Compare commits

..

52 Commits

Author SHA1 Message Date
Oleksiy Syvokon
de9d470e4f open_ai: Configurable model capabilities 2025-08-17 11:27:52 +03:00
Marshall Bowers
f17f63ec84 Remove /docs slash command (#36325)
This PR removes the `/docs` slash command.

We never fully shipped this—with it requiring explicit opt-in via a
setting—and it doesn't seem like the feature is needed in an agentic
world.

Release Notes:

- Removed the `/docs` slash command.
2025-08-16 19:00:31 +00:00
Marshall Bowers
15a1eb2a2e emmet: Extract to zed-extensions/emmet repository (#36323)
This PR extracts the Emmet extension to the
[zed-extensions/emmet](https://github.com/zed-extensions/emmet)
repository.

Release Notes:

- N/A
2025-08-16 17:02:51 +00:00
Ben Brandt
332626e582 Allow Permission Request to only require a ToolCallUpdate instead of a full tool call (#36319)
Release Notes:

- N/A
2025-08-16 15:04:09 +00:00
Finn Evers
7b3fe0a474 Make agent font size inherit the UI font size by default (#36306)
Ensures issues like #36242 and #36295 do not arise where users are
confused that the agent panel does not follow the default UI font size
whilst also keeping the possibility of customization. The agent font
size was matching the UI font size previously alredy, which makes it
easier to change it for most scenarios.

Also cleans up some related logic around modifying the font sizes.

Release Notes:

- The agent panel font size will now inherit the UI font size by default
if not set in your settings.
2025-08-16 14:35:06 +00:00
Marshall Bowers
36184a71df collab: Drop rate_buckets table (#36315)
This PR drops the `rate_buckets` table, as we're no longer using it.

Release Notes:

- N/A
2025-08-16 14:11:36 +00:00
Marshall Bowers
ea7bc96c05 collab: Remove billing-related tables from SQLite schema (#36312)
This PR removes the billing-related tables from the SQLite schema, as we
don't actually reference these tables anywhere in the Collab codebase
anymore.

Release Notes:

- N/A
2025-08-16 13:52:14 +00:00
Marshall Bowers
d1958aa439 collab: Add orb_customer_id to billing_customers (#36310)
This PR adds an `orb_customer_id` column to the `billing_customers`
table.

Release Notes:

- N/A
2025-08-16 13:48:38 +00:00
Marshall Bowers
5620e359af collab: Make admin column non-nullable on users table (#36307)
This PR updates the `admin` column on the `users` table to be
non-nullable.

We were already treating it like this in practice.

All rows in the production database already have a value for the `admin`
column.

Release Notes:

- N/A
2025-08-16 13:09:14 +00:00
Finn Evers
6f2e7c355e Ensure bundled files are opened as read-only (#36299)
Closes #36297

While we set the editor as read-only for bundled files, we didn't do
this for the underlying buffer. This PR fixes this and adds a test for
the corresponding case.

Release Notes:

- Fixed an issue where bundled files (e.g. the default settings) could
be edited in some circumstances
2025-08-16 11:36:17 +00:00
Lukas Wirth
864d4bc1d1 editor: Drop multiline targets in navigation buffers (#36291)
Release Notes:

- N/A
2025-08-16 07:55:46 +00:00
Julia Ryan
7784fac288 Separate minidump crashes from panics (#36267)
The minidump-based crash reporting is now entirely separate from our
legacy panic_hook-based reporting. This should improve the association
of minidumps with their metadata and give us more consistent crash
reports.

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-08-16 06:33:32 +00:00
zumbalogy
f5f14111ef Add setting for hiding the status_bar.cursor_position_button (#36288)
Release Notes:

- Added an option for the status_bar.cursor_position_button. Setting to
`false` will hide the button. It defaults to `true`.

This builds off the recent work to hide the language selection button
(https://github.com/zed-industries/zed/pull/33977). I tried to follow
that pattern, and to pick a clear name for the option, but any
feedback/change is welcome.

---------

Co-authored-by: zumbalogy <3770982+zumbalogy@users.noreply.github.com>
2025-08-16 09:19:38 +03:00
Marshall Bowers
e664a9bc48 collab: Remove unused billing-related database code (#36282)
This PR removes a bunch of unused database code related to billing, as
we no longer need it.

Release Notes:

- N/A
2025-08-15 22:58:10 +00:00
Cole Miller
bf34e185d5 Move MentionSet to message_editor module (#36281)
This is a more natural place for it than its current home next to the
completion provider.

Release Notes:

- N/A
2025-08-15 18:47:36 -04:00
Marshall Bowers
b9c110e63e collab: Remove GET /users/look_up endpoint (#36279)
This PR removes the `GET /users/look_up` endpoint from Collab, as it has
been moved to Cloud.

Release Notes:

- N/A
2025-08-15 22:01:41 +00:00
Ben Kunkle
f642f7615f keymap_ui: Don't try to parse empty action arguments as JSON (#36278)
Closes #ISSUE

Release Notes:

- Keymap Editor: Fixed an issue where leaving the arguments field empty
would result in an error even if arguments were optional
2025-08-15 17:59:57 -04:00
Cole Miller
3d77ad7e1a thread_view: Start loading images as soon as they're added (#36276)
Release Notes:

- N/A
2025-08-15 17:39:33 -04:00
Yang Gang
f365403618 agent: Update use_modifier_to_send behavior description for Windows (#36230)
Release Notes:

- N/A

Signed-off-by: Yang Gang <yanggang.uefi@gmail.com>
2025-08-15 21:03:50 +00:00
Agus Zubiaga
9eb1ff2726 acp thread view: Always use editors for user messages (#36256)
This means the cursor will be at the position you clicked:


https://github.com/user-attachments/assets/0693950d-7513-4d90-88e2-55817df7213a


Release Notes:

- N/A
2025-08-15 21:03:36 +00:00
Marshall Bowers
239e479aed collab: Remove Stripe code (#36275)
This PR removes the code for integrating with Stripe from Collab.

All of these concerns are now handled by Cloud.

Release Notes:

- N/A
2025-08-15 20:49:56 +00:00
Finn Evers
3e0a755486 Remove some redundant entity clones (#36274)
`cx.entity()` already returns an owned entity, so there is no need for
these clones.

Release Notes:

- N/A
2025-08-15 20:27:44 +00:00
Marshall Bowers
7199c733b2 proto: Remove AcceptTermsOfService message (#36272)
This PR removes the `AcceptTermsOfService` RPC message.

We're no longer using the message after
https://github.com/zed-industries/zed/pull/36255.

Release Notes:

- N/A
2025-08-15 20:21:45 +00:00
Finn Evers
65f64aa513 search: Fix recently introduced issues with the search bars (#36271)
Follow-up to https://github.com/zed-industries/zed/pull/36233

The above PR simplified the handling but introduced some bugs: The
replace buttons were no longer clickable, some buttons also lost their
toggle states, some buttons shared their element id and, lastly, some
buttons were clickable but would not trigger the right action. This PR
fixes all that.

Release Notes:

- N/A
2025-08-15 22:21:21 +02:00
Marshall Bowers
2a9d4599cd proto: Remove unused types (#36269)
This PR removes some unused types from the RPC protocol.

Release Notes:

- N/A
2025-08-15 19:46:23 +00:00
Joseph T. Lyons
75f85b3aaa Remove old telemetry events and transformation layer (#36263)
Successor to: https://github.com/zed-industries/zed/pull/25179

Release Notes:

- N/A
2025-08-15 15:37:52 -04:00
Marshall Bowers
b3cad8b527 proto: Remove UpdateUserPlan message (#36268)
This PR removes the `UpdateUserPlan` RPC message.

We're no longer using the message after
https://github.com/zed-industries/zed/pull/36255.

Release Notes:

- N/A
2025-08-15 19:21:04 +00:00
Cole Miller
1931889759 thread_view: Move handlers for confirmed completions to the MessageEditor (#36214)
Release Notes:

- N/A

---------

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-08-15 18:55:34 +00:00
Finn Evers
3c5d5a1d57 editor: Add access method for project (#36266)
This resolves a `TODO` that I've stumbled upon too many times whilst
looking at the editor code.

Release Notes:

- N/A
2025-08-15 18:34:22 +00:00
Marshall Bowers
bd1fda6782 proto: Remove GetPrivateUserInfo message (#36265)
This PR removes the `GetPrivateUserInfo` RPC message.

We're no longer using the message after
https://github.com/zed-industries/zed/pull/36255.

Release Notes:

- N/A
2025-08-15 18:27:31 +00:00
Marshall Bowers
e452aba9da proto: Order reserved fields (#36261)
This PR orders the `reserved` fields in the RPC `Envelope`, as they had
gotten unsorted.

Release Notes:

- N/A
2025-08-15 17:59:08 +00:00
Marshall Bowers
75b832029a Remove RPC messages pertaining to the LLM token (#36252)
This PR removes the RPC messages pertaining to the LLM token.

We now retrieve the LLM token from Cloud.

Release Notes:

- N/A
2025-08-15 13:26:21 -04:00
Marshall Bowers
257e0991d8 collab: Increase minimum required version to connect (#36255)
This PR increases the minimum required version to connect to Collab.

Previously this was set at v0.157.0.

The new minimum required version is v0.198.4, which is the first version
where we no longer connect to Collab automatically.

Clients on the v0.199.x minor version will also need to be v0.199.2 or
greater in order to connect, due to us hotfixing the connection changes
to the Preview branch.

We're doing this to force clients to upgrade in order to connect to
Collab, as we're going to be removing some of the old RPC usages related
to authentication that are no longer used. Therefore, we want users to
be on a version of Zed that does not rely on those messages.

Users will see a message similar to this one, prompting them to upgrade:

<img width="1209" height="875" alt="Screenshot 2025-08-15 at 11 37
55 AM"
src="https://github.com/user-attachments/assets/59ffff3e-8f82-4152-84a8-776c691eaaee"
/>

> Note: In this case I'm simulating the error state, which is why I'm
signed in via Cloud while still not being able to connect to Collab.
Users on older versions will see the "Please update Zed to Collaborate"
message without being signed in.

Release Notes:

- N/A
2025-08-15 16:13:52 +00:00
Umesh Yadav
c39f294bcb remote: Add support for additional SSH arguments in SshSocket (#33243)
Closes #29438

Release Notes:

- Fix SSH agent forwarding doesn't work when using SSH remote
development.
2025-08-15 10:13:18 -06:00
Oleksiy Syvokon
7671f34f88 agent: Create checkpoint before/after every edit operation (#36253)
1. Previously, checkpoints only appeared when an agent's edit happened
immediately after a user message. This is rare (agent usually collects
some context first), so they were almost never shown. This is now fixed.

2. After this change, a checkpoint is created after every edit
operation. So when the agent edits files five times in a single dialog
turn, we will now display five checkpoints.

As a bonus, it's now possible to undo only a part of a long agent
response.

Closes #36092, #32917

Release Notes:

- Create agent checkpoints more frequently (before every edit)
2025-08-15 15:37:24 +00:00
Igal Tabachnik
7993ee9c07 Suggest unsaved buffer content text as the default filename (#35707)
Closes #24672

This PR complements a feature added earlier by @JosephTLyons (in
https://github.com/zed-industries/zed/pull/32353) where the text is
considered as the tab title in a new buffer. It piggybacks off that
change and sets the title as the suggested filename in the save dialog
(completely mirroring the same functionality in VSCode):

![2025-08-05 11 50
28](https://github.com/user-attachments/assets/49ad9e4a-5559-44b0-a4b0-ae19890e478e)

Release Notes:

- Text entered in a new untitled buffer is considered as the default
filename when saving
2025-08-15 17:26:38 +02:00
Marshall Bowers
485802b9e5 collab: Remove endpoints for issuing notifications from Cloud (#36249)
This PR removes the `POST /users/:id/refresh_llm_tokens` and `POST
/users/:id/update_plan` endpoints from Collab.

These endpoints were added to be called by Cloud in order to push down
notifications over the Collab RPC connection.

Cloud now sends down notifications to clients directly, so we no longer
need these endpoints.

All calls to these endpoints have already been removed in production.

Release Notes:

- N/A
2025-08-15 14:46:06 +00:00
Bennet Bo Fenner
1e41d86b31 agent2: Set thread_id, prompt_id, temperature on request (#36246)
Release Notes:

- N/A
2025-08-15 14:23:55 +00:00
Bennet Bo Fenner
10a2426a58 agent2: Port profile selector (#36244)
Release Notes:

- N/A
2025-08-15 14:06:56 +00:00
Agus Zubiaga
91e6b38285 Log agent servers stderr (#36243)
Release Notes:

- N/A
2025-08-15 10:58:57 -03:00
Bennet Bo Fenner
f63036548c agent2: Implement prompt caching (#36236)
Release Notes:

- N/A
2025-08-15 15:17:56 +02:00
Lukas Wirth
846ed6adf9 search: Fix project search not rendering matches count (#36238)
Follow up to https://github.com/zed-industries/zed/pull/36103/

Release Notes:

- N/A
2025-08-15 12:54:05 +00:00
Daniel Sauble
708c434bd4 workspace: Highlight where dragged tab will be dropped (#34740)
Closes #18565

I could use some advice on the color palette / theming. A couple
options:

1. The `drop_target_background` color could be used for the border if we
didn't use it for the background of the tab. In VSCode, the background
color of tabs doesn't change as you're dragging, there's just a border
between tabs. My only concern with this option is that the current
`drop_target_background` color is a bit subtle when used for a small
area like a border.

2. Another option could be to add a `drop_target_border` theme color,
but I don't know how much complexity this adds to implementation
(presumably all existing themes would need to be updated?).

Demo:


https://github.com/user-attachments/assets/0b7c04ea-5ec5-4b45-adad-156dfbf552db

Release Notes:

- Highlight where a dragged tab will be dropped between two other tabs

---------

Co-authored-by: Smit Barmase <heysmitbarmase@gmail.com>
2025-08-15 11:43:29 +00:00
Bennet Bo Fenner
6f3cd42411 agent2: Port Zed AI features (#36172)
Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-08-15 11:17:17 +00:00
smit
f8b0105258 project: Fix LSP TextDocumentSyncCapability dynamic registration (#36234)
Closes #36213

Use `textDocument/didChange`
([docs](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_synchronization))
instead of `textDocument/synchronization`.

Release Notes:

- Fixed an issue where Dart projects were being formatted incorrectly by
the language server.
2025-08-15 16:24:54 +05:30
Oleksiy Syvokon
2a57b160b0 openai: Don't send prompt_cache_key for OpenAI-compatible models (#36231)
Some APIs fail when they get this parameter

Closes #36215

Release Notes:

- Fixed OpenAI-compatible providers that don't support prompt caching
and/or reasoning
2025-08-15 13:54:24 +03:00
Lukas Wirth
d891348442 search: Simplify search options handling (#36233)
Release Notes:

- N/A
2025-08-15 10:34:54 +00:00
David Kleingeld
4f0b00b0d9 Add component NotificationFrame & CaptureAudio parts for testing (#36081)
Adds component NotificationFrame. It implements a subset of MessageNotification as a Component and refactors MessageNotification to use NotificationFrame. Having some notification UI Component is nice as it allows us to easily build new types of notifications.

Uses the new NotificationFrame component for CaptureAudioNotification. 

Adds a CaptureAudio action in the dev namespace (not meant for
end-users). It records 10 seconds of audio and saves that to a wav file.

Release Notes:

- N/A

---------

Co-authored-by: Mikayla <mikayla@zed.dev>
2025-08-15 10:10:52 +00:00
Oleksiy Syvokon
a3dcc76687 openai: Don't send reasoning_effort if it's not set (#36228)
Release Notes:

- N/A
2025-08-15 09:12:18 +00:00
Lukas Wirth
8d6982e78f search: Fix some inconsistencies between project and buffer search bars (#36103)
- project search query string now turns red when no results are found
matching buffer search behavior
- General code deduplication as well as more consistent layout between
the two bars, as some minor details have drifted apart
- Tab cycling in buffer search now ends up in editor focus when cycling
backwards, matching forward cycling
- Report parse errors in filter include and exclude editors

Release Notes:

- N/A
2025-08-15 09:56:47 +02:00
smit
23d0433158 linux: Fix keyboard events not working on first start in X11 (#36224)
Closes #29083

On X11, `ibus-x11` crashes on some distros after Zed interacts with it.
This is not unique to Zed, `xim-rs` shows the same behavior, and there
are similar upstream `ibus` reports with apps like Blender:

- https://github.com/ibus/ibus/issues/2697

I opened an upstream issue to track this:

- https://github.com/ibus/ibus/issues/2789

When this crash happens, we don’t get a disconnect event, so Zed keeps
sending events to the IM server and waits for a response. It works on
subsequent starts because IM server doesn't exist now and we default to
non-XIM path.

This PR detects the crash via X11 events and falls back to the non-XIM
path so typing keeps working. We still need to investigate whether the
root cause is in `xim-rs` or `ibus-x11`.

Release Notes:

- Fixed an issue on X11 where keyboard input sometimes didn’t work on
first start.
2025-08-15 12:51:32 +05:30
Alvaro Parker
4d27b228f7 remote server: Use env flag to opt out of musl remote server build (#36069)
Closes #ISSUE

This will allow devs to opt out of the musl build when developing zed by
running `ZED_BUILD_REMOTE_SERVER=nomusl cargo r` which also fixes remote
builds on NixOS.

Release Notes:

- Add a env flag (`ZED_BUILD_REMOTE_SERVER=nomusl`) to opt out of musl
builds when building the remote server
2025-08-14 20:31:01 -06:00
214 changed files with 5242 additions and 10640 deletions

View File

@@ -25,6 +25,8 @@ third-party = [
{ name = "reqwest", version = "0.11.27" },
# build of remote_server should not include scap / its x11 dependency
{ name = "scap", git = "https://github.com/zed-industries/scap", rev = "808aa5c45b41e8f44729d02e38fd00a2fe2722e7" },
# build of remote_server should not need to include on libalsa through rodio
{ name = "rodio" },
]
[final-excludes]
@@ -32,7 +34,6 @@ workspace-members = [
"zed_extension_api",
# exclude all extensions
"zed_emmet",
"zed_glsl",
"zed_html",
"zed_proto",

220
Cargo.lock generated
View File

@@ -172,9 +172,9 @@ dependencies = [
[[package]]
name = "agent-client-protocol"
version = "0.0.24"
version = "0.0.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fd68bbbef8e424fb8a605c5f0b00c360f682c4528b0a5feb5ec928aaf5ce28e"
checksum = "2ab66add8be8d6a963f5bf4070045c1bbf36472837654c73e2298dd16bda5bf7"
dependencies = [
"anyhow",
"futures 0.3.31",
@@ -347,7 +347,6 @@ dependencies = [
"gpui",
"html_to_markdown",
"http_client",
"indexed_docs",
"indoc",
"inventory",
"itertools 0.14.0",
@@ -872,7 +871,6 @@ dependencies = [
"gpui",
"html_to_markdown",
"http_client",
"indexed_docs",
"language",
"pretty_assertions",
"project",
@@ -1262,26 +1260,6 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "async-stripe"
version = "0.40.0"
source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735"
dependencies = [
"chrono",
"futures-util",
"http-types",
"hyper 0.14.32",
"hyper-rustls 0.24.2",
"serde",
"serde_json",
"serde_path_to_error",
"serde_qs 0.10.1",
"smart-default 0.6.0",
"smol_str 0.1.24",
"thiserror 1.0.69",
"tokio",
]
[[package]]
name = "async-tar"
version = "0.5.0"
@@ -2083,12 +2061,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce"
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.21.7"
@@ -3281,7 +3253,6 @@ dependencies = [
"anyhow",
"assistant_context",
"assistant_slash_command",
"async-stripe",
"async-trait",
"async-tungstenite",
"audio",
@@ -3297,7 +3268,6 @@ dependencies = [
"chrono",
"client",
"clock",
"cloud_llm_client",
"collab_ui",
"collections",
"command_palette_hooks",
@@ -3308,7 +3278,6 @@ dependencies = [
"dap_adapters",
"dashmap 6.1.0",
"debugger_ui",
"derive_more 0.99.19",
"editor",
"envy",
"extension",
@@ -3324,7 +3293,6 @@ dependencies = [
"http_client",
"hyper 0.14.32",
"indoc",
"jsonwebtoken",
"language",
"language_model",
"livekit_api",
@@ -3370,7 +3338,6 @@ dependencies = [
"telemetry_events",
"text",
"theme",
"thiserror 2.0.12",
"time",
"tokio",
"toml 0.8.20",
@@ -3872,7 +3839,7 @@ dependencies = [
"rustc-hash 1.1.0",
"rustybuzz 0.14.1",
"self_cell",
"smol_str 0.2.2",
"smol_str",
"swash",
"sys-locale",
"ttf-parser 0.21.1",
@@ -4069,6 +4036,8 @@ dependencies = [
"minidumper",
"paths",
"release_channel",
"serde",
"serde_json",
"smol",
"workspace-hack",
]
@@ -6376,17 +6345,6 @@ dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "getrandom"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
dependencies = [
"cfg-if",
"libc",
"wasi 0.9.0+wasi-snapshot-preview1",
]
[[package]]
name = "getrandom"
version = "0.2.15"
@@ -7883,6 +7841,12 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "hound"
version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f"
[[package]]
name = "html5ever"
version = "0.27.0"
@@ -7984,27 +7948,6 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f"
[[package]]
name = "http-types"
version = "2.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad"
dependencies = [
"anyhow",
"async-channel 1.9.0",
"base64 0.13.1",
"futures-lite 1.13.0",
"http 0.2.12",
"infer",
"pin-project-lite",
"rand 0.7.3",
"serde",
"serde_json",
"serde_qs 0.8.5",
"serde_urlencoded",
"url",
]
[[package]]
name = "http_client"
version = "0.1.0"
@@ -8438,34 +8381,6 @@ version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408"
[[package]]
name = "indexed_docs"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"cargo_metadata",
"collections",
"derive_more 0.99.19",
"extension",
"fs",
"futures 0.3.31",
"fuzzy",
"gpui",
"heed",
"html_to_markdown",
"http_client",
"indexmap",
"indoc",
"parking_lot",
"paths",
"pretty_assertions",
"serde",
"strum 0.27.1",
"util",
"workspace-hack",
]
[[package]]
name = "indexmap"
version = "2.9.0"
@@ -8483,12 +8398,6 @@ version = "2.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
[[package]]
name = "infer"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac"
[[package]]
name = "inherent"
version = "1.0.12"
@@ -9711,6 +9620,7 @@ dependencies = [
"objc",
"parking_lot",
"postage",
"rodio",
"scap",
"serde",
"serde_json",
@@ -10264,7 +10174,7 @@ dependencies = [
"num-traits",
"range-map",
"scroll",
"smart-default 0.7.1",
"smart-default",
]
[[package]]
@@ -13138,19 +13048,6 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "rand"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
dependencies = [
"getrandom 0.1.16",
"libc",
"rand_chacha 0.2.2",
"rand_core 0.5.1",
"rand_hc",
]
[[package]]
name = "rand"
version = "0.8.5"
@@ -13172,16 +13069,6 @@ dependencies = [
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
dependencies = [
"ppv-lite86",
"rand_core 0.5.1",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
@@ -13202,15 +13089,6 @@ dependencies = [
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
dependencies = [
"getrandom 0.1.16",
]
[[package]]
name = "rand_core"
version = "0.6.4"
@@ -13229,15 +13107,6 @@ dependencies = [
"getrandom 0.3.2",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
dependencies = [
"rand_core 0.5.1",
]
[[package]]
name = "range-map"
version = "0.2.0"
@@ -13972,6 +13841,7 @@ checksum = "e40ecf59e742e03336be6a3d53755e789fd05a059fa22dfa0ed624722319e183"
dependencies = [
"cpal",
"dasp_sample",
"hound",
"num-rational",
"symphonia",
"tracing",
@@ -14891,28 +14761,6 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_qs"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6"
dependencies = [
"percent-encoding",
"serde",
"thiserror 1.0.69",
]
[[package]]
name = "serde_qs"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa"
dependencies = [
"percent-encoding",
"serde",
"thiserror 1.0.69",
]
[[package]]
name = "serde_repr"
version = "0.1.20"
@@ -15289,17 +15137,6 @@ dependencies = [
"serde",
]
[[package]]
name = "smart-default"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "smart-default"
version = "0.7.1"
@@ -15328,15 +15165,6 @@ dependencies = [
"futures-lite 2.6.0",
]
[[package]]
name = "smol_str"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9"
dependencies = [
"serde",
]
[[package]]
name = "smol_str"
version = "0.2.2"
@@ -18185,12 +18013,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "wasi"
version = "0.9.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
@@ -20277,7 +20099,7 @@ dependencies = [
[[package]]
name = "xim"
version = "0.4.0"
source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd"
source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d"
dependencies = [
"ahash 0.8.11",
"hashbrown 0.14.5",
@@ -20290,7 +20112,7 @@ dependencies = [
[[package]]
name = "xim-ctext"
version = "0.3.0"
source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd"
source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d"
dependencies = [
"encoding_rs",
]
@@ -20298,7 +20120,7 @@ dependencies = [
[[package]]
name = "xim-parser"
version = "0.2.1"
source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd"
source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d"
dependencies = [
"bitflags 2.9.0",
]
@@ -20576,6 +20398,7 @@ dependencies = [
"language_tools",
"languages",
"libc",
"livekit_client",
"log",
"markdown",
"markdown_preview",
@@ -20667,13 +20490,6 @@ dependencies = [
"workspace-hack",
]
[[package]]
name = "zed_emmet"
version = "0.0.6"
dependencies = [
"zed_extension_api 0.1.0",
]
[[package]]
name = "zed_extension_api"
version = "0.1.0"

View File

@@ -81,7 +81,6 @@ members = [
"crates/http_client_tls",
"crates/icons",
"crates/image_viewer",
"crates/indexed_docs",
"crates/edit_prediction",
"crates/edit_prediction_button",
"crates/inspector_ui",
@@ -199,7 +198,6 @@ members = [
# Extensions
#
"extensions/emmet",
"extensions/glsl",
"extensions/html",
"extensions/proto",
@@ -306,7 +304,6 @@ http_client = { path = "crates/http_client" }
http_client_tls = { path = "crates/http_client_tls" }
icons = { path = "crates/icons" }
image_viewer = { path = "crates/image_viewer" }
indexed_docs = { path = "crates/indexed_docs" }
edit_prediction = { path = "crates/edit_prediction" }
edit_prediction_button = { path = "crates/edit_prediction_button" }
inspector_ui = { path = "crates/inspector_ui" }
@@ -363,6 +360,7 @@ remote_server = { path = "crates/remote_server" }
repl = { path = "crates/repl" }
reqwest_client = { path = "crates/reqwest_client" }
rich_text = { path = "crates/rich_text" }
rodio = { version = "0.21.1", default-features = false }
rope = { path = "crates/rope" }
rpc = { path = "crates/rpc" }
rules_library = { path = "crates/rules_library" }
@@ -425,7 +423,7 @@ zlog_settings = { path = "crates/zlog_settings" }
#
agentic-coding-protocol = "0.0.10"
agent-client-protocol = "0.0.24"
agent-client-protocol = "0.0.25"
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"
@@ -564,7 +562,6 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77
"socks",
"stream",
] }
rodio = { version = "0.21.1", default-features = false }
rsa = "0.9.6"
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
"async-dispatcher-runtime",
@@ -667,20 +664,6 @@ workspace-hack = "0.1.0"
yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" }
zstd = "0.11"
[workspace.dependencies.async-stripe]
git = "https://github.com/zed-industries/async-stripe"
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
default-features = false
features = [
"runtime-tokio-hyper-rustls",
"billing",
"checkout",
"events",
# The features below are only enabled to get the `events` feature to build.
"chrono",
"connect",
]
[workspace.dependencies.windows]
version = "0.61"
features = [

View File

@@ -71,8 +71,8 @@
"ui_font_weight": 400,
// The default font size for text in the UI
"ui_font_size": 16,
// The default font size for text in the agent panel
"agent_font_size": 16,
// The default font size for text in the agent panel. Falls back to the UI font size if unset.
"agent_font_size": null,
// How much to fade out unused code.
"unnecessary_code_fade": 0.3,
// Active pane styling settings.
@@ -887,11 +887,6 @@
},
// The settings for slash commands.
"slash_commands": {
// Settings for the `/docs` slash command.
"docs": {
// Whether `/docs` is enabled.
"enabled": false
},
// Settings for the `/project` slash command.
"project": {
// Whether `/project` is enabled.
@@ -1256,7 +1251,9 @@
// Status bar-related settings.
"status_bar": {
// Whether to show the active language button in the status bar.
"active_language_button": true
"active_language_button": true,
// Whether to show the cursor position button in the status bar.
"cursor_position_button": true
},
// Settings specific to the terminal
"terminal": {

View File

@@ -33,13 +33,23 @@ pub struct UserMessage {
pub id: Option<UserMessageId>,
pub content: ContentBlock,
pub chunks: Vec<acp::ContentBlock>,
pub checkpoint: Option<GitStoreCheckpoint>,
pub checkpoint: Option<Checkpoint>,
}
#[derive(Debug)]
pub struct Checkpoint {
git_checkpoint: GitStoreCheckpoint,
pub show: bool,
}
impl UserMessage {
fn to_markdown(&self, cx: &App) -> String {
let mut markdown = String::new();
if let Some(_) = self.checkpoint {
if self
.checkpoint
.as_ref()
.map_or(false, |checkpoint| checkpoint.show)
{
writeln!(markdown, "## User (checkpoint)").unwrap();
} else {
writeln!(markdown, "## User").unwrap();
@@ -99,7 +109,7 @@ pub enum AgentThreadEntry {
}
impl AgentThreadEntry {
fn to_markdown(&self, cx: &App) -> String {
pub fn to_markdown(&self, cx: &App) -> String {
match self {
Self::UserMessage(message) => message.to_markdown(cx),
Self::AssistantMessage(message) => message.to_markdown(cx),
@@ -107,6 +117,14 @@ impl AgentThreadEntry {
}
}
pub fn user_message(&self) -> Option<&UserMessage> {
if let AgentThreadEntry::UserMessage(message) = self {
Some(message)
} else {
None
}
}
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
if let AgentThreadEntry::ToolCall(call) = self {
itertools::Either::Left(call.diffs())
@@ -774,7 +792,7 @@ impl AcpThread {
&mut self,
update: acp::SessionUpdate,
cx: &mut Context<Self>,
) -> Result<()> {
) -> Result<(), acp::Error> {
match update {
acp::SessionUpdate::UserMessageChunk { content } => {
self.push_user_content_block(None, content, cx);
@@ -786,7 +804,7 @@ impl AcpThread {
self.push_assistant_content_block(content, true, cx);
}
acp::SessionUpdate::ToolCall(tool_call) => {
self.upsert_tool_call(tool_call, cx);
self.upsert_tool_call(tool_call, cx)?;
}
acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
self.update_tool_call(tool_call_update, cx)?;
@@ -922,32 +940,40 @@ impl AcpThread {
}
/// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
pub fn upsert_tool_call(
&mut self,
tool_call: acp::ToolCall,
cx: &mut Context<Self>,
) -> Result<(), acp::Error> {
let status = ToolCallStatus::Allowed {
status: tool_call.status,
};
self.upsert_tool_call_inner(tool_call, status, cx)
self.upsert_tool_call_inner(tool_call.into(), status, cx)
}
/// Fails if id does not match an existing entry.
pub fn upsert_tool_call_inner(
&mut self,
tool_call: acp::ToolCall,
tool_call_update: acp::ToolCallUpdate,
status: ToolCallStatus,
cx: &mut Context<Self>,
) {
) -> Result<(), acp::Error> {
let language_registry = self.project.read(cx).languages().clone();
let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
let id = call.id.clone();
let id = tool_call_update.id.clone();
if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
*current_call = call;
if let Some((ix, current_call)) = self.tool_call_mut(&id) {
current_call.update_fields(tool_call_update.fields, language_registry, cx);
current_call.status = status;
cx.emit(AcpThreadEvent::EntryUpdated(ix));
} else {
let call =
ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
self.push_entry(AgentThreadEntry::ToolCall(call), cx);
};
self.resolve_locations(id, cx);
Ok(())
}
fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
@@ -1016,10 +1042,10 @@ impl AcpThread {
pub fn request_tool_call_authorization(
&mut self,
tool_call: acp::ToolCall,
tool_call: acp::ToolCallUpdate,
options: Vec<acp::PermissionOption>,
cx: &mut Context<Self>,
) -> oneshot::Receiver<acp::PermissionOptionId> {
) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
let (tx, rx) = oneshot::channel();
let status = ToolCallStatus::WaitingForConfirmation {
@@ -1027,9 +1053,9 @@ impl AcpThread {
respond_tx: tx,
};
self.upsert_tool_call_inner(tool_call, status, cx);
self.upsert_tool_call_inner(tool_call, status, cx)?;
cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
rx
Ok(rx)
}
pub fn authorize_tool_call(
@@ -1145,9 +1171,12 @@ impl AcpThread {
self.project.read(cx).languages().clone(),
cx,
);
let request = acp::PromptRequest {
prompt: message.clone(),
session_id: self.session_id.clone(),
};
let git_store = self.project.read(cx).git_store().clone();
let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
let message_id = if self
.connection
.session_editor(&self.session_id, cx)
@@ -1161,68 +1190,63 @@ impl AcpThread {
AgentThreadEntry::UserMessage(UserMessage {
id: message_id.clone(),
content: block,
chunks: message.clone(),
chunks: message,
checkpoint: None,
}),
cx,
);
self.run_turn(cx, async move |this, cx| {
let old_checkpoint = git_store
.update(cx, |git, cx| git.checkpoint(cx))?
.await
.context("failed to get old checkpoint")
.log_err();
this.update(cx, |this, cx| {
if let Some((_ix, message)) = this.last_user_message() {
message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
git_checkpoint,
show: false,
});
}
this.connection.prompt(message_id, request, cx)
})?
.await
})
}
pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
self.run_turn(cx, async move |this, cx| {
this.update(cx, |this, cx| {
this.connection
.resume(&this.session_id, cx)
.map(|resume| resume.run(cx))
})?
.context("resuming a session is not supported")?
.await
})
}
fn run_turn(
&mut self,
cx: &mut Context<Self>,
f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
) -> BoxFuture<'static, Result<()>> {
self.clear_completed_plan_entries(cx);
let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
let (tx, rx) = oneshot::channel();
let cancel_task = self.cancel(cx);
let request = acp::PromptRequest {
prompt: message,
session_id: self.session_id.clone(),
};
self.send_task = Some(cx.spawn({
let message_id = message_id.clone();
async move |this, cx| {
cancel_task.await;
old_checkpoint_tx.send(old_checkpoint.await).ok();
if let Ok(result) = this.update(cx, |this, cx| {
this.connection.prompt(message_id, request, cx)
}) {
tx.send(result.await).log_err();
}
}
self.send_task = Some(cx.spawn(async move |this, cx| {
cancel_task.await;
tx.send(f(this, cx).await).ok();
}));
cx.spawn(async move |this, cx| {
let old_checkpoint = old_checkpoint_rx
.await
.map_err(|_| anyhow!("send canceled"))
.flatten()
.context("failed to get old checkpoint")
.log_err();
let response = rx.await;
if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
let new_checkpoint = git_store
.update(cx, |git, cx| git.checkpoint(cx))?
.await
.context("failed to get new checkpoint")
.log_err();
if let Some(new_checkpoint) = new_checkpoint {
let equal = git_store
.update(cx, |git, cx| {
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
})?
.await
.unwrap_or(true);
if !equal {
this.update(cx, |this, cx| {
if let Some((ix, message)) = this.user_message_mut(&message_id) {
message.checkpoint = Some(old_checkpoint);
cx.emit(AcpThreadEvent::EntryUpdated(ix));
}
})?;
}
}
}
this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
.await?;
this.update(cx, |this, cx| {
match response {
@@ -1294,7 +1318,10 @@ impl AcpThread {
return Task::ready(Err(anyhow!("message not found")));
};
let checkpoint = message.checkpoint.clone();
let checkpoint = message
.checkpoint
.as_ref()
.map(|c| c.git_checkpoint.clone());
let git_store = self.project.read(cx).git_store().clone();
cx.spawn(async move |this, cx| {
@@ -1316,6 +1343,59 @@ impl AcpThread {
})
}
fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let git_store = self.project.read(cx).git_store().clone();
let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
if let Some(checkpoint) = message.checkpoint.as_ref() {
checkpoint.git_checkpoint.clone()
} else {
return Task::ready(Ok(()));
}
} else {
return Task::ready(Ok(()));
};
let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
cx.spawn(async move |this, cx| {
let new_checkpoint = new_checkpoint
.await
.context("failed to get new checkpoint")
.log_err();
if let Some(new_checkpoint) = new_checkpoint {
let equal = git_store
.update(cx, |git, cx| {
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
})?
.await
.unwrap_or(true);
this.update(cx, |this, cx| {
let (ix, message) = this.last_user_message().context("no user message")?;
let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
checkpoint.show = !equal;
cx.emit(AcpThreadEvent::EntryUpdated(ix));
anyhow::Ok(())
})??;
}
Ok(())
})
}
fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
self.entries
.iter_mut()
.enumerate()
.rev()
.find_map(|(ix, entry)| {
if let AgentThreadEntry::UserMessage(message) = entry {
Some((ix, message))
} else {
None
}
})
}
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
self.entries.iter().find_map(|entry| {
if let AgentThreadEntry::UserMessage(message) = entry {
@@ -1552,6 +1632,7 @@ mod tests {
use settings::SettingsStore;
use smol::stream::StreamExt as _;
use std::{
any::Any,
cell::RefCell,
path::Path,
rc::Rc,
@@ -2284,6 +2365,10 @@ mod tests {
_session_id: session_id.clone(),
}))
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct FakeAgentSessionEditor {

View File

@@ -4,7 +4,7 @@ use anyhow::Result;
use collections::IndexMap;
use gpui::{Entity, SharedString, Task};
use project::Project;
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
@@ -36,6 +36,14 @@ pub trait AgentConnection {
cx: &mut App,
) -> Task<Result<acp::PromptResponse>>;
fn resume(
&self,
_session_id: &acp::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn AgentSessionResume>> {
None
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
fn session_editor(
@@ -53,12 +61,24 @@ pub trait AgentConnection {
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
None
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
}
impl dyn AgentConnection {
pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
self.into_any().downcast().ok()
}
}
pub trait AgentSessionEditor {
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
}
pub trait AgentSessionResume {
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
}
#[derive(Debug)]
pub struct AuthRequired;
@@ -266,12 +286,12 @@ mod test_support {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
tool_call.clone(),
tool_call.clone().into(),
options.clone(),
cx,
)
})?;
permission.await?;
permission?.await?;
}
thread.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
@@ -299,6 +319,10 @@ mod test_support {
) -> Option<Rc<dyn AgentSessionEditor>> {
Some(Rc::new(StubAgentSessionEditor))
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct StubAgentSessionEditor;

View File

@@ -844,11 +844,17 @@ impl Thread {
.await
.unwrap_or(false);
if !equal {
this.update(cx, |this, cx| {
this.insert_checkpoint(pending_checkpoint, cx)
})?;
}
this.update(cx, |this, cx| {
this.pending_checkpoint = if equal {
Some(pending_checkpoint)
} else {
this.insert_checkpoint(pending_checkpoint, cx);
Some(ThreadCheckpoint {
message_id: this.next_message_id,
git_checkpoint: final_checkpoint,
})
}
})?;
Ok(())
}

View File

@@ -1,9 +1,8 @@
use crate::{AgentResponseEvent, Thread, templates::Templates};
use crate::{
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
WebSearchTool,
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
};
use acp_thread::AgentModelSelector;
use agent_client_protocol as acp;
@@ -11,6 +10,7 @@ use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap};
use fs::Fs;
use futures::channel::mpsc;
use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
@@ -21,6 +21,7 @@ use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use settings::update_settings_file;
use std::any::Any;
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
@@ -426,9 +427,9 @@ impl NativeAgent {
self.models.refresh_list(cx);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, _| {
let model_id = LanguageModels::model_id(&thread.selected_model);
let model_id = LanguageModels::model_id(&thread.model());
if let Some(model) = self.models.model_from_id(&model_id) {
thread.selected_model = model.clone();
thread.set_model(model.clone());
}
});
}
@@ -439,6 +440,125 @@ impl NativeAgent {
#[derive(Clone)]
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
impl NativeAgentConnection {
pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
self.0
.read(cx)
.sessions
.get(session_id)
.map(|session| session.thread.clone())
}
fn run_turn(
&self,
session_id: acp::SessionId,
cx: &mut App,
f: impl 'static
+ FnOnce(
Entity<Thread>,
&mut App,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent
.sessions
.get_mut(&session_id)
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
}) else {
return Task::ready(Err(anyhow!("Session not found")));
};
log::debug!("Found session for: {}", session_id);
let mut response_stream = match f(thread, cx) {
Ok(stream) => stream,
Err(err) => return Task::ready(Err(err)),
};
cx.spawn(async move |cx| {
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
AgentResponseEvent::Text(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
false,
cx,
)
})?;
}
AgentResponseEvent::Thinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
true,
cx,
)
})?;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
}) => {
let recv = acp_thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, options, cx)
})?;
cx.background_spawn(async move {
if let Some(recv) = recv.log_err()
&& let Some(option) = recv
.await
.context("authorization sender was dropped")
.log_err()
{
response
.send(option)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})??;
}
AgentResponseEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
}
}
Err(e) => {
log::error!("Error in model response stream: {:?}", e);
return Err(e);
}
}
}
log::info!("Response stream completed");
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
}
impl AgentModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
log::debug!("NativeAgentConnection::list_models called");
@@ -472,7 +592,7 @@ impl AgentModelSelector for NativeAgentConnection {
};
thread.update(cx, |thread, _cx| {
thread.selected_model = model.clone();
thread.set_model(model.clone());
});
update_settings_file::<AgentSettings>(
@@ -502,7 +622,7 @@ impl AgentModelSelector for NativeAgentConnection {
else {
return Task::ready(Err(anyhow!("Session not found")));
};
let model = thread.read(cx).selected_model.clone();
let model = thread.read(cx).model().clone();
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
else {
return Task::ready(Err(anyhow!("Provider not found")));
@@ -644,25 +764,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
) -> Task<Result<acp::PromptResponse>> {
let id = id.expect("UserMessageId is required");
let session_id = params.session_id.clone();
let agent = self.0.clone();
log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len());
cx.spawn(async move |cx| {
// Get session
let (thread, acp_thread) = agent
.update(cx, |agent, _| {
agent
.sessions
.get_mut(&session_id)
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
})?
.ok_or_else(|| {
log::error!("Session not found: {}", session_id);
anyhow::anyhow!("Session not found")
})?;
log::debug!("Found session for: {}", session_id);
self.run_turn(session_id, cx, |thread, cx| {
let content: Vec<UserMessageContent> = params
.prompt
.into_iter()
@@ -672,99 +777,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::debug!("Message id: {:?}", id);
log::debug!("Message content: {:?}", content);
// Get model using the ModelSelector capability (always available for agent2)
// Get the selected model from the thread directly
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
// Send to thread
log::info!("Sending message to thread with model: {:?}", model.name());
let mut response_stream =
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
AgentResponseEvent::Text(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
false,
cx,
)
})?;
}
AgentResponseEvent::Thinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
true,
cx,
)
})?;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
}) => {
let recv = acp_thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, options, cx)
})?;
cx.background_spawn(async move {
if let Some(option) = recv
.await
.context("authorization sender was dropped")
.log_err()
{
response
.send(option)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})?;
}
AgentResponseEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
}
}
Err(e) => {
log::error!("Error in model response stream: {:?}", e);
// TODO: Consider sending an error message to the UI
break;
}
}
}
log::info!("Response stream completed");
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
Ok(thread.update(cx, |thread, cx| {
log::info!(
"Sending message to thread with model: {:?}",
thread.model().name()
);
thread.send(id, content, cx)
}))
})
}
fn resume(
&self,
session_id: &acp::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
Some(Rc::new(NativeAgentSessionResume {
connection: self.clone(),
session_id: session_id.clone(),
}) as _)
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
@@ -786,6 +819,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
})
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct NativeAgentSessionEditor(Entity<Thread>);
@@ -796,6 +833,20 @@ impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
}
}
struct NativeAgentSessionResume {
connection: NativeAgentConnection,
session_id: acp::SessionId,
}
impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
self.connection
.run_turn(self.session_id.clone(), cx, |thread, cx| {
thread.update(cx, |thread, cx| thread.resume(cx))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -957,7 +1008,7 @@ mod tests {
agent.read_with(cx, |agent, _| {
let session = agent.sessions.get(&session_id).unwrap();
session.thread.read_with(cx, |thread, _| {
assert_eq!(thread.selected_model.id().0, "fake");
assert_eq!(thread.model().id().0, "fake");
});
});

View File

@@ -12,10 +12,11 @@ use gpui::{
};
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
fake_provider::FakeLanguageModel,
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
Role, StopReason, fake_provider::FakeLanguageModel,
};
use pretty_assertions::assert_eq;
use project::Project;
use prompt_store::ProjectContext;
use reqwest_client::ReqwestClient;
@@ -129,6 +130,134 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
);
}
#[gpui::test]
async fn test_prompt_caching(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
// Send initial user message and verify it's cached
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Message 1"], cx)
});
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![LanguageModelRequestMessage {
role: Role::User,
content: vec!["Message 1".into()],
cache: true
}]
);
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
"Response to Message 1".into(),
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
// Send another user message and verify only the latest is cached
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Message 2"], cx)
});
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Message 1".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec!["Response to Message 1".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Message 2".into()],
cache: true
}
]
);
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
"Response to Message 2".into(),
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
// Simulate a tool call and verify that the latest tool result is cached
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
});
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
id: "tool_1".into(),
name: EchoTool.name().into(),
raw_input: json!({"text": "test"}).to_string(),
input: json!({"text": "test"}),
is_input_complete: true,
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let tool_result = LanguageModelToolResult {
tool_use_id: "tool_1".into(),
tool_name: EchoTool.name().into(),
is_error: false,
content: "test".into(),
output: Some("test".into()),
};
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Message 1".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec!["Response to Message 1".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Message 2".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec!["Response to Message 2".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Use the echo tool".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result)],
cache: true
}
]
);
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
@@ -394,8 +523,194 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
}
#[gpui::test]
async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["abc"], cx)
});
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
id: "tool_id_1".into(),
name: EchoTool.name().into(),
raw_input: "{}".into(),
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
is_input_complete: true,
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let tool_result = LanguageModelToolResult {
tool_use_id: "tool_id_1".into(),
tool_name: EchoTool.name().into(),
is_error: false,
content: "def".into(),
output: Some("def".into()),
};
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["abc".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use.clone())],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result.clone())],
cache: true
},
]
);
// Simulate reaching tool use limit.
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
last_event
.unwrap_err()
.is::<language_model::ToolUseLimitReachedError>()
);
let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["abc".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Continue where you left off".into()],
cache: true
}
]
);
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
fake_model.end_last_completion_stream();
events.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _cx| {
assert_eq!(
thread.last_message().unwrap().to_markdown(),
indoc! {"
## Assistant
Done
"}
)
});
// Ensure we error if calling resume when tool use limit was *not* reached.
let error = thread
.update(cx, |thread, cx| thread.resume(cx))
.unwrap_err();
assert_eq!(
error.to_string(),
"can only resume after tool use limit is reached"
)
}
#[gpui::test]
async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["abc"], cx)
});
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
id: "tool_id_1".into(),
name: EchoTool.name().into(),
raw_input: "{}".into(),
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
is_input_complete: true,
};
let tool_result = LanguageModelToolResult {
tool_use_id: "tool_id_1".into(),
tool_name: EchoTool.name().into(),
is_error: false,
content: "def".into(),
output: Some("def".into()),
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
last_event
.unwrap_err()
.is::<language_model::ToolUseLimitReachedError>()
);
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), vec!["ghi"], cx)
});
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["abc".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["ghi".into()],
cache: true
}
]
);
}
async fn expect_tool_call(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCall {
let event = events
.next()
@@ -411,7 +726,7 @@ async fn expect_tool_call(
}
async fn expect_tool_call_update_fields(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCallUpdate {
let event = events
.next()
@@ -429,7 +744,7 @@ async fn expect_tool_call_update_fields(
}
async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> ToolCallAuthorization {
loop {
let event = events
@@ -1007,9 +1322,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
}
/// Filters out the stop events for asserting against in tests
fn stop_events(
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> Vec<acp::StopReason> {
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {

View File

@@ -7,7 +7,7 @@ use std::future;
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct EchoToolInput {
/// The text to echo.
text: String,
pub text: String,
}
pub struct EchoTool;

View File

@@ -2,10 +2,10 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, AgentSettings};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use cloud_llm_client::{CompletionIntent, CompletionMode};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
use collections::IndexMap;
use fs::Fs;
use futures::{
@@ -14,10 +14,10 @@ use futures::{
};
use gpui::{App, Context, Entity, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
@@ -28,11 +28,54 @@ use smol::stream::StreamExt;
use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
use std::{fmt::Write, ops::Range};
use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid;
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
)]
pub struct ThreadId(Arc<str>);
impl ThreadId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string().into())
}
}
impl std::fmt::Display for ThreadId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for ThreadId {
fn from(value: &str) -> Self {
Self(value.into())
}
}
/// The ID of the user prompt that initiated a request.
///
/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct PromptId(Arc<str>);
impl PromptId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string().into())
}
}
impl std::fmt::Display for PromptId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Message {
User(UserMessage),
Agent(AgentMessage),
Resume,
}
impl Message {
@@ -47,6 +90,7 @@ impl Message {
match self {
Message::User(message) => message.to_markdown(),
Message::Agent(message) => message.to_markdown(),
Message::Resume => "[resumed after tool use limit was reached]".into(),
}
}
}
@@ -320,7 +364,11 @@ impl AgentMessage {
}
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
let mut content = Vec::with_capacity(self.content.len());
let mut assistant_message = LanguageModelRequestMessage {
role: Role::Assistant,
content: Vec::with_capacity(self.content.len()),
cache: false,
};
for chunk in &self.content {
let chunk = match chunk {
AgentMessageContent::Text(text) => {
@@ -342,29 +390,30 @@ impl AgentMessage {
language_model::MessageContent::Image(value.clone())
}
};
content.push(chunk);
assistant_message.content.push(chunk);
}
let mut messages = vec![LanguageModelRequestMessage {
role: Role::Assistant,
content,
let mut user_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
}];
};
if !self.tool_results.is_empty() {
let mut tool_results = Vec::with_capacity(self.tool_results.len());
for tool_result in self.tool_results.values() {
tool_results.push(language_model::MessageContent::ToolResult(
for tool_result in self.tool_results.values() {
user_message
.content
.push(language_model::MessageContent::ToolResult(
tool_result.clone(),
));
}
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: tool_results,
cache: false,
});
}
let mut messages = Vec::new();
if !assistant_message.content.is_empty() {
messages.push(assistant_message);
}
if !user_message.content.is_empty() {
messages.push(user_message);
}
messages
}
}
@@ -399,12 +448,14 @@ pub enum AgentResponseEvent {
#[derive(Debug)]
pub struct ToolCallAuthorization {
pub tool_call: acp::ToolCall,
pub tool_call: acp::ToolCallUpdate,
pub options: Vec<acp::PermissionOption>,
pub response: oneshot::Sender<acp::PermissionOptionId>,
}
pub struct Thread {
id: ThreadId,
prompt_id: PromptId,
messages: Vec<Message>,
completion_mode: CompletionMode,
/// Holds the task that handles agent interaction until the end of the turn.
@@ -413,11 +464,12 @@ pub struct Thread {
running_turn: Option<Task<()>>,
pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
tool_use_limit_reached: bool,
context_server_registry: Entity<ContextServerRegistry>,
profile_id: AgentProfileId,
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
@@ -429,21 +481,24 @@ impl Thread {
context_server_registry: Entity<ContextServerRegistry>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
default_model: Arc<dyn LanguageModel>,
model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
Self {
id: ThreadId::new(),
prompt_id: PromptId::new(),
messages: Vec::new(),
completion_mode: CompletionMode::Normal,
running_turn: None,
pending_message: None,
tools: BTreeMap::default(),
tool_use_limit_reached: false,
context_server_registry,
profile_id,
project_context,
templates,
selected_model: default_model,
model,
project,
action_log,
}
@@ -457,7 +512,19 @@ impl Thread {
&self.action_log
}
pub fn set_mode(&mut self, mode: CompletionMode) {
pub fn model(&self) -> &Arc<dyn LanguageModel> {
&self.model
}
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
self.model = model;
}
pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode
}
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@@ -478,6 +545,10 @@ impl Thread {
self.tools.remove(name).is_some()
}
pub fn profile(&self) -> &AgentProfileId {
&self.profile_id
}
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
self.profile_id = profile_id;
}
@@ -499,36 +570,60 @@ impl Thread {
Ok(())
}
pub fn resume(
&mut self,
cx: &mut Context<Self>,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
anyhow::ensure!(
self.tool_use_limit_reached,
"can only resume after tool use limit is reached"
);
self.messages.push(Message::Resume);
cx.notify();
log::info!("Total messages in thread: {}", self.messages.len());
Ok(self.run_turn(cx))
}
/// Sending a message results in the model streaming a response, which could include tool calls.
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
pub fn send<T>(
&mut self,
message_id: UserMessageId,
id: UserMessageId,
content: impl IntoIterator<Item = T>,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
where
T: Into<UserMessageContent>,
{
let model = self.selected_model.clone();
log::info!("Thread::send called with model: {:?}", self.model.name());
self.advance_prompt_id();
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
log::info!("Thread::send called with model: {:?}", model.name());
log::debug!("Thread::send content: {:?}", content);
self.messages
.push(Message::User(UserMessage { id, content }));
cx.notify();
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let event_stream = AgentResponseEventStream(events_tx);
self.messages.push(Message::User(UserMessage {
id: message_id.clone(),
content,
}));
log::info!("Total messages in thread: {}", self.messages.len());
self.run_turn(cx)
}
fn run_turn(
&mut self,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
let model = self.model.clone();
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
let event_stream = AgentResponseEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1);
self.tool_use_limit_reached = false;
self.running_turn = Some(cx.spawn(async move |this, cx| {
log::info!("Starting agent turn execution");
let turn_result = async {
let turn_result: Result<()> = async {
let mut completion_intent = CompletionIntent::UserPrompt;
loop {
log::debug!(
@@ -543,13 +638,22 @@ impl Thread {
let mut events = model.stream_completion(request, cx).await?;
log::debug!("Stream completion started successfully");
let mut tool_use_limit_reached = false;
let mut tool_uses = FuturesUnordered::new();
while let Some(event) = events.next().await {
match event? {
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::ToolUseLimitReached,
) => {
tool_use_limit_reached = true;
}
LanguageModelCompletionEvent::Stop(reason) => {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
this.update(cx, |this, _cx| this.truncate(message_id))??;
this.update(cx, |this, _cx| {
this.flush_pending_message();
this.messages.truncate(message_ix);
})?;
return Ok(());
}
}
@@ -567,12 +671,7 @@ impl Thread {
}
}
if tool_uses.is_empty() {
log::info!("No tool uses found, completing turn");
return Ok(());
}
log::info!("Found {} tool uses to execute", tool_uses.len());
let used_tools = tool_uses.is_empty();
while let Some(tool_result) = tool_uses.next().await {
log::info!("Tool finished {:?}", tool_result);
@@ -596,8 +695,17 @@ impl Thread {
.ok();
}
this.update(cx, |this, _| this.flush_pending_message())?;
completion_intent = CompletionIntent::ToolResults;
if tool_use_limit_reached {
log::info!("Tool use limit reached, completing turn");
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
return Err(language_model::ToolUseLimitReachedError.into());
} else if used_tools {
log::info!("No tool uses found, completing turn");
return Ok(());
} else {
this.update(cx, |this, _| this.flush_pending_message())?;
completion_intent = CompletionIntent::ToolResults;
}
}
}
.await;
@@ -678,10 +786,10 @@ impl Thread {
fn handle_text_event(
&mut self,
new_text: String,
events_stream: &AgentResponseEventStream,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) {
events_stream.send_text(&new_text);
event_stream.send_text(&new_text);
let last_message = self.pending_message();
if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
@@ -793,13 +901,14 @@ impl Thread {
let fs = self.project.read(cx).fs().clone();
let tool_event_stream =
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
tool_event_stream.update_fields(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
let supports_images = self.selected_model.supports_images();
let supports_images = self.model.supports_images();
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
log::info!("Running tool {}", tool_use.name);
Some(cx.foreground_executor().spawn(async move {
let tool_result = tool_result.await.and_then(|output| {
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
@@ -902,7 +1011,7 @@ impl Thread {
name: tool_name,
description: tool.description().to_string(),
input_schema: tool
.input_schema(self.selected_model.tool_input_format())
.input_schema(self.model.tool_input_format())
.log_err()?,
})
})
@@ -914,15 +1023,15 @@ impl Thread {
log::info!("Request includes {} tools", tools.len());
let request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
thread_id: Some(self.id.to_string()),
prompt_id: Some(self.prompt_id.to_string()),
intent: Some(completion_intent),
mode: Some(self.completion_mode),
mode: Some(self.completion_mode.into()),
messages,
tools,
tool_choice: None,
stop: Vec::new(),
temperature: None,
temperature: AgentSettings::temperature_for_model(self.model(), cx),
thinking_allowed: true,
};
@@ -935,7 +1044,7 @@ impl Thread {
.profiles
.get(&self.profile_id)
.context("profile not found")?;
let provider_id = self.selected_model.provider_id();
let provider_id = self.model.provider_id();
Ok(self
.tools
@@ -971,6 +1080,11 @@ impl Thread {
match message {
Message::User(message) => messages.push(message.to_request()),
Message::Agent(message) => messages.extend(message.to_request()),
Message::Resume => messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec!["Continue where you left off".into()],
cache: false,
}),
}
}
@@ -978,6 +1092,14 @@ impl Thread {
messages.extend(message.to_request());
}
if let Some(last_user_message) = messages
.iter_mut()
.rev()
.find(|message| message.role == Role::User)
{
last_user_message.cache = true;
}
messages
}
@@ -997,6 +1119,10 @@ impl Thread {
markdown
}
fn advance_prompt_id(&mut self) {
self.prompt_id = PromptId::new();
}
}
pub trait AgentTool
@@ -1123,9 +1249,7 @@ where
}
#[derive(Clone)]
struct AgentResponseEventStream(
mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
impl AgentResponseEventStream {
fn send_text(&self, text: &str) {
@@ -1212,16 +1336,14 @@ impl AgentResponseEventStream {
}
}
fn send_error(&self, error: LanguageModelCompletionError) {
self.0.unbounded_send(Err(error)).ok();
fn send_error(&self, error: impl Into<anyhow::Error>) {
self.0.unbounded_send(Err(error.into())).ok();
}
}
#[derive(Clone)]
pub struct ToolCallEventStream {
tool_use_id: LanguageModelToolUseId,
kind: acp::ToolKind,
input: serde_json::Value,
stream: AgentResponseEventStream,
fs: Option<Arc<dyn Fs>>,
}
@@ -1229,35 +1351,21 @@ pub struct ToolCallEventStream {
impl ToolCallEventStream {
#[cfg(test)]
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
None,
);
let stream =
ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None);
(stream, ToolCallEventStreamReceiver(events_rx))
}
fn new(
tool_use: &LanguageModelToolUse,
kind: acp::ToolKind,
tool_use_id: LanguageModelToolUseId,
stream: AgentResponseEventStream,
fs: Option<Arc<dyn Fs>>,
) -> Self {
Self {
tool_use_id: tool_use.id.clone(),
kind,
input: tool_use.input.clone(),
tool_use_id,
stream,
fs,
}
@@ -1304,12 +1412,13 @@ impl ToolCallEventStream {
.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
ToolCallAuthorization {
tool_call: AgentResponseEventStream::initial_tool_call(
&self.tool_use_id,
title.into(),
self.kind.clone(),
self.input.clone(),
),
tool_call: acp::ToolCallUpdate {
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
fields: acp::ToolCallUpdateFields {
title: Some(title.into()),
..Default::default()
},
},
options: vec![
acp::PermissionOption {
id: acp::PermissionOptionId("always_allow".into()),
@@ -1351,9 +1460,7 @@ impl ToolCallEventStream {
}
#[cfg(test)]
pub struct ToolCallEventStreamReceiver(
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
#[cfg(test)]
impl ToolCallEventStreamReceiver {
@@ -1381,7 +1488,7 @@ impl ToolCallEventStreamReceiver {
#[cfg(test)]
impl std::ops::Deref for ToolCallEventStreamReceiver {
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
fn deref(&self) -> &Self::Target {
&self.0

View File

@@ -241,7 +241,7 @@ impl AgentTool for EditFileTool {
thread.build_completion_request(CompletionIntent::ToolResults, cx)
});
let thread = self.thread.read(cx);
let model = thread.selected_model.clone();
let model = thread.model().clone();
let action_log = thread.action_log().clone();
let authorize = self.authorize(&input, &event_stream, cx);
@@ -1001,7 +1001,10 @@ mod tests {
});
let event = stream_rx.expect_authorization().await;
assert_eq!(event.tool_call.title, "test 1 (local settings)");
assert_eq!(
event.tool_call.fields.title,
Some("test 1 (local settings)".into())
);
// Test 2: Path outside project should require confirmation
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
@@ -1018,7 +1021,7 @@ mod tests {
});
let event = stream_rx.expect_authorization().await;
assert_eq!(event.tool_call.title, "test 2");
assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
// Test 3: Relative path without .zed should not require confirmation
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
@@ -1051,7 +1054,10 @@ mod tests {
)
});
let event = stream_rx.expect_authorization().await;
assert_eq!(event.tool_call.title, "test 4 (local settings)");
assert_eq!(
event.tool_call.fields.title,
Some("test 4 (local settings)".into())
);
// Test 5: When always_allow_tool_actions is enabled, no confirmation needed
cx.update(|cx| {

View File

@@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
use std::{cell::RefCell, path::Path, rc::Rc};
use std::{any::Any, cell::RefCell, path::Path, rc::Rc};
use ui::App;
use util::ResultExt as _;
@@ -135,9 +135,9 @@ impl acp_old::Client for OldAcpClientDelegate {
let response = cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, acp_options, cx)
thread.request_tool_call_authorization(tool_call.into(), acp_options, cx)
})
})?
})??
.context("Failed to update thread")?
.await;
@@ -168,7 +168,7 @@ impl acp_old::Client for OldAcpClientDelegate {
cx,
)
})
})?
})??
.context("Failed to update thread")?;
Ok(acp_old::PushToolCallResponse {
@@ -507,4 +507,8 @@ impl AgentConnection for AcpConnection {
})
.detach_and_log_err(cx)
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}

View File

@@ -1,11 +1,13 @@
use agent_client_protocol::{self as acp, Agent as _};
use anyhow::anyhow;
use collections::HashMap;
use futures::AsyncBufReadExt as _;
use futures::channel::oneshot;
use futures::io::BufReader;
use project::Project;
use std::cell::RefCell;
use std::path::Path;
use std::rc::Rc;
use std::{any::Any, cell::RefCell};
use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
@@ -40,12 +42,13 @@ impl AcpConnection {
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()?;
let stdout = child.stdout.take().expect("Failed to take stdout");
let stdin = child.stdin.take().expect("Failed to take stdin");
let stdout = child.stdout.take().context("Failed to take stdout")?;
let stdin = child.stdin.take().context("Failed to take stdin")?;
let stderr = child.stderr.take().context("Failed to take stderr")?;
log::trace!("Spawned (pid: {})", child.id());
let sessions = Rc::new(RefCell::new(HashMap::default()));
@@ -63,6 +66,18 @@ impl AcpConnection {
let io_task = cx.background_spawn(io_task);
cx.background_spawn(async move {
let mut stderr = BufReader::new(stderr);
let mut line = String::new();
while let Ok(n) = stderr.read_line(&mut line).await
&& n > 0
{
log::warn!("agent stderr: {}", &line);
line.clear();
}
})
.detach();
cx.spawn({
let sessions = sessions.clone();
async move |cx| {
@@ -191,6 +206,10 @@ impl AgentConnection for AcpConnection {
.spawn(async move { conn.cancel(params).await })
.detach();
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct ClientDelegate {
@@ -214,7 +233,7 @@ impl acp::Client for ClientDelegate {
thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
})?;
let result = rx.await;
let result = rx?.await;
let outcome = match result {
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },

View File

@@ -6,6 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
use std::any::Any;
use std::cell::RefCell;
use std::fmt::Display;
use std::path::Path;
@@ -13,7 +14,7 @@ use std::rc::Rc;
use uuid::Uuid;
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
use anyhow::{Context as _, Result, anyhow};
use futures::channel::oneshot;
use futures::{AsyncBufReadExt, AsyncWriteExt};
use futures::{
@@ -129,12 +130,25 @@ impl AgentConnection for ClaudeAgentConnection {
&cwd,
)?;
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let stdout = child.stdout.take().context("Failed to take stdout")?;
let stdin = child.stdin.take().context("Failed to take stdin")?;
let stderr = child.stderr.take().context("Failed to take stderr")?;
let pid = child.id();
log::trace!("Spawned (pid: {})", pid);
cx.background_spawn(async move {
let mut stderr = BufReader::new(stderr);
let mut line = String::new();
while let Ok(n) = stderr.read_line(&mut line).await
&& n > 0
{
log::warn!("agent stderr: {}", &line);
line.clear();
}
})
.detach();
cx.background_spawn(async move {
let mut outgoing_rx = Some(outgoing_rx);
@@ -289,6 +303,10 @@ impl AgentConnection for ClaudeAgentConnection {
})
.log_err();
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
#[derive(Clone, Copy)]
@@ -340,7 +358,7 @@ fn spawn_claude(
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()?;
@@ -542,8 +560,9 @@ impl ClaudeAgentSession {
thread.upsert_tool_call(
claude_tool.as_acp(acp::ToolCallId(id.into())),
cx,
);
)?;
}
anyhow::Ok(())
})
.log_err();
}

View File

@@ -154,7 +154,7 @@ impl McpServerTool for PermissionTool {
let chosen_option = thread
.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
claude_tool.as_acp(tool_call_id),
claude_tool.as_acp(tool_call_id).into(),
vec![
acp::PermissionOption {
id: allow_option_id.clone(),
@@ -169,7 +169,7 @@ impl McpServerTool for PermissionTool {
],
cx,
)
})?
})??
.await?;
let response = if chosen_option == allow_option_id {

View File

@@ -309,7 +309,7 @@ pub struct AgentSettingsContent {
///
/// Default: true
expand_terminal_card: Option<bool>,
/// Whether to always use cmd-enter (or ctrl-enter on Linux) to send messages in the agent panel.
/// Whether to always use cmd-enter (or ctrl-enter on Linux or Windows) to send messages in the agent panel.
///
/// Default: false
use_modifier_to_send: Option<bool>,

View File

@@ -50,7 +50,6 @@ fuzzy.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
http_client.workspace = true
indexed_docs.workspace = true
indoc.workspace = true
inventory.workspace = true
itertools.workspace = true

File diff suppressed because it is too large Load Diff

View File

@@ -1,45 +1,141 @@
use std::{collections::HashMap, ops::Range};
use std::ops::Range;
use acp_thread::AcpThread;
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
use acp_thread::{AcpThread, AgentThreadEntry};
use agent::{TextThreadStore, ThreadStore};
use collections::HashMap;
use editor::{Editor, EditorMode, MinimapVisibility};
use gpui::{
AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window,
AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement,
WeakEntity, Window,
};
use language::language_settings::SoftWrap;
use project::Project;
use settings::Settings as _;
use terminal_view::TerminalView;
use theme::ThemeSettings;
use ui::TextSize;
use ui::{Context, TextSize};
use workspace::Workspace;
#[derive(Default)]
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
pub struct EntryViewState {
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
entries: Vec<Entry>,
}
impl EntryViewState {
pub fn new(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
) -> Self {
Self {
workspace,
project,
thread_store,
text_thread_store,
entries: Vec::new(),
}
}
pub fn entry(&self, index: usize) -> Option<&Entry> {
self.entries.get(index)
}
pub fn sync_entry(
&mut self,
workspace: WeakEntity<Workspace>,
thread: Entity<AcpThread>,
index: usize,
thread: &Entity<AcpThread>,
window: &mut Window,
cx: &mut App,
cx: &mut Context<Self>,
) {
debug_assert!(index <= self.entries.len());
let entry = if let Some(entry) = self.entries.get_mut(index) {
entry
} else {
self.entries.push(Entry::default());
self.entries.last_mut().unwrap()
let Some(thread_entry) = thread.read(cx).entries().get(index) else {
return;
};
entry.sync_diff_multibuffers(&thread, index, window, cx);
entry.sync_terminals(&workspace, &thread, index, window, cx);
match thread_entry {
AgentThreadEntry::UserMessage(message) => {
let has_id = message.id.is_some();
let chunks = message.chunks.clone();
let message_editor = cx.new(|cx| {
let mut editor = MessageEditor::new(
self.workspace.clone(),
self.project.clone(),
self.thread_store.clone(),
self.text_thread_store.clone(),
editor::EditorMode::AutoHeight {
min_lines: 1,
max_lines: None,
},
window,
cx,
);
if !has_id {
editor.set_read_only(true, cx);
}
editor.set_message(chunks, window, cx);
editor
});
cx.subscribe(&message_editor, move |_, editor, event, cx| {
cx.emit(EntryViewEvent {
entry_index: index,
view_event: ViewEvent::MessageEditorEvent(editor, *event),
})
})
.detach();
self.set_entry(index, Entry::UserMessage(message_editor));
}
AgentThreadEntry::ToolCall(tool_call) => {
let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
views
} else {
self.set_entry(index, Entry::empty());
let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
unreachable!()
};
views
};
for terminal in terminals {
views.entry(terminal.entity_id()).or_insert_with(|| {
create_terminal(
self.workspace.clone(),
self.project.clone(),
terminal.clone(),
window,
cx,
)
.into_any()
});
}
for diff in diffs {
views
.entry(diff.entity_id())
.or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
}
}
AgentThreadEntry::AssistantMessage(_) => {
if index == self.entries.len() {
self.entries.push(Entry::empty())
}
}
};
}
fn set_entry(&mut self, index: usize, entry: Entry) {
if index == self.entries.len() {
self.entries.push(entry);
} else {
self.entries[index] = entry;
}
}
pub fn remove(&mut self, range: Range<usize>) {
@@ -48,26 +144,51 @@ impl EntryViewState {
pub fn settings_changed(&mut self, cx: &mut App) {
for entry in self.entries.iter() {
for view in entry.views.values() {
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
diff_editor.update(cx, |diff_editor, cx| {
diff_editor
.set_text_style_refinement(diff_editor_text_style_refinement(cx));
cx.notify();
})
match entry {
Entry::UserMessage { .. } => {}
Entry::Content(response_views) => {
for view in response_views.values() {
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
diff_editor.update(cx, |diff_editor, cx| {
diff_editor.set_text_style_refinement(
diff_editor_text_style_refinement(cx),
);
cx.notify();
})
}
}
}
}
}
}
}
pub struct Entry {
views: HashMap<EntityId, AnyEntity>,
impl EventEmitter<EntryViewEvent> for EntryViewState {}
pub struct EntryViewEvent {
pub entry_index: usize,
pub view_event: ViewEvent,
}
pub enum ViewEvent {
MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
}
pub enum Entry {
UserMessage(Entity<MessageEditor>),
Content(HashMap<EntityId, AnyEntity>),
}
impl Entry {
pub fn editor_for_diff(&self, diff: &Entity<MultiBuffer>) -> Option<Entity<Editor>> {
self.views
pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
match self {
Self::UserMessage(editor) => Some(editor),
Entry::Content(_) => None,
}
}
pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
self.content_map()?
.get(&diff.entity_id())
.cloned()
.map(|entity| entity.downcast::<Editor>().unwrap())
@@ -77,118 +198,88 @@ impl Entry {
&self,
terminal: &Entity<acp_thread::Terminal>,
) -> Option<Entity<TerminalView>> {
self.views
self.content_map()?
.get(&terminal.entity_id())
.cloned()
.map(|entity| entity.downcast::<TerminalView>().unwrap())
}
fn sync_diff_multibuffers(
&mut self,
thread: &Entity<AcpThread>,
index: usize,
window: &mut Window,
cx: &mut App,
) {
let Some(entry) = thread.read(cx).entries().get(index) else {
return;
};
let multibuffers = entry
.diffs()
.map(|diff| diff.read(cx).multibuffer().clone());
let multibuffers = multibuffers.collect::<Vec<_>>();
for multibuffer in multibuffers {
if self.views.contains_key(&multibuffer.entity_id()) {
return;
}
let editor = cx.new(|cx| {
let mut editor = Editor::new(
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: true,
},
multibuffer.clone(),
None,
window,
cx,
);
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_expand_excerpt_buttons(cx);
editor.set_show_vertical_scrollbar(false, cx);
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_show_indent_guides(false, cx);
editor.set_read_only(true);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
editor
});
let entity_id = multibuffer.entity_id();
self.views.insert(entity_id, editor.into_any());
fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
match self {
Self::Content(map) => Some(map),
_ => None,
}
}
fn sync_terminals(
&mut self,
workspace: &WeakEntity<Workspace>,
thread: &Entity<AcpThread>,
index: usize,
window: &mut Window,
cx: &mut App,
) {
let Some(entry) = thread.read(cx).entries().get(index) else {
return;
};
let terminals = entry
.terminals()
.map(|terminal| terminal.clone())
.collect::<Vec<_>>();
for terminal in terminals {
if self.views.contains_key(&terminal.entity_id()) {
return;
}
let Some(strong_workspace) = workspace.upgrade() else {
return;
};
let terminal_view = cx.new(|cx| {
let mut view = TerminalView::new(
terminal.read(cx).inner().clone(),
workspace.clone(),
None,
strong_workspace.read(cx).project().downgrade(),
window,
cx,
);
view.set_embedded_mode(Some(1000), cx);
view
});
let entity_id = terminal.entity_id();
self.views.insert(entity_id, terminal_view.into_any());
}
fn empty() -> Self {
Self::Content(HashMap::default())
}
#[cfg(test)]
pub fn len(&self) -> usize {
self.views.len()
pub fn has_content(&self) -> bool {
match self {
Self::Content(map) => !map.is_empty(),
Self::UserMessage(_) => false,
}
}
}
fn create_terminal(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
terminal: Entity<acp_thread::Terminal>,
window: &mut Window,
cx: &mut App,
) -> Entity<TerminalView> {
cx.new(|cx| {
let mut view = TerminalView::new(
terminal.read(cx).inner().clone(),
workspace.clone(),
None,
project.downgrade(),
window,
cx,
);
view.set_embedded_mode(Some(1000), cx);
view
})
}
fn create_editor_diff(
diff: Entity<acp_thread::Diff>,
window: &mut Window,
cx: &mut App,
) -> Entity<Editor> {
cx.new(|cx| {
let mut editor = Editor::new(
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: true,
},
diff.read(cx).multibuffer().clone(),
None,
window,
cx,
);
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_expand_excerpt_buttons(cx);
editor.set_show_vertical_scrollbar(false, cx);
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_show_indent_guides(false, cx);
editor.set_read_only(true);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
editor
})
}
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
TextStyleRefinement {
font_size: Some(
@@ -201,26 +292,20 @@ fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
}
}
impl Default for Entry {
fn default() -> Self {
Self {
// Avoid allocating in the heap by default
views: HashMap::with_capacity(0),
}
}
}
#[cfg(test)]
mod tests {
use std::{path::Path, rc::Rc};
use acp_thread::{AgentConnection, StubAgentConnection};
use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
use editor::{EditorSettings, RowInfo};
use fs::FakeFs;
use gpui::{SemanticVersion, TestAppContext};
use gpui::{AppContext as _, SemanticVersion, TestAppContext};
use crate::acp::entry_view_state::EntryViewState;
use multi_buffer::MultiBufferRow;
use pretty_assertions::assert_matches;
use project::Project;
@@ -230,8 +315,6 @@ mod tests {
use util::path;
use workspace::Workspace;
use crate::acp::entry_view_state::EntryViewState;
#[gpui::test]
async fn test_diff_sync(cx: &mut TestAppContext) {
init_test(cx);
@@ -269,7 +352,7 @@ mod tests {
.update(|_, cx| {
connection
.clone()
.new_thread(project, Path::new(path!("/project")), cx)
.new_thread(project.clone(), Path::new(path!("/project")), cx)
})
.await
.unwrap();
@@ -279,12 +362,23 @@ mod tests {
connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
});
let mut view_state = EntryViewState::default();
cx.update(|window, cx| {
view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx);
let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
let view_state = cx.new(|_cx| {
EntryViewState::new(
workspace.downgrade(),
project.clone(),
thread_store,
text_thread_store,
)
});
let multibuffer = thread.read_with(cx, |thread, cx| {
view_state.update_in(cx, |view_state, window, cx| {
view_state.sync_entry(0, &thread, window, cx)
});
let diff = thread.read_with(cx, |thread, _cx| {
thread
.entries()
.get(0)
@@ -292,15 +386,14 @@ mod tests {
.diffs()
.next()
.unwrap()
.read(cx)
.multibuffer()
.clone()
});
cx.run_until_parked();
let entry = view_state.entry(0).unwrap();
let diff_editor = entry.editor_for_diff(&multibuffer).unwrap();
let diff_editor = view_state.read_with(cx, |view_state, _cx| {
view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
});
assert_eq!(
diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
"hi world\nhello world"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -465,7 +465,7 @@ impl AgentConfiguration {
"modifier-send",
"Use modifier to submit a message",
Some(
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(),
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux or Windows) required to send messages.".into(),
),
use_modifier_to_send,
move |state, _window, cx| {
@@ -1035,7 +1035,6 @@ fn extension_only_provides_context_server(manifest: &ExtensionManifest) -> bool
&& manifest.grammars.is_empty()
&& manifest.language_servers.is_empty()
&& manifest.slash_commands.is_empty()
&& manifest.indexed_docs_providers.is_empty()
&& manifest.snippets.is_none()
&& manifest.debug_locators.is_empty()
}

View File

@@ -818,12 +818,10 @@ impl AgentPanel {
ActiveView::Thread { thread, .. } => {
thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
}
ActiveView::ExternalAgentThread { thread_view, .. } => {
thread_view.update(cx, |thread_element, cx| {
thread_element.cancel_generation(cx)
});
}
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
ActiveView::ExternalAgentThread { .. }
| ActiveView::TextThread { .. }
| ActiveView::History
| ActiveView::Configuration => {}
}
}
@@ -1259,13 +1257,11 @@ impl AgentPanel {
ThemeSettings::get_global(cx).agent_font_size(cx) + delta;
let _ = settings
.agent_font_size
.insert(theme::clamp_font_size(agent_font_size).0);
.insert(Some(theme::clamp_font_size(agent_font_size).into()));
},
);
} else {
theme::adjust_agent_font_size(cx, |size| {
*size += delta;
});
theme::adjust_agent_font_size(cx, |size| size + delta);
}
}
WhichFontSize::BufferFont => {

View File

@@ -5,7 +5,6 @@ mod agent_diff;
mod agent_model_selector;
mod agent_panel;
mod buffer_codegen;
mod burn_mode_tooltip;
mod context_picker;
mod context_server_configuration;
mod context_strip;
@@ -243,7 +242,6 @@ pub fn init(
client.telemetry().clone(),
cx,
);
indexed_docs::init(cx);
cx.observe_new(move |workspace, window, cx| {
ConfigureContextServerModal::register(workspace, language_registry.clone(), window, cx)
})
@@ -410,12 +408,6 @@ fn update_slash_commands_from_settings(cx: &mut App) {
let slash_command_registry = SlashCommandRegistry::global(cx);
let settings = SlashCommandSettings::get_global(cx);
if settings.docs.enabled {
slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
} else {
slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
}
if settings.cargo_workspace.enabled {
slash_command_registry
.register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);

View File

@@ -1,61 +0,0 @@
use gpui::{Context, FontWeight, IntoElement, Render, Window};
use ui::{prelude::*, tooltip_container};
pub struct BurnModeTooltip {
selected: bool,
}
impl BurnModeTooltip {
pub fn new() -> Self {
Self { selected: false }
}
pub fn selected(mut self, selected: bool) -> Self {
self.selected = selected;
self
}
}
impl Render for BurnModeTooltip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let (icon, color) = if self.selected {
(IconName::ZedBurnModeOn, Color::Error)
} else {
(IconName::ZedBurnMode, Color::Default)
};
let turned_on = h_flex()
.h_4()
.px_1()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().text_accent.opacity(0.1))
.rounded_sm()
.child(
Label::new("ON")
.size(LabelSize::XSmall)
.weight(FontWeight::SEMIBOLD)
.color(Color::Accent),
);
let title = h_flex()
.gap_1p5()
.child(Icon::new(icon).size(IconSize::Small).color(color))
.child(Label::new("Burn Mode"))
.when(self.selected, |title| title.child(turned_on));
tooltip_container(window, cx, |this, _, _| {
this
.child(title)
.child(
div()
.max_w_64()
.child(
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
.size(LabelSize::Small)
.color(Color::Muted)
)
)
})
}
}

View File

@@ -13,7 +13,7 @@ use anyhow::{Result, anyhow};
use collections::HashSet;
pub use completion_provider::ContextPickerCompletionProvider;
use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId};
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
use editor::{Anchor, Editor, ExcerptId, FoldPlaceholder, ToOffset};
use fetch_context_picker::FetchContextPicker;
use file_context_picker::FileContextPicker;
use file_context_picker::render_file_context_entry;
@@ -228,7 +228,7 @@ impl ContextPicker {
}
fn build_menu(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Entity<ContextMenu> {
let context_picker = cx.entity().clone();
let context_picker = cx.entity();
let menu = ContextMenu::build(window, cx, move |menu, _window, cx| {
let recent = self.recent_entries(cx);
@@ -837,42 +837,9 @@ fn render_fold_icon_button(
) -> Arc<dyn Send + Sync + Fn(FoldId, Range<Anchor>, &mut App) -> AnyElement> {
Arc::new({
move |fold_id, fold_range, cx| {
let is_in_text_selection = editor.upgrade().is_some_and(|editor| {
editor.update(cx, |editor, cx| {
let snapshot = editor
.buffer()
.update(cx, |multi_buffer, cx| multi_buffer.snapshot(cx));
let is_in_pending_selection = || {
editor
.selections
.pending
.as_ref()
.is_some_and(|pending_selection| {
pending_selection
.selection
.range()
.includes(&fold_range, &snapshot)
})
};
let mut is_in_complete_selection = || {
editor
.selections
.disjoint_in_range::<usize>(fold_range.clone(), cx)
.into_iter()
.any(|selection| {
// This is needed to cover a corner case, if we just check for an existing
// selection in the fold range, having a cursor at the start of the fold
// marks it as selected. Non-empty selections don't cause this.
let length = selection.end - selection.start;
length > 0
})
};
is_in_pending_selection() || is_in_complete_selection()
})
});
let is_in_text_selection = editor
.update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx))
.unwrap_or_default();
ButtonLike::new(fold_id)
.style(ButtonStyle::Filled)

View File

@@ -72,7 +72,7 @@ pub fn init(
let Some(window) = window else {
return;
};
let workspace = cx.entity().clone();
let workspace = cx.entity();
InlineAssistant::update_global(cx, |inline_assistant, cx| {
inline_assistant.register_workspace(&workspace, window, cx)
});

View File

@@ -1,5 +1,6 @@
use std::{cmp::Reverse, sync::Arc};
use cloud_llm_client::Plan;
use collections::{HashSet, IndexMap};
use feature_flags::ZedProFeatureFlag;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
@@ -10,7 +11,6 @@ use language_model::{
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use proto::Plan;
use ui::{ListItem, ListItemSpacing, prelude::*};
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
@@ -536,7 +536,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
) -> Option<gpui::AnyElement> {
use feature_flags::FeatureFlagAppExt;
let plan = proto::Plan::ZedPro;
let plan = Plan::ZedPro;
Some(
h_flex()
@@ -557,7 +557,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
window
.dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
}),
Plan::Free | Plan::ZedProTrial => Button::new(
Plan::ZedFree | Plan::ZedProTrial => Button::new(
"try-pro",
if plan == Plan::ZedProTrial {
"Upgrade to Pro"

View File

@@ -6,7 +6,7 @@ use crate::agent_diff::AgentDiffThread;
use crate::agent_model_selector::AgentModelSelector;
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use crate::ui::{
MaxModeTooltip,
BurnModeTooltip,
preview::{AgentPreview, UsageCallout},
};
use agent::history_store::HistoryStore;
@@ -14,7 +14,7 @@ use agent::{
context::{AgentContextKey, ContextLoadResult, load_context},
context_store::ContextStoreEvent,
};
use agent_settings::{AgentSettings, CompletionMode};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use ai_onboarding::ApiKeysWithProviders;
use buffer_diff::BufferDiff;
use cloud_llm_client::CompletionIntent;
@@ -55,7 +55,7 @@ use zed_actions::agent::ToggleModelSelector;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::profile_selector::ProfileSelector;
use crate::profile_selector::{ProfileProvider, ProfileSelector};
use crate::{
ActiveThread, AgentDiffPane, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll,
ModelUsageContext, NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode,
@@ -152,6 +152,24 @@ pub(crate) fn create_editor(
editor
}
impl ProfileProvider for Entity<Thread> {
fn profiles_supported(&self, cx: &App) -> bool {
self.read(cx)
.configured_model()
.map_or(false, |model| model.model.supports_tools())
}
fn profile_id(&self, cx: &App) -> AgentProfileId {
self.read(cx).profile().id().clone()
}
fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App) {
self.update(cx, |this, cx| {
this.set_profile(profile_id, cx);
});
}
}
impl MessageEditor {
pub fn new(
fs: Arc<dyn Fs>,
@@ -221,8 +239,9 @@ impl MessageEditor {
)
});
let profile_selector =
cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx));
let profile_selector = cx.new(|cx| {
ProfileSelector::new(fs, Arc::new(thread.clone()), editor.focus_handle(cx), cx)
});
Self {
editor: editor.clone(),
@@ -605,7 +624,7 @@ impl MessageEditor {
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
}))
.tooltip(move |_window, cx| {
cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled))
.into()
})
.into_any_element(),

View File

@@ -1,12 +1,8 @@
use crate::{ManageProfiles, ToggleProfileSelector};
use agent::{
Thread,
agent_profile::{AgentProfile, AvailableProfiles},
};
use agent::agent_profile::{AgentProfile, AvailableProfiles};
use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles};
use fs::Fs;
use gpui::{Action, Empty, Entity, FocusHandle, Subscription, prelude::*};
use language_model::LanguageModelRegistry;
use gpui::{Action, Entity, FocusHandle, Subscription, prelude::*};
use settings::{Settings as _, SettingsStore, update_settings_file};
use std::sync::Arc;
use ui::{
@@ -14,10 +10,22 @@ use ui::{
prelude::*,
};
/// Trait for types that can provide and manage agent profiles
pub trait ProfileProvider {
/// Get the current profile ID
fn profile_id(&self, cx: &App) -> AgentProfileId;
/// Set the profile ID
fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App);
/// Check if profiles are supported in the current context (e.g. if the model that is selected has tool support)
fn profiles_supported(&self, cx: &App) -> bool;
}
pub struct ProfileSelector {
profiles: AvailableProfiles,
fs: Arc<dyn Fs>,
thread: Entity<Thread>,
provider: Arc<dyn ProfileProvider>,
menu_handle: PopoverMenuHandle<ContextMenu>,
focus_handle: FocusHandle,
_subscriptions: Vec<Subscription>,
@@ -26,7 +34,7 @@ pub struct ProfileSelector {
impl ProfileSelector {
pub fn new(
fs: Arc<dyn Fs>,
thread: Entity<Thread>,
provider: Arc<dyn ProfileProvider>,
focus_handle: FocusHandle,
cx: &mut Context<Self>,
) -> Self {
@@ -37,7 +45,7 @@ impl ProfileSelector {
Self {
profiles: AgentProfile::available_profiles(cx),
fs,
thread,
provider,
menu_handle: PopoverMenuHandle::default(),
focus_handle,
_subscriptions: vec![settings_subscription],
@@ -113,10 +121,10 @@ impl ProfileSelector {
builtin_profiles::MINIMAL => Some("Chat about anything with no tools."),
_ => None,
};
let thread_profile_id = self.thread.read(cx).profile().id();
let thread_profile_id = self.provider.profile_id(cx);
let entry = ContextMenuEntry::new(profile_name.clone())
.toggleable(IconPosition::End, &profile_id == thread_profile_id);
.toggleable(IconPosition::End, profile_id == thread_profile_id);
let entry = if let Some(doc_text) = documentation {
entry.documentation_aside(documentation_side(settings.dock), move |_| {
@@ -128,7 +136,7 @@ impl ProfileSelector {
entry.handler({
let fs = self.fs.clone();
let thread = self.thread.clone();
let provider = self.provider.clone();
let profile_id = profile_id.clone();
move |_window, cx| {
update_settings_file::<AgentSettings>(fs.clone(), cx, {
@@ -138,9 +146,7 @@ impl ProfileSelector {
}
});
thread.update(cx, |this, cx| {
this.set_profile(profile_id.clone(), cx);
});
provider.set_profile(profile_id.clone(), cx);
}
})
}
@@ -149,23 +155,15 @@ impl ProfileSelector {
impl Render for ProfileSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let settings = AgentSettings::get_global(cx);
let profile_id = self.thread.read(cx).profile().id();
let profile = settings.profiles.get(profile_id);
let profile_id = self.provider.profile_id(cx);
let profile = settings.profiles.get(&profile_id);
let selected_profile = profile
.map(|profile| profile.name.clone())
.unwrap_or_else(|| "Unknown".into());
let configured_model = self.thread.read(cx).configured_model().or_else(|| {
let model_registry = LanguageModelRegistry::read_global(cx);
model_registry.default_model()
});
let Some(configured_model) = configured_model else {
return Empty.into_any_element();
};
if configured_model.model.supports_tools() {
let this = cx.entity().clone();
if self.provider.profiles_supported(cx) {
let this = cx.entity();
let focus_handle = self.focus_handle.clone();
let trigger_button = Button::new("profile-selector-model", selected_profile)
.label_size(LabelSize::Small)

View File

@@ -7,22 +7,11 @@ use settings::{Settings, SettingsSources};
/// Settings for slash commands.
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
pub struct SlashCommandSettings {
/// Settings for the `/docs` slash command.
#[serde(default)]
pub docs: DocsCommandSettings,
/// Settings for the `/cargo-workspace` slash command.
#[serde(default)]
pub cargo_workspace: CargoWorkspaceCommandSettings,
}
/// Settings for the `/docs` slash command.
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
pub struct DocsCommandSettings {
/// Whether `/docs` is enabled.
#[serde(default)]
pub enabled: bool,
}
/// Settings for the `/cargo-workspace` slash command.
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
pub struct CargoWorkspaceCommandSettings {

View File

@@ -1,14 +1,11 @@
use crate::{
burn_mode_tooltip::BurnModeTooltip,
language_model_selector::{LanguageModelSelector, language_model_selector},
ui::BurnModeTooltip,
};
use agent_settings::{AgentSettings, CompletionMode};
use anyhow::Result;
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection, SlashCommandWorkingSet};
use assistant_slash_commands::{
DefaultSlashCommand, DocsSlashCommand, DocsSlashCommandArgs, FileSlashCommand,
selections_creases,
};
use assistant_slash_commands::{DefaultSlashCommand, FileSlashCommand, selections_creases};
use client::{proto, zed_urls};
use collections::{BTreeSet, HashMap, HashSet, hash_map};
use editor::{
@@ -30,7 +27,6 @@ use gpui::{
StatefulInteractiveElement, Styled, Subscription, Task, Transformation, WeakEntity, actions,
div, img, percentage, point, prelude::*, pulsating_between, size,
};
use indexed_docs::IndexedDocsStore;
use language::{
BufferSnapshot, LspAdapterDelegate, ToOffset,
language_settings::{SoftWrap, all_language_settings},
@@ -77,7 +73,7 @@ use crate::{slash_command::SlashCommandCompletionProvider, slash_command_picker}
use assistant_context::{
AssistantContext, CacheStatus, Content, ContextEvent, ContextId, InvokedSlashCommandId,
InvokedSlashCommandStatus, Message, MessageId, MessageMetadata, MessageStatus,
ParsedSlashCommand, PendingSlashCommandStatus, ThoughtProcessOutputSection,
PendingSlashCommandStatus, ThoughtProcessOutputSection,
};
actions!(
@@ -701,19 +697,7 @@ impl TextThreadEditor {
}
};
let render_trailer = {
let command = command.clone();
move |row, _unfold, _window: &mut Window, cx: &mut App| {
// TODO: In the future we should investigate how we can expose
// this as a hook on the `SlashCommand` trait so that we don't
// need to special-case it here.
if command.name == DocsSlashCommand::NAME {
return render_docs_slash_command_trailer(
row,
command.clone(),
cx,
);
}
move |_row, _unfold, _window: &mut Window, _cx: &mut App| {
Empty.into_any()
}
};
@@ -2398,70 +2382,6 @@ fn render_pending_slash_command_gutter_decoration(
icon.into_any_element()
}
fn render_docs_slash_command_trailer(
row: MultiBufferRow,
command: ParsedSlashCommand,
cx: &mut App,
) -> AnyElement {
if command.arguments.is_empty() {
return Empty.into_any();
}
let args = DocsSlashCommandArgs::parse(&command.arguments);
let Some(store) = args
.provider()
.and_then(|provider| IndexedDocsStore::try_global(provider, cx).ok())
else {
return Empty.into_any();
};
let Some(package) = args.package() else {
return Empty.into_any();
};
let mut children = Vec::new();
if store.is_indexing(&package) {
children.push(
div()
.id(("crates-being-indexed", row.0))
.child(Icon::new(IconName::ArrowCircle).with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(4)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
))
.tooltip({
let package = package.clone();
Tooltip::text(format!("Indexing {package}"))
})
.into_any_element(),
);
}
if let Some(latest_error) = store.latest_error_for_package(&package) {
children.push(
div()
.id(("latest-error", row.0))
.child(
Icon::new(IconName::Warning)
.size(IconSize::Small)
.color(Color::Warning),
)
.tooltip(Tooltip::text(format!("Failed to index: {latest_error}")))
.into_any_element(),
)
}
let is_indexing = store.is_indexing(&package);
let latest_error = store.latest_error_for_package(&package);
if !is_indexing && latest_error.is_none() {
return Empty.into_any();
}
h_flex().gap_2().children(children).into_any_element()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CopyMetadata {
creases: Vec<SelectedCreaseMetadata>,

View File

@@ -2,11 +2,11 @@ use crate::ToggleBurnMode;
use gpui::{Context, FontWeight, IntoElement, Render, Window};
use ui::{KeyBinding, prelude::*, tooltip_container};
pub struct MaxModeTooltip {
pub struct BurnModeTooltip {
selected: bool,
}
impl MaxModeTooltip {
impl BurnModeTooltip {
pub fn new() -> Self {
Self { selected: false }
}
@@ -17,7 +17,7 @@ impl MaxModeTooltip {
}
}
impl Render for MaxModeTooltip {
impl Render for BurnModeTooltip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let (icon, color) = if self.selected {
(IconName::ZedBurnModeOn, Color::Error)

View File

@@ -27,7 +27,6 @@ globset.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
http_client.workspace = true
indexed_docs.workspace = true
language.workspace = true
project.workspace = true
prompt_store.workspace = true

View File

@@ -3,7 +3,6 @@ mod context_server_command;
mod default_command;
mod delta_command;
mod diagnostics_command;
mod docs_command;
mod fetch_command;
mod file_command;
mod now_command;
@@ -18,7 +17,6 @@ pub use crate::context_server_command::*;
pub use crate::default_command::*;
pub use crate::delta_command::*;
pub use crate::diagnostics_command::*;
pub use crate::docs_command::*;
pub use crate::fetch_command::*;
pub use crate::file_command::*;
pub use crate::now_command::*;

View File

@@ -1,543 +0,0 @@
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::time::Duration;
use anyhow::{Context as _, Result, anyhow, bail};
use assistant_slash_command::{
ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
SlashCommandResult,
};
use gpui::{App, BackgroundExecutor, Entity, Task, WeakEntity};
use indexed_docs::{
DocsDotRsProvider, IndexedDocsRegistry, IndexedDocsStore, LocalRustdocProvider, PackageName,
ProviderId,
};
use language::{BufferSnapshot, LspAdapterDelegate};
use project::{Project, ProjectPath};
use ui::prelude::*;
use util::{ResultExt, maybe};
use workspace::Workspace;
pub struct DocsSlashCommand;
impl DocsSlashCommand {
pub const NAME: &'static str = "docs";
fn path_to_cargo_toml(project: Entity<Project>, cx: &mut App) -> Option<Arc<Path>> {
let worktree = project.read(cx).worktrees(cx).next()?;
let worktree = worktree.read(cx);
let entry = worktree.entry_for_path("Cargo.toml")?;
let path = ProjectPath {
worktree_id: worktree.id(),
path: entry.path.clone(),
};
Some(Arc::from(
project.read(cx).absolute_path(&path, cx)?.as_path(),
))
}
/// Ensures that the indexed doc providers for Rust are registered.
///
/// Ideally we would do this sooner, but we need to wait until we're able to
/// access the workspace so we can read the project.
fn ensure_rust_doc_providers_are_registered(
&self,
workspace: Option<WeakEntity<Workspace>>,
cx: &mut App,
) {
let indexed_docs_registry = IndexedDocsRegistry::global(cx);
if indexed_docs_registry
.get_provider_store(LocalRustdocProvider::id())
.is_none()
{
let index_provider_deps = maybe!({
let workspace = workspace
.as_ref()
.context("no workspace")?
.upgrade()
.context("workspace dropped")?;
let project = workspace.read(cx).project().clone();
let fs = project.read(cx).fs().clone();
let cargo_workspace_root = Self::path_to_cargo_toml(project, cx)
.and_then(|path| path.parent().map(|path| path.to_path_buf()))
.context("no Cargo workspace root found")?;
anyhow::Ok((fs, cargo_workspace_root))
});
if let Some((fs, cargo_workspace_root)) = index_provider_deps.log_err() {
indexed_docs_registry.register_provider(Box::new(LocalRustdocProvider::new(
fs,
cargo_workspace_root,
)));
}
}
if indexed_docs_registry
.get_provider_store(DocsDotRsProvider::id())
.is_none()
{
let http_client = maybe!({
let workspace = workspace
.as_ref()
.context("no workspace")?
.upgrade()
.context("workspace was dropped")?;
let project = workspace.read(cx).project().clone();
anyhow::Ok(project.read(cx).client().http_client())
});
if let Some(http_client) = http_client.log_err() {
indexed_docs_registry
.register_provider(Box::new(DocsDotRsProvider::new(http_client)));
}
}
}
/// Runs just-in-time indexing for a given package, in case the slash command
/// is run without any entries existing in the index.
fn run_just_in_time_indexing(
store: Arc<IndexedDocsStore>,
key: String,
package: PackageName,
executor: BackgroundExecutor,
) -> Task<()> {
executor.clone().spawn(async move {
let (prefix, needs_full_index) = if let Some((prefix, _)) = key.split_once('*') {
// If we have a wildcard in the search, we want to wait until
// we've completely finished indexing so we get a full set of
// results for the wildcard.
(prefix.to_string(), true)
} else {
(key, false)
};
// If we already have some entries, we assume that we've indexed the package before
// and don't need to do it again.
let has_any_entries = store
.any_with_prefix(prefix.clone())
.await
.unwrap_or_default();
if has_any_entries {
return ();
};
let index_task = store.clone().index(package.clone());
if needs_full_index {
_ = index_task.await;
} else {
loop {
executor.timer(Duration::from_millis(200)).await;
if store
.any_with_prefix(prefix.clone())
.await
.unwrap_or_default()
|| !store.is_indexing(&package)
{
break;
}
}
}
})
}
}
impl SlashCommand for DocsSlashCommand {
fn name(&self) -> String {
Self::NAME.into()
}
fn description(&self) -> String {
"insert docs".into()
}
fn menu_text(&self) -> String {
"Insert Documentation".into()
}
fn requires_argument(&self) -> bool {
true
}
fn complete_argument(
self: Arc<Self>,
arguments: &[String],
_cancel: Arc<AtomicBool>,
workspace: Option<WeakEntity<Workspace>>,
_: &mut Window,
cx: &mut App,
) -> Task<Result<Vec<ArgumentCompletion>>> {
self.ensure_rust_doc_providers_are_registered(workspace, cx);
let indexed_docs_registry = IndexedDocsRegistry::global(cx);
let args = DocsSlashCommandArgs::parse(arguments);
let store = args
.provider()
.context("no docs provider specified")
.and_then(|provider| IndexedDocsStore::try_global(provider, cx));
cx.background_spawn(async move {
fn build_completions(items: Vec<String>) -> Vec<ArgumentCompletion> {
items
.into_iter()
.map(|item| ArgumentCompletion {
label: item.clone().into(),
new_text: item.to_string(),
after_completion: assistant_slash_command::AfterCompletion::Run,
replace_previous_arguments: false,
})
.collect()
}
match args {
DocsSlashCommandArgs::NoProvider => {
let providers = indexed_docs_registry.list_providers();
if providers.is_empty() {
return Ok(vec![ArgumentCompletion {
label: "No available docs providers.".into(),
new_text: String::new(),
after_completion: false.into(),
replace_previous_arguments: false,
}]);
}
Ok(providers
.into_iter()
.map(|provider| ArgumentCompletion {
label: provider.to_string().into(),
new_text: provider.to_string(),
after_completion: false.into(),
replace_previous_arguments: false,
})
.collect())
}
DocsSlashCommandArgs::SearchPackageDocs {
provider,
package,
index,
} => {
let store = store?;
if index {
// We don't need to hold onto this task, as the `IndexedDocsStore` will hold it
// until it completes.
drop(store.clone().index(package.as_str().into()));
}
let suggested_packages = store.clone().suggest_packages().await?;
let search_results = store.search(package).await;
let mut items = build_completions(search_results);
let workspace_crate_completions = suggested_packages
.into_iter()
.filter(|package_name| {
!items
.iter()
.any(|item| item.label.text() == package_name.as_ref())
})
.map(|package_name| ArgumentCompletion {
label: format!("{package_name} (unindexed)").into(),
new_text: format!("{package_name}"),
after_completion: true.into(),
replace_previous_arguments: false,
})
.collect::<Vec<_>>();
items.extend(workspace_crate_completions);
if items.is_empty() {
return Ok(vec![ArgumentCompletion {
label: format!(
"Enter a {package_term} name.",
package_term = package_term(&provider)
)
.into(),
new_text: provider.to_string(),
after_completion: false.into(),
replace_previous_arguments: false,
}]);
}
Ok(items)
}
DocsSlashCommandArgs::SearchItemDocs { item_path, .. } => {
let store = store?;
let items = store.search(item_path).await;
Ok(build_completions(items))
}
}
})
}
fn run(
self: Arc<Self>,
arguments: &[String],
_context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
_context_buffer: BufferSnapshot,
_workspace: WeakEntity<Workspace>,
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
_: &mut Window,
cx: &mut App,
) -> Task<SlashCommandResult> {
if arguments.is_empty() {
return Task::ready(Err(anyhow!("missing an argument")));
};
let args = DocsSlashCommandArgs::parse(arguments);
let executor = cx.background_executor().clone();
let task = cx.background_spawn({
let store = args
.provider()
.context("no docs provider specified")
.and_then(|provider| IndexedDocsStore::try_global(provider, cx));
async move {
let (provider, key) = match args.clone() {
DocsSlashCommandArgs::NoProvider => bail!("no docs provider specified"),
DocsSlashCommandArgs::SearchPackageDocs {
provider, package, ..
} => (provider, package),
DocsSlashCommandArgs::SearchItemDocs {
provider,
item_path,
..
} => (provider, item_path),
};
if key.trim().is_empty() {
bail!(
"no {package_term} name provided",
package_term = package_term(&provider)
);
}
let store = store?;
if let Some(package) = args.package() {
Self::run_just_in_time_indexing(store.clone(), key.clone(), package, executor)
.await;
}
let (text, ranges) = if let Some((prefix, _)) = key.split_once('*') {
let docs = store.load_many_by_prefix(prefix.to_string()).await?;
let mut text = String::new();
let mut ranges = Vec::new();
for (key, docs) in docs {
let prev_len = text.len();
text.push_str(&docs.0);
text.push_str("\n");
ranges.push((key, prev_len..text.len()));
text.push_str("\n");
}
(text, ranges)
} else {
let item_docs = store.load(key.clone()).await?;
let text = item_docs.to_string();
let range = 0..text.len();
(text, vec![(key, range)])
};
anyhow::Ok((provider, text, ranges))
}
});
cx.foreground_executor().spawn(async move {
let (provider, text, ranges) = task.await?;
Ok(SlashCommandOutput {
text,
sections: ranges
.into_iter()
.map(|(key, range)| SlashCommandOutputSection {
range,
icon: IconName::FileDoc,
label: format!("docs ({provider}): {key}",).into(),
metadata: None,
})
.collect(),
run_commands_in_text: false,
}
.to_event_stream())
})
}
}
fn is_item_path_delimiter(char: char) -> bool {
!char.is_alphanumeric() && char != '-' && char != '_'
}
#[derive(Debug, PartialEq, Clone)]
pub enum DocsSlashCommandArgs {
NoProvider,
SearchPackageDocs {
provider: ProviderId,
package: String,
index: bool,
},
SearchItemDocs {
provider: ProviderId,
package: String,
item_path: String,
},
}
impl DocsSlashCommandArgs {
pub fn parse(arguments: &[String]) -> Self {
let Some(provider) = arguments
.get(0)
.cloned()
.filter(|arg| !arg.trim().is_empty())
else {
return Self::NoProvider;
};
let provider = ProviderId(provider.into());
let Some(argument) = arguments.get(1) else {
return Self::NoProvider;
};
if let Some((package, rest)) = argument.split_once(is_item_path_delimiter) {
if rest.trim().is_empty() {
Self::SearchPackageDocs {
provider,
package: package.to_owned(),
index: true,
}
} else {
Self::SearchItemDocs {
provider,
package: package.to_owned(),
item_path: argument.to_owned(),
}
}
} else {
Self::SearchPackageDocs {
provider,
package: argument.to_owned(),
index: false,
}
}
}
pub fn provider(&self) -> Option<ProviderId> {
match self {
Self::NoProvider => None,
Self::SearchPackageDocs { provider, .. } | Self::SearchItemDocs { provider, .. } => {
Some(provider.clone())
}
}
}
pub fn package(&self) -> Option<PackageName> {
match self {
Self::NoProvider => None,
Self::SearchPackageDocs { package, .. } | Self::SearchItemDocs { package, .. } => {
Some(package.as_str().into())
}
}
}
}
/// Returns the term used to refer to a package.
fn package_term(provider: &ProviderId) -> &'static str {
if provider == &DocsDotRsProvider::id() || provider == &LocalRustdocProvider::id() {
return "crate";
}
"package"
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_docs_slash_command_args() {
assert_eq!(
DocsSlashCommandArgs::parse(&["".to_string()]),
DocsSlashCommandArgs::NoProvider
);
assert_eq!(
DocsSlashCommandArgs::parse(&["rustdoc".to_string()]),
DocsSlashCommandArgs::NoProvider
);
assert_eq!(
DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "".to_string()]),
DocsSlashCommandArgs::SearchPackageDocs {
provider: ProviderId("rustdoc".into()),
package: "".into(),
index: false
}
);
assert_eq!(
DocsSlashCommandArgs::parse(&["gleam".to_string(), "".to_string()]),
DocsSlashCommandArgs::SearchPackageDocs {
provider: ProviderId("gleam".into()),
package: "".into(),
index: false
}
);
assert_eq!(
DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "gpui".to_string()]),
DocsSlashCommandArgs::SearchPackageDocs {
provider: ProviderId("rustdoc".into()),
package: "gpui".into(),
index: false,
}
);
assert_eq!(
DocsSlashCommandArgs::parse(&["gleam".to_string(), "gleam_stdlib".to_string()]),
DocsSlashCommandArgs::SearchPackageDocs {
provider: ProviderId("gleam".into()),
package: "gleam_stdlib".into(),
index: false
}
);
// Adding an item path delimiter indicates we can start indexing.
assert_eq!(
DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "gpui:".to_string()]),
DocsSlashCommandArgs::SearchPackageDocs {
provider: ProviderId("rustdoc".into()),
package: "gpui".into(),
index: true,
}
);
assert_eq!(
DocsSlashCommandArgs::parse(&["gleam".to_string(), "gleam_stdlib/".to_string()]),
DocsSlashCommandArgs::SearchPackageDocs {
provider: ProviderId("gleam".into()),
package: "gleam_stdlib".into(),
index: true
}
);
assert_eq!(
DocsSlashCommandArgs::parse(&[
"rustdoc".to_string(),
"gpui::foo::bar::Baz".to_string()
]),
DocsSlashCommandArgs::SearchItemDocs {
provider: ProviderId("rustdoc".into()),
package: "gpui".into(),
item_path: "gpui::foo::bar::Baz".into()
}
);
assert_eq!(
DocsSlashCommandArgs::parse(&[
"gleam".to_string(),
"gleam_stdlib/gleam/int".to_string()
]),
DocsSlashCommandArgs::SearchItemDocs {
provider: ProviderId("gleam".into()),
package: "gleam_stdlib".into(),
item_path: "gleam_stdlib/gleam/int".into()
}
);
}
}

View File

@@ -18,6 +18,6 @@ collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
parking_lot.workspace = true
rodio = { workspace = true, features = ["wav", "playback", "tracing"] }
rodio = { workspace = true, features = [ "wav", "playback", "tracing" ] }
util.workspace = true
workspace-hack.workspace = true

View File

@@ -1,16 +1,12 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{Context as _, Result, anyhow};
use chrono::Duration;
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
use futures::{StreamExt, stream::BoxStream};
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
use http_client::{AsyncBody, Method, Request, http};
use parking_lot::Mutex;
use rpc::{
ConnectionId, Peer, Receipt, TypedEnvelope,
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
};
use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
use std::sync::Arc;
pub struct FakeServer {
@@ -187,50 +183,27 @@ impl FakeServer {
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
self.executor.start_waiting();
loop {
let message = self
.state
.lock()
.incoming
.as_mut()
.expect("not connected")
.next()
.await
.context("other half hung up")?;
self.executor.finish_waiting();
let type_name = message.payload_type_name();
let message = message.into_any();
let message = self
.state
.lock()
.incoming
.as_mut()
.expect("not connected")
.next()
.await
.context("other half hung up")?;
self.executor.finish_waiting();
let type_name = message.payload_type_name();
let message = message.into_any();
if message.is::<TypedEnvelope<M>>() {
return Ok(*message.downcast().unwrap());
}
let accepted_tos_at = chrono::Utc::now()
.checked_sub_signed(Duration::hours(5))
.expect("failed to build accepted_tos_at")
.timestamp() as u64;
if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
self.respond(
message
.downcast::<TypedEnvelope<GetPrivateUserInfo>>()
.unwrap()
.receipt(),
GetPrivateUserInfoResponse {
metrics_id: "the-metrics-id".into(),
staff: false,
flags: Default::default(),
accepted_tos_at: Some(accepted_tos_at),
},
);
continue;
}
panic!(
"fake server received unexpected message type: {:?}",
type_name
);
if message.is::<TypedEnvelope<M>>() {
return Ok(*message.downcast().unwrap());
}
panic!(
"fake server received unexpected message type: {:?}",
type_name
);
}
pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {

View File

@@ -177,7 +177,6 @@ impl UserStore {
let (mut current_user_tx, current_user_rx) = watch::channel();
let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
let rpc_subscriptions = vec![
client.add_message_handler(cx.weak_entity(), Self::handle_update_plan),
client.add_message_handler(cx.weak_entity(), Self::handle_update_contacts),
client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info),
client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts),
@@ -343,26 +342,6 @@ impl UserStore {
Ok(())
}
async fn handle_update_plan(
this: Entity<Self>,
_message: TypedEnvelope<proto::UpdateUserPlan>,
mut cx: AsyncApp,
) -> Result<()> {
let client = this
.read_with(&cx, |this, _| this.client.upgrade())?
.context("client was dropped")?;
let response = client
.cloud_client()
.get_authenticated_user()
.await
.context("failed to fetch authenticated user")?;
this.update(&mut cx, |this, cx| {
this.update_authenticated_user(response, cx);
})
}
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
match message {
UpdateContacts::Wait(barrier) => {
@@ -1019,19 +998,6 @@ impl RequestUsage {
}
}
pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option<Self> {
let limit = match limit.variant? {
proto::usage_limit::Variant::Limited(limited) => {
UsageLimit::Limited(limited.limit as i32)
}
proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited,
};
Some(RequestUsage {
limit,
amount: amount as i32,
})
}
fn from_headers(
limit_name: &str,
amount_name: &str,

View File

@@ -19,7 +19,6 @@ test-support = ["sqlite"]
[dependencies]
anyhow.workspace = true
async-stripe.workspace = true
async-trait.workspace = true
async-tungstenite.workspace = true
aws-config = { version = "1.1.5" }
@@ -30,16 +29,13 @@ axum-extra = { version = "0.4", features = ["erased-json"] }
base64.workspace = true
chrono.workspace = true
clock.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
dashmap.workspace = true
derive_more.workspace = true
envy = "0.4.2"
futures.workspace = true
gpui.workspace = true
hex.workspace = true
http_client.workspace = true
jsonwebtoken.workspace = true
livekit_api.workspace = true
log.workspace = true
nanoid.workspace = true
@@ -65,7 +61,6 @@ subtle.workspace = true
supermaven_api.workspace = true
telemetry_events.workspace = true
text.workspace = true
thiserror.workspace = true
time.workspace = true
tokio = { workspace = true, features = ["full"] }
toml.workspace = true
@@ -136,6 +131,3 @@ util.workspace = true
workspace = { workspace = true, features = ["test-support"] }
worktree = { workspace = true, features = ["test-support"] }
zlog.workspace = true
[package.metadata.cargo-machete]
ignored = ["async-stripe"]

View File

@@ -219,12 +219,6 @@ spec:
secretKeyRef:
name: slack
key: panics_webhook
- name: STRIPE_API_KEY
valueFrom:
secretKeyRef:
name: stripe
key: api_key
optional: true
- name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR
value: "1000"
- name: SUPERMAVEN_ADMIN_API_KEY

View File

@@ -474,67 +474,6 @@ CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id
CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count");
CREATE TABLE rate_buckets (
user_id INT NOT NULL,
rate_limit_name VARCHAR(255) NOT NULL,
token_count INT NOT NULL,
last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
PRIMARY KEY (user_id, rate_limit_name),
FOREIGN KEY (user_id) REFERENCES users (id)
);
CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
CREATE TABLE IF NOT EXISTS billing_preferences (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_id INTEGER NOT NULL REFERENCES users (id),
max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL,
model_request_overages_enabled bool NOT NULL DEFAULT FALSE,
model_request_overages_spend_limit_in_cents integer NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id);
CREATE TABLE IF NOT EXISTS billing_customers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_id INTEGER NOT NULL REFERENCES users (id),
has_overdue_invoices BOOLEAN NOT NULL DEFAULT FALSE,
stripe_customer_id TEXT NOT NULL,
trial_started_at TIMESTAMP
);
CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);
CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id);
CREATE TABLE IF NOT EXISTS billing_subscriptions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
billing_customer_id INTEGER NOT NULL REFERENCES billing_customers (id),
stripe_subscription_id TEXT NOT NULL,
stripe_subscription_status TEXT NOT NULL,
stripe_cancel_at TIMESTAMP,
stripe_cancellation_reason TEXT,
kind TEXT,
stripe_current_period_start BIGINT,
stripe_current_period_end BIGINT
);
CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id);
CREATE UNIQUE INDEX "uix_billing_subscriptions_on_stripe_subscription_id" ON billing_subscriptions (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS processed_stripe_events (
stripe_event_id TEXT PRIMARY KEY,
stripe_event_type TEXT NOT NULL,
stripe_event_created_timestamp INTEGER NOT NULL,
processed_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX "ix_processed_stripe_events_on_stripe_event_created_timestamp" ON processed_stripe_events (stripe_event_created_timestamp);
CREATE TABLE IF NOT EXISTS "breakpoints" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,

View File

@@ -0,0 +1,2 @@
alter table users
alter column admin set not null;

View File

@@ -0,0 +1,2 @@
alter table billing_customers
add column orb_customer_id text;

View File

@@ -0,0 +1 @@
drop table rate_buckets;

View File

@@ -1,19 +1,11 @@
pub mod billing;
pub mod contributors;
pub mod events;
pub mod extensions;
pub mod ips_file;
pub mod slack;
use crate::db::Database;
use crate::{
AppState, Error, Result, auth,
db::{User, UserId},
rpc,
};
use ::rpc::proto;
use crate::{AppState, Error, Result, auth, db::UserId, rpc};
use anyhow::Context as _;
use axum::extract;
use axum::{
Extension, Json, Router,
body::Body,
@@ -25,7 +17,6 @@ use axum::{
routing::{get, post},
};
use axum_extra::response::ErasedJson;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::sync::{Arc, OnceLock};
use tower::ServiceBuilder;
@@ -100,10 +91,7 @@ impl std::fmt::Display for SystemIdHeader {
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
Router::new()
.route("/users/look_up", get(look_up_user))
.route("/users/:id/access_tokens", post(create_access_token))
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
.route("/users/:id/update_plan", post(update_plan))
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
.merge(contributors::router())
.layer(
@@ -144,99 +132,6 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
Ok::<_, Error>(next.run(req).await)
}
#[derive(Debug, Deserialize)]
struct LookUpUserParams {
identifier: String,
}
#[derive(Debug, Serialize)]
struct LookUpUserResponse {
user: Option<User>,
}
async fn look_up_user(
Query(params): Query<LookUpUserParams>,
Extension(app): Extension<Arc<AppState>>,
) -> Result<Json<LookUpUserResponse>> {
let user = resolve_identifier_to_user(&app.db, &params.identifier).await?;
let user = if let Some(user) = user {
match user {
UserOrId::User(user) => Some(user),
UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
}
} else {
None
};
Ok(Json(LookUpUserResponse { user }))
}
enum UserOrId {
User(User),
Id(UserId),
}
async fn resolve_identifier_to_user(
db: &Arc<Database>,
identifier: &str,
) -> Result<Option<UserOrId>> {
if let Some(identifier) = identifier.parse::<i32>().ok() {
let user = db.get_user_by_id(UserId(identifier)).await?;
return Ok(user.map(UserOrId::User));
}
if identifier.starts_with("cus_") {
let billing_customer = db
.get_billing_customer_by_stripe_customer_id(&identifier)
.await?;
return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
}
if identifier.starts_with("sub_") {
let billing_subscription = db
.get_billing_subscription_by_stripe_subscription_id(&identifier)
.await?;
if let Some(billing_subscription) = billing_subscription {
let billing_customer = db
.get_billing_customer_by_id(billing_subscription.billing_customer_id)
.await?;
return Ok(
billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
);
} else {
return Ok(None);
}
}
if identifier.contains('@') {
let user = db.get_user_by_email(identifier).await?;
return Ok(user.map(UserOrId::User));
}
if let Some(user) = db.get_user_by_github_login(identifier).await? {
return Ok(Some(UserOrId::User(user)));
}
Ok(None)
}
#[derive(Deserialize, Debug)]
struct CreateUserParams {
github_user_id: i32,
github_login: String,
email_address: String,
email_confirmation_code: Option<String>,
#[serde(default)]
admin: bool,
#[serde(default)]
invite_count: i32,
}
async fn get_rpc_server_snapshot(
Extension(rpc_server): Extension<Arc<rpc::Server>>,
) -> Result<ErasedJson> {
@@ -295,90 +190,3 @@ async fn create_access_token(
encrypted_access_token,
}))
}
#[derive(Serialize)]
struct RefreshLlmTokensResponse {}
async fn refresh_llm_tokens(
Path(user_id): Path<UserId>,
Extension(rpc_server): Extension<Arc<rpc::Server>>,
) -> Result<Json<RefreshLlmTokensResponse>> {
rpc_server.refresh_llm_tokens_for_user(user_id).await;
Ok(Json(RefreshLlmTokensResponse {}))
}
#[derive(Debug, Serialize, Deserialize)]
struct UpdatePlanBody {
pub plan: cloud_llm_client::Plan,
pub subscription_period: SubscriptionPeriod,
pub usage: cloud_llm_client::CurrentUsage,
pub trial_started_at: Option<DateTime<Utc>>,
pub is_usage_based_billing_enabled: bool,
pub is_account_too_young: bool,
pub has_overdue_invoices: bool,
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
struct SubscriptionPeriod {
pub started_at: DateTime<Utc>,
pub ended_at: DateTime<Utc>,
}
#[derive(Serialize)]
struct UpdatePlanResponse {}
async fn update_plan(
Path(user_id): Path<UserId>,
Extension(rpc_server): Extension<Arc<rpc::Server>>,
extract::Json(body): extract::Json<UpdatePlanBody>,
) -> Result<Json<UpdatePlanResponse>> {
let plan = match body.plan {
cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
};
let update_user_plan = proto::UpdateUserPlan {
plan: plan.into(),
trial_started_at: body
.trial_started_at
.map(|trial_started_at| trial_started_at.timestamp() as u64),
is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
usage: Some(proto::SubscriptionUsage {
model_requests_usage_amount: body.usage.model_requests.used,
model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
edit_predictions_usage_amount: body.usage.edit_predictions.used,
edit_predictions_usage_limit: Some(usage_limit_to_proto(
body.usage.edit_predictions.limit,
)),
}),
subscription_period: Some(proto::SubscriptionPeriod {
started_at: body.subscription_period.started_at.timestamp() as u64,
ended_at: body.subscription_period.ended_at.timestamp() as u64,
}),
account_too_young: Some(body.is_account_too_young),
has_overdue_invoices: Some(body.has_overdue_invoices),
};
rpc_server
.update_plan_for_user(user_id, update_user_plan)
.await?;
Ok(Json(UpdatePlanResponse {}))
}
fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit {
proto::UsageLimit {
variant: Some(match limit {
cloud_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}
}

View File

@@ -1,59 +0,0 @@
use std::sync::Arc;
use stripe::SubscriptionStatus;
use crate::AppState;
use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::{CreateBillingCustomerParams, billing_customer};
use crate::stripe_client::{StripeClient, StripeCustomerId};
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
fn from(value: SubscriptionStatus) -> Self {
match value {
SubscriptionStatus::Incomplete => Self::Incomplete,
SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
SubscriptionStatus::Trialing => Self::Trialing,
SubscriptionStatus::Active => Self::Active,
SubscriptionStatus::PastDue => Self::PastDue,
SubscriptionStatus::Canceled => Self::Canceled,
SubscriptionStatus::Unpaid => Self::Unpaid,
SubscriptionStatus::Paused => Self::Paused,
}
}
}
/// Finds or creates a billing customer using the provided customer.
pub async fn find_or_create_billing_customer(
app: &Arc<AppState>,
stripe_client: &dyn StripeClient,
customer_id: &StripeCustomerId,
) -> anyhow::Result<Option<billing_customer::Model>> {
// If we already have a billing customer record associated with the Stripe customer,
// there's nothing more we need to do.
if let Some(billing_customer) = app
.db
.get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
.await?
{
return Ok(Some(billing_customer));
}
let customer = stripe_client.get_customer(customer_id).await?;
let Some(email) = customer.email else {
return Ok(None);
};
let Some(user) = app.db.get_user_by_email(&email).await? else {
return Ok(None);
};
let billing_customer = app
.db
.create_billing_customer(&CreateBillingCustomerParams {
user_id: user.id,
stripe_customer_id: customer.id.to_string(),
})
.await?;
Ok(Some(billing_customer))
}

View File

@@ -564,170 +564,10 @@ fn for_snowflake(
country_code: Option<String>,
checksum_matched: bool,
) -> impl Iterator<Item = SnowflakeRow> {
body.events.into_iter().filter_map(move |event| {
body.events.into_iter().map(move |event| {
let timestamp =
first_event_at + Duration::milliseconds(event.milliseconds_since_first_event);
// We will need to double check, but I believe all of the events that
// are being transformed here are now migrated over to use the
// telemetry::event! macro, as of this commit so this code can go away
// when we feel enough users have upgraded past this point.
let (event_type, mut event_properties) = match &event.event {
Event::Editor(e) => (
match e.operation.as_str() {
"open" => "Editor Opened".to_string(),
"save" => "Editor Saved".to_string(),
_ => format!("Unknown Editor Event: {}", e.operation),
},
serde_json::to_value(e).unwrap(),
),
Event::EditPrediction(e) => (
format!(
"Edit Prediction {}",
if e.suggestion_accepted {
"Accepted"
} else {
"Discarded"
}
),
serde_json::to_value(e).unwrap(),
),
Event::EditPredictionRating(e) => (
"Edit Prediction Rated".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Call(e) => {
let event_type = match e.operation.trim() {
"unshare project" => "Project Unshared".to_string(),
"open channel notes" => "Channel Notes Opened".to_string(),
"share project" => "Project Shared".to_string(),
"join channel" => "Channel Joined".to_string(),
"hang up" => "Call Ended".to_string(),
"accept incoming" => "Incoming Call Accepted".to_string(),
"invite" => "Participant Invited".to_string(),
"disable microphone" => "Microphone Disabled".to_string(),
"enable microphone" => "Microphone Enabled".to_string(),
"enable screen share" => "Screen Share Enabled".to_string(),
"disable screen share" => "Screen Share Disabled".to_string(),
"decline incoming" => "Incoming Call Declined".to_string(),
_ => format!("Unknown Call Event: {}", e.operation),
};
(event_type, serde_json::to_value(e).unwrap())
}
Event::Assistant(e) => (
match e.phase {
telemetry_events::AssistantPhase::Response => "Assistant Responded".to_string(),
telemetry_events::AssistantPhase::Invoked => "Assistant Invoked".to_string(),
telemetry_events::AssistantPhase::Accepted => {
"Assistant Response Accepted".to_string()
}
telemetry_events::AssistantPhase::Rejected => {
"Assistant Response Rejected".to_string()
}
},
serde_json::to_value(e).unwrap(),
),
Event::Cpu(_) | Event::Memory(_) => return None,
Event::App(e) => {
let mut properties = json!({});
let event_type = match e.operation.trim() {
// App
"open" => "App Opened".to_string(),
"first open" => "App First Opened".to_string(),
"first open for release channel" => {
"App First Opened For Release Channel".to_string()
}
"close" => "App Closed".to_string(),
// Project
"open project" => "Project Opened".to_string(),
"open node project" => {
properties["project_type"] = json!("node");
"Project Opened".to_string()
}
"open pnpm project" => {
properties["project_type"] = json!("pnpm");
"Project Opened".to_string()
}
"open yarn project" => {
properties["project_type"] = json!("yarn");
"Project Opened".to_string()
}
// SSH
"create ssh server" => "SSH Server Created".to_string(),
"create ssh project" => "SSH Project Created".to_string(),
"open ssh project" => "SSH Project Opened".to_string(),
// Welcome Page
"welcome page: change keymap" => "Welcome Keymap Changed".to_string(),
"welcome page: change theme" => "Welcome Theme Changed".to_string(),
"welcome page: close" => "Welcome Page Closed".to_string(),
"welcome page: edit settings" => "Welcome Settings Edited".to_string(),
"welcome page: install cli" => "Welcome CLI Installed".to_string(),
"welcome page: open" => "Welcome Page Opened".to_string(),
"welcome page: open extensions" => "Welcome Extensions Page Opened".to_string(),
"welcome page: sign in to copilot" => "Welcome Copilot Signed In".to_string(),
"welcome page: toggle diagnostic telemetry" => {
"Welcome Diagnostic Telemetry Toggled".to_string()
}
"welcome page: toggle metric telemetry" => {
"Welcome Metric Telemetry Toggled".to_string()
}
"welcome page: toggle vim" => "Welcome Vim Mode Toggled".to_string(),
"welcome page: view docs" => "Welcome Documentation Viewed".to_string(),
// Extensions
"extensions page: open" => "Extensions Page Opened".to_string(),
"extensions: install extension" => "Extension Installed".to_string(),
"extensions: uninstall extension" => "Extension Uninstalled".to_string(),
// Misc
"markdown preview: open" => "Markdown Preview Opened".to_string(),
"project diagnostics: open" => "Project Diagnostics Opened".to_string(),
"project search: open" => "Project Search Opened".to_string(),
"repl sessions: open" => "REPL Session Started".to_string(),
// Feature Upsell
"feature upsell: toggle vim" => {
properties["source"] = json!("Feature Upsell");
"Vim Mode Toggled".to_string()
}
_ => e
.operation
.strip_prefix("feature upsell: viewed docs (")
.and_then(|s| s.strip_suffix(')'))
.map_or_else(
|| format!("Unknown App Event: {}", e.operation),
|docs_url| {
properties["url"] = json!(docs_url);
properties["source"] = json!("Feature Upsell");
"Documentation Viewed".to_string()
},
),
};
(event_type, properties)
}
Event::Setting(e) => (
"Settings Changed".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Extension(e) => (
"Extension Loaded".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Edit(e) => (
"Editor Edited".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Action(e) => (
"Action Invoked".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Repl(e) => (
"Kernel Status Changed".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Flexible(e) => (
e.event_type.clone(),
serde_json::to_value(&e.event_properties).unwrap(),
@@ -759,7 +599,7 @@ fn for_snowflake(
})
});
Some(SnowflakeRow {
SnowflakeRow {
time: timestamp,
user_id: body.metrics_id.clone(),
device_id: body.system_id.clone(),
@@ -767,7 +607,7 @@ fn for_snowflake(
event_properties,
user_properties,
insert_id: Some(Uuid::new_v4().to_string()),
})
}
})
}

View File

@@ -41,12 +41,7 @@ use worktree_settings_file::LocalSettingsKind;
pub use tests::TestDb;
pub use ids::*;
pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams};
pub use queries::billing_subscriptions::{
CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams,
};
pub use queries::contributors::ContributorSelector;
pub use queries::processed_stripe_events::CreateProcessedStripeEventParams;
pub use sea_orm::ConnectOptions;
pub use tables::user::Model as User;
pub use tables::*;

View File

@@ -70,9 +70,6 @@ macro_rules! id_type {
}
id_type!(AccessTokenId);
id_type!(BillingCustomerId);
id_type!(BillingSubscriptionId);
id_type!(BillingPreferencesId);
id_type!(BufferId);
id_type!(ChannelBufferCollaboratorId);
id_type!(ChannelChatParticipantId);

View File

@@ -1,9 +1,6 @@
use super::*;
pub mod access_tokens;
pub mod billing_customers;
pub mod billing_preferences;
pub mod billing_subscriptions;
pub mod buffers;
pub mod channels;
pub mod contacts;
@@ -12,7 +9,6 @@ pub mod embeddings;
pub mod extensions;
pub mod messages;
pub mod notifications;
pub mod processed_stripe_events;
pub mod projects;
pub mod rooms;
pub mod servers;

View File

@@ -1,100 +0,0 @@
use super::*;
#[derive(Debug)]
pub struct CreateBillingCustomerParams {
pub user_id: UserId,
pub stripe_customer_id: String,
}
#[derive(Debug, Default)]
pub struct UpdateBillingCustomerParams {
pub user_id: ActiveValue<UserId>,
pub stripe_customer_id: ActiveValue<String>,
pub has_overdue_invoices: ActiveValue<bool>,
pub trial_started_at: ActiveValue<Option<DateTime>>,
}
impl Database {
/// Creates a new billing customer.
pub async fn create_billing_customer(
&self,
params: &CreateBillingCustomerParams,
) -> Result<billing_customer::Model> {
self.transaction(|tx| async move {
let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
user_id: ActiveValue::set(params.user_id),
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
..Default::default()
})
.exec_with_returning(&*tx)
.await?;
Ok(customer)
})
.await
}
/// Updates the specified billing customer.
pub async fn update_billing_customer(
&self,
id: BillingCustomerId,
params: &UpdateBillingCustomerParams,
) -> Result<()> {
self.transaction(|tx| async move {
billing_customer::Entity::update(billing_customer::ActiveModel {
id: ActiveValue::set(id),
user_id: params.user_id.clone(),
stripe_customer_id: params.stripe_customer_id.clone(),
has_overdue_invoices: params.has_overdue_invoices.clone(),
trial_started_at: params.trial_started_at.clone(),
created_at: ActiveValue::not_set(),
})
.exec(&*tx)
.await?;
Ok(())
})
.await
}
pub async fn get_billing_customer_by_id(
&self,
id: BillingCustomerId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::Id.eq(id))
.one(&*tx)
.await?)
})
.await
}
/// Returns the billing customer for the user with the specified ID.
pub async fn get_billing_customer_by_user_id(
&self,
user_id: UserId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::UserId.eq(user_id))
.one(&*tx)
.await?)
})
.await
}
/// Returns the billing customer for the user with the specified Stripe customer ID.
pub async fn get_billing_customer_by_stripe_customer_id(
&self,
stripe_customer_id: &str,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
.one(&*tx)
.await?)
})
.await
}
}

View File

@@ -1,17 +0,0 @@
use super::*;
impl Database {
/// Returns the billing preferences for the given user, if they exist.
pub async fn get_billing_preferences(
&self,
user_id: UserId,
) -> Result<Option<billing_preference::Model>> {
self.transaction(|tx| async move {
Ok(billing_preference::Entity::find()
.filter(billing_preference::Column::UserId.eq(user_id))
.one(&*tx)
.await?)
})
.await
}
}

View File

@@ -1,158 +0,0 @@
use anyhow::Context as _;
use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
};
use super::*;
#[derive(Debug)]
pub struct CreateBillingSubscriptionParams {
pub billing_customer_id: BillingCustomerId,
pub kind: Option<SubscriptionKind>,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
pub stripe_cancellation_reason: Option<StripeCancellationReason>,
pub stripe_current_period_start: Option<i64>,
pub stripe_current_period_end: Option<i64>,
}
#[derive(Debug, Default)]
pub struct UpdateBillingSubscriptionParams {
pub billing_customer_id: ActiveValue<BillingCustomerId>,
pub kind: ActiveValue<Option<SubscriptionKind>>,
pub stripe_subscription_id: ActiveValue<String>,
pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
pub stripe_current_period_start: ActiveValue<Option<i64>>,
pub stripe_current_period_end: ActiveValue<Option<i64>>,
}
impl Database {
/// Creates a new billing subscription.
pub async fn create_billing_subscription(
&self,
params: &CreateBillingSubscriptionParams,
) -> Result<billing_subscription::Model> {
self.transaction(|tx| async move {
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
kind: ActiveValue::set(params.kind),
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
..Default::default()
})
.exec(&*tx)
.await?
.last_insert_id;
Ok(billing_subscription::Entity::find_by_id(id)
.one(&*tx)
.await?
.context("failed to retrieve inserted billing subscription")?)
})
.await
}
/// Updates the specified billing subscription.
pub async fn update_billing_subscription(
&self,
id: BillingSubscriptionId,
params: &UpdateBillingSubscriptionParams,
) -> Result<()> {
self.transaction(|tx| async move {
billing_subscription::Entity::update(billing_subscription::ActiveModel {
id: ActiveValue::set(id),
billing_customer_id: params.billing_customer_id.clone(),
kind: params.kind.clone(),
stripe_subscription_id: params.stripe_subscription_id.clone(),
stripe_subscription_status: params.stripe_subscription_status.clone(),
stripe_cancel_at: params.stripe_cancel_at.clone(),
stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
stripe_current_period_start: params.stripe_current_period_start.clone(),
stripe_current_period_end: params.stripe_current_period_end.clone(),
created_at: ActiveValue::not_set(),
})
.exec(&*tx)
.await?;
Ok(())
})
.await
}
/// Returns the billing subscription with the specified Stripe subscription ID.
pub async fn get_billing_subscription_by_stripe_subscription_id(
&self,
stripe_subscription_id: &str,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.filter(
billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
)
.one(&*tx)
.await?)
})
.await
}
pub async fn get_active_billing_subscription(
&self,
user_id: UserId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
.filter(
Condition::all()
.add(
Condition::any()
.add(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.add(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Trialing),
),
)
.add(billing_subscription::Column::Kind.is_not_null()),
)
.one(&*tx)
.await?)
})
.await
}
/// Returns whether the user has an active billing subscription.
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
}
/// Returns the count of the active billing subscriptions for the user with the specified ID.
pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
self.transaction(|tx| async move {
let count = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(
billing_customer::Column::UserId.eq(user_id).and(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active)
.or(billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Trialing)),
),
)
.count(&*tx)
.await?;
Ok(count as usize)
})
.await
}
}

View File

@@ -1,69 +0,0 @@
use super::*;
#[derive(Debug)]
pub struct CreateProcessedStripeEventParams {
pub stripe_event_id: String,
pub stripe_event_type: String,
pub stripe_event_created_timestamp: i64,
}
impl Database {
/// Creates a new processed Stripe event.
pub async fn create_processed_stripe_event(
&self,
params: &CreateProcessedStripeEventParams,
) -> Result<()> {
self.transaction(|tx| async move {
processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel {
stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()),
stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()),
stripe_event_created_timestamp: ActiveValue::set(
params.stripe_event_created_timestamp,
),
..Default::default()
})
.exec_without_returning(&*tx)
.await?;
Ok(())
})
.await
}
/// Returns the processed Stripe event with the specified event ID.
pub async fn get_processed_stripe_event_by_event_id(
&self,
event_id: &str,
) -> Result<Option<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find_by_id(event_id)
.one(&*tx)
.await?)
})
.await
}
/// Returns the processed Stripe events with the specified event IDs.
pub async fn get_processed_stripe_events_by_event_ids(
&self,
event_ids: &[&str],
) -> Result<Vec<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find()
.filter(
processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()),
)
.all(&*tx)
.await?)
})
.await
}
/// Returns whether the Stripe event with the specified ID has already been processed.
pub async fn already_processed_stripe_event(&self, event_id: &str) -> Result<bool> {
Ok(self
.get_processed_stripe_event_by_event_id(event_id)
.await?
.is_some())
}
}

View File

@@ -1,7 +1,4 @@
pub mod access_token;
pub mod billing_customer;
pub mod billing_preference;
pub mod billing_subscription;
pub mod buffer;
pub mod buffer_operation;
pub mod buffer_snapshot;
@@ -23,7 +20,6 @@ pub mod notification;
pub mod notification_kind;
pub mod observed_buffer_edits;
pub mod observed_channel_messages;
pub mod processed_stripe_event;
pub mod project;
pub mod project_collaborator;
pub mod project_repository;

View File

@@ -1,41 +0,0 @@
use crate::db::{BillingCustomerId, UserId};
use sea_orm::entity::prelude::*;
/// A billing customer.
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "billing_customers")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingCustomerId,
pub user_id: UserId,
pub stripe_customer_id: String,
pub has_overdue_invoices: bool,
pub trial_started_at: Option<DateTime>,
pub created_at: DateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
#[sea_orm(has_many = "super::billing_subscription::Entity")]
BillingSubscription,
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl Related<super::billing_subscription::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingSubscription.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,32 +0,0 @@
use crate::db::{BillingPreferencesId, UserId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "billing_preferences")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingPreferencesId,
pub created_at: DateTime,
pub user_id: UserId,
pub max_monthly_llm_usage_spending_in_cents: i32,
pub model_request_overages_enabled: bool,
pub model_request_overages_spend_limit_in_cents: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,176 +0,0 @@
use crate::db::{BillingCustomerId, BillingSubscriptionId};
use crate::stripe_client;
use chrono::{Datelike as _, NaiveDate, Utc};
use sea_orm::entity::prelude::*;
use serde::Serialize;
/// A billing subscription.
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "billing_subscriptions")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingSubscriptionId,
pub billing_customer_id: BillingCustomerId,
pub kind: Option<SubscriptionKind>,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
pub stripe_cancel_at: Option<DateTime>,
pub stripe_cancellation_reason: Option<StripeCancellationReason>,
pub stripe_current_period_start: Option<i64>,
pub stripe_current_period_end: Option<i64>,
pub created_at: DateTime,
}
impl Model {
pub fn current_period_start_at(&self) -> Option<DateTimeUtc> {
let period_start = self.stripe_current_period_start?;
chrono::DateTime::from_timestamp(period_start, 0)
}
pub fn current_period_end_at(&self) -> Option<DateTimeUtc> {
let period_end = self.stripe_current_period_end?;
chrono::DateTime::from_timestamp(period_end, 0)
}
pub fn current_period(
subscription: Option<Self>,
is_staff: bool,
) -> Option<(DateTimeUtc, DateTimeUtc)> {
if is_staff {
let now = Utc::now();
let year = now.year();
let month = now.month();
let first_day_of_this_month =
NaiveDate::from_ymd_opt(year, month, 1)?.and_hms_opt(0, 0, 0)?;
let next_month = if month == 12 { 1 } else { month + 1 };
let next_month_year = if month == 12 { year + 1 } else { year };
let first_day_of_next_month =
NaiveDate::from_ymd_opt(next_month_year, next_month, 1)?.and_hms_opt(23, 59, 59)?;
let last_day_of_this_month = first_day_of_next_month - chrono::Days::new(1);
Some((
first_day_of_this_month.and_utc(),
last_day_of_this_month.and_utc(),
))
} else {
let subscription = subscription?;
let period_start_at = subscription.current_period_start_at()?;
let period_end_at = subscription.current_period_end_at()?;
Some((period_start_at, period_end_at))
}
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::billing_customer::Entity",
from = "Column::BillingCustomerId",
to = "super::billing_customer::Column::Id"
)]
BillingCustomer,
}
impl Related<super::billing_customer::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingCustomer.def()
}
}
impl ActiveModelBehavior for ActiveModel {}
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
#[serde(rename_all = "snake_case")]
pub enum SubscriptionKind {
#[sea_orm(string_value = "zed_pro")]
ZedPro,
#[sea_orm(string_value = "zed_pro_trial")]
ZedProTrial,
#[sea_orm(string_value = "zed_free")]
ZedFree,
}
impl From<SubscriptionKind> for cloud_llm_client::Plan {
fn from(value: SubscriptionKind) -> Self {
match value {
SubscriptionKind::ZedPro => Self::ZedPro,
SubscriptionKind::ZedProTrial => Self::ZedProTrial,
SubscriptionKind::ZedFree => Self::ZedFree,
}
}
}
/// The status of a Stripe subscription.
///
/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-status)
#[derive(
Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,
)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
#[serde(rename_all = "snake_case")]
pub enum StripeSubscriptionStatus {
#[default]
#[sea_orm(string_value = "incomplete")]
Incomplete,
#[sea_orm(string_value = "incomplete_expired")]
IncompleteExpired,
#[sea_orm(string_value = "trialing")]
Trialing,
#[sea_orm(string_value = "active")]
Active,
#[sea_orm(string_value = "past_due")]
PastDue,
#[sea_orm(string_value = "canceled")]
Canceled,
#[sea_orm(string_value = "unpaid")]
Unpaid,
#[sea_orm(string_value = "paused")]
Paused,
}
impl StripeSubscriptionStatus {
pub fn is_cancelable(&self) -> bool {
match self {
Self::Trialing | Self::Active | Self::PastDue => true,
Self::Incomplete
| Self::IncompleteExpired
| Self::Canceled
| Self::Unpaid
| Self::Paused => false,
}
}
}
/// The cancellation reason for a Stripe subscription.
///
/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-cancellation_details-reason)
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
#[serde(rename_all = "snake_case")]
pub enum StripeCancellationReason {
#[sea_orm(string_value = "cancellation_requested")]
CancellationRequested,
#[sea_orm(string_value = "payment_disputed")]
PaymentDisputed,
#[sea_orm(string_value = "payment_failed")]
PaymentFailed,
}
impl From<stripe_client::StripeCancellationDetailsReason> for StripeCancellationReason {
fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self {
match value {
stripe_client::StripeCancellationDetailsReason::CancellationRequested => {
Self::CancellationRequested
}
stripe_client::StripeCancellationDetailsReason::PaymentDisputed => {
Self::PaymentDisputed
}
stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
}
}
}

View File

@@ -1,16 +0,0 @@
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "processed_stripe_events")]
pub struct Model {
#[sea_orm(primary_key)]
pub stripe_event_id: String,
pub stripe_event_type: String,
pub stripe_event_created_timestamp: i64,
pub processed_at: DateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -29,8 +29,6 @@ pub struct Model {
pub enum Relation {
#[sea_orm(has_many = "super::access_token::Entity")]
AccessToken,
#[sea_orm(has_one = "super::billing_customer::Entity")]
BillingCustomer,
#[sea_orm(has_one = "super::room_participant::Entity")]
RoomParticipant,
#[sea_orm(has_many = "super::project::Entity")]
@@ -68,12 +66,6 @@ impl Related<super::access_token::Entity> for Entity {
}
}
impl Related<super::billing_customer::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingCustomer.def()
}
}
impl Related<super::room_participant::Entity> for Entity {
fn to() -> RelationDef {
Relation::RoomParticipant.def()

View File

@@ -8,7 +8,6 @@ mod embedding_tests;
mod extension_tests;
mod feature_flag_tests;
mod message_tests;
mod processed_stripe_event_tests;
mod user_tests;
use crate::migrations::run_database_migrations;

View File

@@ -1,38 +0,0 @@
use std::sync::Arc;
use crate::test_both_dbs;
use super::{CreateProcessedStripeEventParams, Database};
test_both_dbs!(
test_already_processed_stripe_event,
test_already_processed_stripe_event_postgres,
test_already_processed_stripe_event_sqlite
);
async fn test_already_processed_stripe_event(db: &Arc<Database>) {
let unprocessed_event_id = "evt_1PiJOuRxOf7d5PNaw2zzWiyO".to_string();
let processed_event_id = "evt_1PiIfMRxOf7d5PNakHrAUe8P".to_string();
db.create_processed_stripe_event(&CreateProcessedStripeEventParams {
stripe_event_id: processed_event_id.clone(),
stripe_event_type: "customer.created".into(),
stripe_event_created_timestamp: 1722355968,
})
.await
.unwrap();
assert!(
db.already_processed_stripe_event(&processed_event_id)
.await
.unwrap(),
"Expected {processed_event_id} to already be processed"
);
assert!(
!db.already_processed_stripe_event(&unprocessed_event_id)
.await
.unwrap(),
"Expected {unprocessed_event_id} to be unprocessed"
);
}

View File

@@ -7,8 +7,6 @@ pub mod llm;
pub mod migrations;
pub mod rpc;
pub mod seed;
pub mod stripe_billing;
pub mod stripe_client;
pub mod user_backfiller;
#[cfg(test)]
@@ -22,21 +20,16 @@ use axum::{
};
use db::{ChannelId, Database};
use executor::Executor;
use llm::db::LlmDatabase;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{RealStripeClient, StripeClient};
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub enum Error {
Http(StatusCode, String, HeaderMap),
Database(sea_orm::error::DbErr),
Internal(anyhow::Error),
Stripe(stripe::StripeError),
}
impl From<anyhow::Error> for Error {
@@ -51,12 +44,6 @@ impl From<sea_orm::error::DbErr> for Error {
}
}
impl From<stripe::StripeError> for Error {
fn from(error: stripe::StripeError) -> Self {
Self::Stripe(error)
}
}
impl From<axum::Error> for Error {
fn from(error: axum::Error) -> Self {
Self::Internal(error.into())
@@ -104,14 +91,6 @@ impl IntoResponse for Error {
);
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
}
Error::Stripe(error) => {
log::error!(
"HTTP error {}: {:?}",
StatusCode::INTERNAL_SERVER_ERROR,
&error
);
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
}
}
}
}
@@ -122,7 +101,6 @@ impl std::fmt::Debug for Error {
Error::Http(code, message, _headers) => (code, message).fmt(f),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
}
}
}
@@ -133,7 +111,6 @@ impl std::fmt::Display for Error {
Error::Http(code, message, _) => write!(f, "{code}: {message}"),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
}
}
}
@@ -179,7 +156,6 @@ pub struct Config {
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
@@ -234,7 +210,6 @@ impl Config {
auto_join_channel_id: None,
migrations_path: None,
seed_path: None,
stripe_api_key: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,
@@ -266,14 +241,8 @@ impl ServiceMode {
pub struct AppState {
pub db: Arc<Database>,
pub llm_db: Option<Arc<LlmDatabase>>,
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
/// This is a real instance of the Stripe client; we're working to replace references to this with the
/// [`StripeClient`] trait.
pub real_stripe_client: Option<Arc<stripe::Client>>,
pub stripe_client: Option<Arc<dyn StripeClient>>,
pub stripe_billing: Option<Arc<StripeBilling>>,
pub executor: Executor,
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
pub config: Config,
@@ -286,20 +255,6 @@ impl AppState {
let mut db = Database::new(db_options).await?;
db.initialize_notification_kinds().await?;
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
.llm_database_url
.clone()
.zip(config.llm_database_max_connections)
{
let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
llm_db_options.max_connections(llm_database_max_connections);
let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
llm_db.initialize().await?;
Some(Arc::new(llm_db))
} else {
None
};
let livekit_client = if let Some(((server, key), secret)) = config
.livekit_server
.as_ref()
@@ -316,18 +271,10 @@ impl AppState {
};
let db = Arc::new(db);
let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
let this = Self {
db: db.clone(),
llm_db,
livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
stripe_billing: stripe_client
.clone()
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
real_stripe_client: stripe_client.clone(),
stripe_client: stripe_client
.map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _),
executor,
kinesis_client: if config.kinesis_access_key.is_some() {
build_kinesis_client(&config).await.log_err()
@@ -340,14 +287,6 @@ impl AppState {
}
}
fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
let api_key = config
.stripe_api_key
.as_ref()
.context("missing stripe_api_key")?;
Ok(stripe::Client::new(api_key))
}
async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
let keys = aws_sdk_s3::config::Credentials::new(
config

View File

@@ -1,12 +1 @@
pub mod db;
mod token;
pub use token::*;
pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial";
/// The name of the feature flag that bypasses the account age check.
pub const BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG: &str = "bypass-account-age-check";
/// The minimum account age an account must have in order to use the LLM service.
pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);

View File

@@ -1,30 +1,9 @@
mod ids;
mod queries;
mod seed;
mod tables;
#[cfg(test)]
mod tests;
use cloud_llm_client::LanguageModelProvider;
use collections::HashMap;
pub use ids::*;
pub use seed::*;
pub use tables::*;
#[cfg(test)]
pub use tests::TestLlmDb;
use usage_measure::UsageMeasure;
use std::future::Future;
use std::sync::Arc;
use anyhow::Context;
pub use sea_orm::ConnectOptions;
use sea_orm::prelude::*;
use sea_orm::{
ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
};
use sea_orm::{DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait};
use crate::Result;
use crate::db::TransactionHandle;
@@ -36,9 +15,6 @@ pub struct LlmDatabase {
pool: DatabaseConnection,
#[allow(unused)]
executor: Executor,
provider_ids: HashMap<LanguageModelProvider, ProviderId>,
models: HashMap<(LanguageModelProvider, String), model::Model>,
usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
#[cfg(test)]
runtime: Option<tokio::runtime::Runtime>,
}
@@ -51,59 +27,11 @@ impl LlmDatabase {
options: options.clone(),
pool: sea_orm::Database::connect(options).await?,
executor,
provider_ids: HashMap::default(),
models: HashMap::default(),
usage_measure_ids: HashMap::default(),
#[cfg(test)]
runtime: None,
})
}
pub async fn initialize(&mut self) -> Result<()> {
self.initialize_providers().await?;
self.initialize_models().await?;
self.initialize_usage_measures().await?;
Ok(())
}
/// Returns the list of all known models, with their [`LanguageModelProvider`].
pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
self.models
.iter()
.map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
.collect::<Vec<_>>()
}
/// Returns the names of the known models for the given [`LanguageModelProvider`].
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
self.models
.keys()
.filter_map(|(model_provider, model_name)| {
if model_provider == &provider {
Some(model_name)
} else {
None
}
})
.cloned()
.collect::<Vec<_>>()
}
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
Ok(self
.models
.get(&(provider, name.to_string()))
.with_context(|| format!("unknown model {provider:?}:{name}"))?)
}
pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
Ok(self
.models
.values()
.find(|model| model.id == id)
.with_context(|| format!("no model for ID {id:?}"))?)
}
pub fn options(&self) -> &ConnectOptions {
&self.options
}

View File

@@ -1,11 +0,0 @@
use sea_orm::{DbErr, entity::prelude::*};
use serde::{Deserialize, Serialize};
use crate::id_type;
id_type!(BillingEventId);
id_type!(ModelId);
id_type!(ProviderId);
id_type!(RevokedAccessTokenId);
id_type!(UsageId);
id_type!(UsageMeasureId);

View File

@@ -1,5 +0,0 @@
use super::*;
pub mod providers;
pub mod subscription_usages;
pub mod usages;

View File

@@ -1,134 +0,0 @@
use super::*;
use sea_orm::{QueryOrder, sea_query::OnConflict};
use std::str::FromStr;
use strum::IntoEnumIterator as _;
pub struct ModelParams {
pub provider: LanguageModelProvider,
pub name: String,
pub max_requests_per_minute: i64,
pub max_tokens_per_minute: i64,
pub max_tokens_per_day: i64,
pub price_per_million_input_tokens: i32,
pub price_per_million_output_tokens: i32,
}
impl LlmDatabase {
pub async fn initialize_providers(&mut self) -> Result<()> {
self.provider_ids = self
.transaction(|tx| async move {
let existing_providers = provider::Entity::find().all(&*tx).await?;
let mut new_providers = LanguageModelProvider::iter()
.filter(|provider| {
!existing_providers
.iter()
.any(|p| p.name == provider.to_string())
})
.map(|provider| provider::ActiveModel {
name: ActiveValue::set(provider.to_string()),
..Default::default()
})
.peekable();
if new_providers.peek().is_some() {
provider::Entity::insert_many(new_providers)
.exec(&*tx)
.await?;
}
let all_providers: HashMap<_, _> = provider::Entity::find()
.all(&*tx)
.await?
.iter()
.filter_map(|provider| {
LanguageModelProvider::from_str(&provider.name)
.ok()
.map(|p| (p, provider.id))
})
.collect();
Ok(all_providers)
})
.await?;
Ok(())
}
pub async fn initialize_models(&mut self) -> Result<()> {
let all_provider_ids = &self.provider_ids;
self.models = self
.transaction(|tx| async move {
let all_models: HashMap<_, _> = model::Entity::find()
.all(&*tx)
.await?
.into_iter()
.filter_map(|model| {
let provider = all_provider_ids.iter().find_map(|(provider, id)| {
if *id == model.provider_id {
Some(provider)
} else {
None
}
})?;
Some(((*provider, model.name.clone()), model))
})
.collect();
Ok(all_models)
})
.await?;
Ok(())
}
pub async fn insert_models(&mut self, models: &[ModelParams]) -> Result<()> {
let all_provider_ids = &self.provider_ids;
self.transaction(|tx| async move {
model::Entity::insert_many(models.iter().map(|model_params| {
let provider_id = all_provider_ids[&model_params.provider];
model::ActiveModel {
provider_id: ActiveValue::set(provider_id),
name: ActiveValue::set(model_params.name.clone()),
max_requests_per_minute: ActiveValue::set(model_params.max_requests_per_minute),
max_tokens_per_minute: ActiveValue::set(model_params.max_tokens_per_minute),
max_tokens_per_day: ActiveValue::set(model_params.max_tokens_per_day),
price_per_million_input_tokens: ActiveValue::set(
model_params.price_per_million_input_tokens,
),
price_per_million_output_tokens: ActiveValue::set(
model_params.price_per_million_output_tokens,
),
..Default::default()
}
}))
.on_conflict(
OnConflict::columns([model::Column::ProviderId, model::Column::Name])
.update_columns([
model::Column::MaxRequestsPerMinute,
model::Column::MaxTokensPerMinute,
model::Column::MaxTokensPerDay,
model::Column::PricePerMillionInputTokens,
model::Column::PricePerMillionOutputTokens,
])
.to_owned(),
)
.exec_without_returning(&*tx)
.await?;
Ok(())
})
.await?;
self.initialize_models().await
}
/// Returns the list of LLM providers.
pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
self.transaction(|tx| async move {
Ok(provider::Entity::find()
.order_by_asc(provider::Column::Name)
.all(&*tx)
.await?
.into_iter()
.filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
.collect())
})
.await
}
}

View File

@@ -1,38 +0,0 @@
use crate::db::UserId;
use super::*;
impl LlmDatabase {
pub async fn get_subscription_usage_for_period(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
) -> Result<Option<subscription_usage::Model>> {
self.transaction(|tx| async move {
self.get_subscription_usage_for_period_in_tx(
user_id,
period_start_at,
period_end_at,
&tx,
)
.await
})
.await
}
async fn get_subscription_usage_for_period_in_tx(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
tx: &DatabaseTransaction,
) -> Result<Option<subscription_usage::Model>> {
Ok(subscription_usage::Entity::find()
.filter(subscription_usage::Column::UserId.eq(user_id))
.filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
.filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
.one(tx)
.await?)
}
}

View File

@@ -1,44 +0,0 @@
use std::str::FromStr;
use strum::IntoEnumIterator as _;
use super::*;
impl LlmDatabase {
pub async fn initialize_usage_measures(&mut self) -> Result<()> {
let all_measures = self
.transaction(|tx| async move {
let existing_measures = usage_measure::Entity::find().all(&*tx).await?;
let new_measures = UsageMeasure::iter()
.filter(|measure| {
!existing_measures
.iter()
.any(|m| m.name == measure.to_string())
})
.map(|measure| usage_measure::ActiveModel {
name: ActiveValue::set(measure.to_string()),
..Default::default()
})
.collect::<Vec<_>>();
if !new_measures.is_empty() {
usage_measure::Entity::insert_many(new_measures)
.exec(&*tx)
.await?;
}
Ok(usage_measure::Entity::find().all(&*tx).await?)
})
.await?;
self.usage_measure_ids = all_measures
.into_iter()
.filter_map(|measure| {
UsageMeasure::from_str(&measure.name)
.ok()
.map(|um| (um, measure.id))
})
.collect();
Ok(())
}
}

View File

@@ -1,45 +0,0 @@
use super::*;
use crate::{Config, Result};
use queries::providers::ModelParams;
pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> {
db.insert_models(&[
ModelParams {
provider: LanguageModelProvider::Anthropic,
name: "claude-3-5-sonnet".into(),
max_requests_per_minute: 5,
max_tokens_per_minute: 20_000,
max_tokens_per_day: 300_000,
price_per_million_input_tokens: 300, // $3.00/MTok
price_per_million_output_tokens: 1500, // $15.00/MTok
},
ModelParams {
provider: LanguageModelProvider::Anthropic,
name: "claude-3-opus".into(),
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 300_000,
price_per_million_input_tokens: 1500, // $15.00/MTok
price_per_million_output_tokens: 7500, // $75.00/MTok
},
ModelParams {
provider: LanguageModelProvider::Anthropic,
name: "claude-3-sonnet".into(),
max_requests_per_minute: 5,
max_tokens_per_minute: 20_000,
max_tokens_per_day: 300_000,
price_per_million_input_tokens: 1500, // $15.00/MTok
price_per_million_output_tokens: 7500, // $75.00/MTok
},
ModelParams {
provider: LanguageModelProvider::Anthropic,
name: "claude-3-haiku".into(),
max_requests_per_minute: 5,
max_tokens_per_minute: 25_000,
max_tokens_per_day: 300_000,
price_per_million_input_tokens: 25, // $0.25/MTok
price_per_million_output_tokens: 125, // $1.25/MTok
},
])
.await
}

View File

@@ -1,6 +0,0 @@
pub mod model;
pub mod provider;
pub mod subscription_usage;
pub mod subscription_usage_meter;
pub mod usage;
pub mod usage_measure;

View File

@@ -1,48 +0,0 @@
use sea_orm::entity::prelude::*;
use crate::llm::db::{ModelId, ProviderId};
/// An LLM model.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "models")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: ModelId,
pub provider_id: ProviderId,
pub name: String,
pub max_requests_per_minute: i64,
pub max_tokens_per_minute: i64,
pub max_input_tokens_per_minute: i64,
pub max_output_tokens_per_minute: i64,
pub max_tokens_per_day: i64,
pub price_per_million_input_tokens: i32,
pub price_per_million_cache_creation_input_tokens: i32,
pub price_per_million_cache_read_input_tokens: i32,
pub price_per_million_output_tokens: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::provider::Entity",
from = "Column::ProviderId",
to = "super::provider::Column::Id"
)]
Provider,
#[sea_orm(has_many = "super::usage::Entity")]
Usages,
}
impl Related<super::provider::Entity> for Entity {
fn to() -> RelationDef {
Relation::Provider.def()
}
}
impl Related<super::usage::Entity> for Entity {
fn to() -> RelationDef {
Relation::Usages.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,25 +0,0 @@
use crate::llm::db::ProviderId;
use sea_orm::entity::prelude::*;
/// An LLM provider.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "providers")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: ProviderId,
pub name: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::model::Entity")]
Models,
}
impl Related<super::model::Entity> for Entity {
fn to() -> RelationDef {
Relation::Models.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,22 +0,0 @@
use crate::db::UserId;
use crate::db::billing_subscription::SubscriptionKind;
use sea_orm::entity::prelude::*;
use time::PrimitiveDateTime;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "subscription_usages_v2")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: Uuid,
pub user_id: UserId,
pub period_start_at: PrimitiveDateTime,
pub period_end_at: PrimitiveDateTime,
pub plan: SubscriptionKind,
pub model_requests: i32,
pub edit_predictions: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,55 +0,0 @@
use sea_orm::entity::prelude::*;
use serde::Serialize;
use crate::llm::db::ModelId;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "subscription_usage_meters_v2")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: Uuid,
pub subscription_usage_id: Uuid,
pub model_id: ModelId,
pub mode: CompletionMode,
pub requests: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::subscription_usage::Entity",
from = "Column::SubscriptionUsageId",
to = "super::subscription_usage::Column::Id"
)]
SubscriptionUsage,
#[sea_orm(
belongs_to = "super::model::Entity",
from = "Column::ModelId",
to = "super::model::Column::Id"
)]
Model,
}
impl Related<super::subscription_usage::Entity> for Entity {
fn to() -> RelationDef {
Relation::SubscriptionUsage.def()
}
}
impl Related<super::model::Entity> for Entity {
fn to() -> RelationDef {
Relation::Model.def()
}
}
impl ActiveModelBehavior for ActiveModel {}
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
#[serde(rename_all = "snake_case")]
pub enum CompletionMode {
#[sea_orm(string_value = "normal")]
Normal,
#[sea_orm(string_value = "max")]
Max,
}

View File

@@ -1,52 +0,0 @@
use crate::{
db::UserId,
llm::db::{ModelId, UsageId, UsageMeasureId},
};
use sea_orm::entity::prelude::*;
/// An LLM usage record.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "usages")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: UsageId,
/// The ID of the Zed user.
///
/// Corresponds to the `users` table in the primary collab database.
pub user_id: UserId,
pub model_id: ModelId,
pub measure_id: UsageMeasureId,
pub timestamp: DateTime,
pub buckets: Vec<i64>,
pub is_staff: bool,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::model::Entity",
from = "Column::ModelId",
to = "super::model::Column::Id"
)]
Model,
#[sea_orm(
belongs_to = "super::usage_measure::Entity",
from = "Column::MeasureId",
to = "super::usage_measure::Column::Id"
)]
UsageMeasure,
}
impl Related<super::model::Entity> for Entity {
fn to() -> RelationDef {
Relation::Model.def()
}
}
impl Related<super::usage_measure::Entity> for Entity {
fn to() -> RelationDef {
Relation::UsageMeasure.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,36 +0,0 @@
use crate::llm::db::UsageMeasureId;
use sea_orm::entity::prelude::*;
#[derive(
Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter,
)]
#[strum(serialize_all = "snake_case")]
pub enum UsageMeasure {
RequestsPerMinute,
TokensPerMinute,
InputTokensPerMinute,
OutputTokensPerMinute,
TokensPerDay,
}
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "usage_measures")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: UsageMeasureId,
pub name: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::usage::Entity")]
Usages,
}
impl Related<super::usage::Entity> for Entity {
fn to() -> RelationDef {
Relation::Usages.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,107 +0,0 @@
mod provider_tests;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;
use rand::prelude::*;
use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase;
use std::time::Duration;
use crate::migrations::run_database_migrations;
use super::*;
pub struct TestLlmDb {
pub db: Option<LlmDatabase>,
pub connection: Option<sqlx::AnyConnection>,
}
impl TestLlmDb {
pub fn postgres(background: BackgroundExecutor) -> Self {
static LOCK: Mutex<()> = Mutex::new(());
let _guard = LOCK.lock();
let mut rng = StdRng::from_entropy();
let url = format!(
"postgres://postgres@localhost/zed-llm-test-{}",
rng.r#gen::<u128>()
);
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap();
let mut db = runtime.block_on(async {
sqlx::Postgres::create_database(&url)
.await
.expect("failed to create test db");
let mut options = ConnectOptions::new(url);
options
.max_connections(5)
.idle_timeout(Duration::from_secs(0));
let db = LlmDatabase::new(options, Executor::Deterministic(background))
.await
.unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
run_database_migrations(db.options(), migrations_path)
.await
.unwrap();
db
});
db.runtime = Some(runtime);
Self {
db: Some(db),
connection: None,
}
}
pub fn db(&mut self) -> &mut LlmDatabase {
self.db.as_mut().unwrap()
}
}
#[macro_export]
macro_rules! test_llm_db {
($test_name:ident, $postgres_test_name:ident) => {
#[gpui::test]
async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
if !cfg!(target_os = "macos") {
return;
}
let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
$test_name(test_db.db()).await;
}
};
}
impl Drop for TestLlmDb {
fn drop(&mut self) {
let db = self.db.take().unwrap();
if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
db.runtime.as_ref().unwrap().block_on(async {
use util::ResultExt;
let query = "
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE
pg_stat_activity.datname = current_database() AND
pid <> pg_backend_pid();
";
db.pool
.execute(sea_orm::Statement::from_string(
db.pool.get_database_backend(),
query,
))
.await
.log_err();
sqlx::Postgres::drop_database(db.options.get_url())
.await
.log_err();
})
}
}
}

View File

@@ -1,31 +0,0 @@
use cloud_llm_client::LanguageModelProvider;
use pretty_assertions::assert_eq;
use crate::llm::db::LlmDatabase;
use crate::test_llm_db;
test_llm_db!(
test_initialize_providers,
test_initialize_providers_postgres
);
async fn test_initialize_providers(db: &mut LlmDatabase) {
let initial_providers = db.list_providers().await.unwrap();
assert_eq!(initial_providers, vec![]);
db.initialize_providers().await.unwrap();
// Do it twice, to make sure the operation is idempotent.
db.initialize_providers().await.unwrap();
let providers = db.list_providers().await.unwrap();
assert_eq!(
providers,
&[
LanguageModelProvider::Anthropic,
LanguageModelProvider::Google,
LanguageModelProvider::OpenAi,
]
)
}

View File

@@ -1,146 +0,0 @@
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::{billing_customer, billing_subscription, user};
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG};
use crate::{Config, db::billing_preference};
use anyhow::{Context as _, Result};
use chrono::{NaiveDateTime, Utc};
use cloud_llm_client::Plan;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
use uuid::Uuid;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LlmTokenClaims {
pub iat: u64,
pub exp: u64,
pub jti: String,
pub user_id: u64,
pub system_id: Option<String>,
pub metrics_id: Uuid,
pub github_user_login: String,
pub account_created_at: NaiveDateTime,
pub is_staff: bool,
pub has_llm_closed_beta_feature_flag: bool,
pub bypass_account_age_check: bool,
pub use_llm_request_queue: bool,
pub plan: Plan,
pub has_extended_trial: bool,
pub subscription_period: (NaiveDateTime, NaiveDateTime),
pub enable_model_request_overages: bool,
pub model_request_overages_spend_limit_in_cents: u32,
pub can_use_web_search_tool: bool,
#[serde(default)]
pub has_overdue_invoices: bool,
}
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
impl LlmTokenClaims {
pub fn create(
user: &user::Model,
is_staff: bool,
billing_customer: billing_customer::Model,
billing_preferences: Option<billing_preference::Model>,
feature_flags: &Vec<String>,
subscription: billing_subscription::Model,
system_id: Option<String>,
config: &Config,
) -> Result<String> {
let secret = config
.llm_api_secret
.as_ref()
.context("no LLM API secret")?;
let plan = if is_staff {
Plan::ZedPro
} else {
subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
SubscriptionKind::ZedFree => Plan::ZedFree,
SubscriptionKind::ZedPro => Plan::ZedPro,
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
})
};
let subscription_period =
billing_subscription::Model::current_period(Some(subscription), is_staff)
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
.context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?;
let now = Utc::now();
let claims = Self {
iat: now.timestamp() as u64,
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
jti: uuid::Uuid::new_v4().to_string(),
user_id: user.id.to_proto(),
system_id,
metrics_id: user.metrics_id,
github_user_login: user.github_login.clone(),
account_created_at: user.account_created_at(),
is_staff,
has_llm_closed_beta_feature_flag: feature_flags
.iter()
.any(|flag| flag == "llm-closed-beta"),
bypass_account_age_check: feature_flags
.iter()
.any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG),
can_use_web_search_tool: true,
use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
plan,
has_extended_trial: feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
subscription_period,
enable_model_request_overages: billing_preferences
.as_ref()
.map_or(false, |preferences| {
preferences.model_request_overages_enabled
}),
model_request_overages_spend_limit_in_cents: billing_preferences
.as_ref()
.map_or(0, |preferences| {
preferences.model_request_overages_spend_limit_in_cents as u32
}),
has_overdue_invoices: billing_customer.has_overdue_invoices,
};
Ok(jsonwebtoken::encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_ref()),
)?)
}
pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
let secret = config
.llm_api_secret
.as_ref()
.context("no LLM API secret")?;
match jsonwebtoken::decode::<Self>(
token,
&DecodingKey::from_secret(secret.as_ref()),
&Validation::default(),
) {
Ok(token) => Ok(token.claims),
Err(e) => {
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
Err(ValidateLlmTokenError::Expired)
} else {
Err(ValidateLlmTokenError::JwtError(e))
}
}
}
}
}
#[derive(Error, Debug)]
pub enum ValidateLlmTokenError {
#[error("access token is expired")]
Expired,
#[error("access token validation error: {0}")]
JwtError(#[from] jsonwebtoken::errors::Error),
#[error("{0}")]
Other(#[from] anyhow::Error),
}

View File

@@ -62,13 +62,6 @@ async fn main() -> Result<()> {
db.initialize_notification_kinds().await?;
collab::seed::seed(&config, &db, false).await?;
if let Some(llm_database_url) = config.llm_database_url.clone() {
let db_options = db::ConnectOptions::new(llm_database_url);
let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
db.initialize().await?;
collab::llm::db::seed_database(&config, &mut db, true).await?;
}
}
Some("serve") => {
let mode = match args.next().as_deref() {
@@ -102,13 +95,6 @@ async fn main() -> Result<()> {
let state = AppState::new(config, Executor::Production).await?;
if let Some(stripe_billing) = state.stripe_billing.clone() {
let executor = state.executor.clone();
executor.spawn_detached(async move {
stripe_billing.initialize().await.trace_err();
});
}
if mode.is_collab() {
state.db.purge_old_embeddings().await.trace_err();
@@ -270,9 +256,6 @@ async fn setup_llm_database(config: &Config) -> Result<()> {
.llm_database_migrations_path
.as_deref()
.unwrap_or_else(|| {
#[cfg(feature = "sqlite")]
let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite");
#[cfg(not(feature = "sqlite"))]
let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
Path::new(default_migrations)

View File

@@ -1,14 +1,6 @@
mod connection_pool;
use crate::api::billing::find_or_create_billing_customer;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
use crate::llm::{
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
MIN_ACCOUNT_AGE_FOR_LLM_USE,
};
use crate::stripe_client::StripeCustomerId;
use crate::{
AppState, Error, Result, auth,
db::{
@@ -37,7 +29,6 @@ use axum::{
response::IntoResponse,
routing::get,
};
use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
@@ -148,13 +139,6 @@ pub enum Principal {
}
impl Principal {
fn user(&self) -> &User {
match self {
Principal::User(user) => user,
Principal::Impersonated { user, .. } => user,
}
}
fn update_span(&self, span: &tracing::Span) {
match &self {
Principal::User(user) => {
@@ -218,6 +202,7 @@ struct Session {
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
#[allow(unused)]
system_id: Option<String>,
_executor: Executor,
}
@@ -463,9 +448,6 @@ impl Server {
.add_request_handler(follow)
.add_message_handler(unfollow)
.add_message_handler(update_followers)
.add_request_handler(get_private_user_info)
.add_request_handler(get_llm_api_token)
.add_request_handler(accept_terms_of_service)
.add_message_handler(acknowledge_channel_message)
.add_message_handler(acknowledge_buffer_version)
.add_request_handler(get_supermaven_api_key)
@@ -1000,8 +982,6 @@ impl Server {
.await?;
}
update_user_plan(session).await?;
let contacts = self.app_state.db.get_contacts(user.id).await?;
{
@@ -1081,53 +1061,6 @@ impl Server {
Ok(())
}
pub async fn update_plan_for_user(
self: &Arc<Self>,
user_id: UserId,
update_user_plan: proto::UpdateUserPlan,
) -> Result<()> {
let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) {
self.peer
.send(connection_id, update_user_plan.clone())
.trace_err();
}
Ok(())
}
/// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
/// message on the Collab server.
///
/// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
let user = self
.app_state
.db
.get_user_by_id(user_id)
.await?
.context("user not found")?;
let update_user_plan = make_update_user_plan_message(
&user,
user.admin,
&self.app_state.db,
self.app_state.llm_db.clone(),
)
.await?;
self.update_plan_for_user(user_id, update_user_plan).await
}
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) {
self.peer
.send(connection_id, proto::RefreshLlmToken {})
.trace_err();
}
}
pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
ServerSnapshot {
connection_pool: ConnectionPoolGuard {
@@ -2882,214 +2815,6 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
version.0.minor() < 139
}
async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
if is_staff {
return Ok(proto::Plan::ZedPro);
}
let subscription = db.get_active_billing_subscription(user_id).await?;
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
let plan = if let Some(subscription_kind) = subscription_kind {
match subscription_kind {
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
SubscriptionKind::ZedFree => proto::Plan::Free,
}
} else {
proto::Plan::Free
};
Ok(plan)
}
async fn make_update_user_plan_message(
user: &User,
is_staff: bool,
db: &Arc<Database>,
llm_db: Option<Arc<LlmDatabase>>,
) -> Result<proto::UpdateUserPlan> {
let feature_flags = db.get_user_flags(user.id).await?;
let plan = current_plan(db, user.id, is_staff).await?;
let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
let billing_preferences = db.get_billing_preferences(user.id).await?;
let (subscription_period, usage) = if let Some(llm_db) = llm_db {
let subscription = db.get_active_billing_subscription(user.id).await?;
let subscription_period =
crate::db::billing_subscription::Model::current_period(subscription, is_staff);
let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
llm_db
.get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
.await?
} else {
None
};
(subscription_period, usage)
} else {
(None, None)
};
let bypass_account_age_check = feature_flags
.iter()
.any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
let account_too_young = !matches!(plan, proto::Plan::ZedPro)
&& !bypass_account_age_check
&& user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
Ok(proto::UpdateUserPlan {
plan: plan.into(),
trial_started_at: billing_customer
.as_ref()
.and_then(|billing_customer| billing_customer.trial_started_at)
.map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
is_usage_based_billing_enabled: if is_staff {
Some(true)
} else {
billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
},
subscription_period: subscription_period.map(|(started_at, ended_at)| {
proto::SubscriptionPeriod {
started_at: started_at.timestamp() as u64,
ended_at: ended_at.timestamp() as u64,
}
}),
account_too_young: Some(account_too_young),
has_overdue_invoices: billing_customer
.map(|billing_customer| billing_customer.has_overdue_invoices),
usage: Some(
usage
.map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
.unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
),
})
}
fn model_requests_limit(
plan: cloud_llm_client::Plan,
feature_flags: &Vec<String>,
) -> cloud_llm_client::UsageLimit {
match plan.model_requests_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
let limit = if plan == cloud_llm_client::Plan::ZedProTrial
&& feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
{
1_000
} else {
limit
};
cloud_llm_client::UsageLimit::Limited(limit)
}
cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
}
}
fn subscription_usage_to_proto(
plan: proto::Plan,
usage: crate::llm::db::subscription_usage::Model,
feature_flags: &Vec<String>,
) -> proto::SubscriptionUsage {
let plan = match plan {
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
};
proto::SubscriptionUsage {
model_requests_usage_amount: usage.model_requests as u32,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit(plan, feature_flags) {
cloud_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
edit_predictions_usage_amount: usage.edit_predictions as u32,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
}
}
fn make_default_subscription_usage(
plan: proto::Plan,
feature_flags: &Vec<String>,
) -> proto::SubscriptionUsage {
let plan = match plan {
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
};
proto::SubscriptionUsage {
model_requests_usage_amount: 0,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit(plan, feature_flags) {
cloud_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
edit_predictions_usage_amount: 0,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
}
}
async fn update_user_plan(session: &Session) -> Result<()> {
let db = session.db().await;
let update_user_plan = make_update_user_plan_message(
session.principal.user(),
session.is_staff(),
&db.0,
session.app_state.llm_db.clone(),
)
.await?;
session
.peer
.send(session.connection_id, update_user_plan)
.trace_err();
Ok(())
}
async fn subscribe_to_channels(
_: proto::SubscribeToChannels,
session: MessageContext,
@@ -4258,139 +3983,6 @@ async fn mark_notification_as_read(
Ok(())
}
/// Get the current users information
async fn get_private_user_info(
_request: proto::GetPrivateUserInfo,
response: Response<proto::GetPrivateUserInfo>,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
let user = db
.get_user_by_id(session.user_id())
.await?
.context("user not found")?;
let flags = db.get_user_flags(session.user_id()).await?;
response.send(proto::GetPrivateUserInfoResponse {
metrics_id,
staff: user.admin,
flags,
accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
})?;
Ok(())
}
/// Accept the terms of service (tos) on behalf of the current user
async fn accept_terms_of_service(
_request: proto::AcceptTermsOfService,
response: Response<proto::AcceptTermsOfService>,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let accepted_tos_at = Utc::now();
db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
.await?;
response.send(proto::AcceptTermsOfServiceResponse {
accepted_tos_at: accepted_tos_at.timestamp() as u64,
})?;
// When the user accepts the terms of service, we want to refresh their LLM
// token to grant access.
session
.peer
.send(session.connection_id, proto::RefreshLlmToken {})?;
Ok(())
}
async fn get_llm_api_token(
_request: proto::GetLlmToken,
response: Response<proto::GetLlmToken>,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?;
let user_id = session.user_id();
let user = db
.get_user_by_id(user_id)
.await?
.with_context(|| format!("user {user_id} not found"))?;
if user.accepted_tos_at.is_none() {
Err(anyhow!("terms of service not accepted"))?
}
let stripe_client = session
.app_state
.stripe_client
.as_ref()
.context("failed to retrieve Stripe client")?;
let stripe_billing = session
.app_state
.stripe_billing
.as_ref()
.context("failed to retrieve Stripe billing object")?;
let billing_customer = if let Some(billing_customer) =
db.get_billing_customer_by_user_id(user.id).await?
{
billing_customer
} else {
let customer_id = stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
.await?;
find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
.await?
.context("billing customer not found")?
};
let billing_subscription =
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
billing_subscription
} else {
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_subscription = stripe_billing
.subscribe_to_zed_free(stripe_customer_id)
.await?;
db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
billing_customer_id: billing_customer.id,
kind: Some(SubscriptionKind::ZedFree),
stripe_subscription_id: stripe_subscription.id.to_string(),
stripe_subscription_status: stripe_subscription.status.into(),
stripe_cancellation_reason: None,
stripe_current_period_start: Some(stripe_subscription.current_period_start),
stripe_current_period_end: Some(stripe_subscription.current_period_end),
})
.await?
};
let billing_preferences = db.get_billing_preferences(user.id).await?;
let token = LlmTokenClaims::create(
&user,
session.is_staff(),
billing_customer,
billing_preferences,
&flags,
billing_subscription,
session.system_id.clone(),
&session.app_state.config,
)?;
response.send(proto::GetLlmTokenResponse { token })?;
Ok(())
}
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
let message = match message {
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),

View File

@@ -30,7 +30,19 @@ impl fmt::Display for ZedVersion {
impl ZedVersion {
pub fn can_collaborate(&self) -> bool {
self.0 >= SemanticVersion::new(0, 157, 0)
// v0.198.4 is the first version where we no longer connect to Collab automatically.
// We reject any clients older than that to prevent them from connecting to Collab just for authentication.
if self.0 < SemanticVersion::new(0, 198, 4) {
return false;
}
// Since we hotfixed the changes to no longer connect to Collab automatically to Preview, we also need to reject
// versions in the range [v0.199.0, v0.199.1].
if self.0 >= SemanticVersion::new(0, 199, 0) && self.0 < SemanticVersion::new(0, 199, 2) {
return false;
}
true
}
}

View File

@@ -1,156 +0,0 @@
use std::sync::Arc;
use anyhow::anyhow;
use collections::HashMap;
use stripe::SubscriptionStatus;
use tokio::sync::RwLock;
use crate::Result;
use crate::stripe_client::{
RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems,
StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId,
StripeSubscription,
};
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
client: Arc<dyn StripeClient>,
}
#[derive(Default)]
struct StripeBillingState {
prices_by_lookup_key: HashMap<String, StripePrice>,
}
impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client: Arc::new(RealStripeClient::new(client.clone())),
state: RwLock::default(),
}
}
#[cfg(test)]
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
Self {
client,
state: RwLock::default(),
}
}
pub fn client(&self) -> &Arc<dyn StripeClient> {
&self.client
}
pub async fn initialize(&self) -> Result<()> {
log::info!("StripeBilling: initializing");
let mut state = self.state.write().await;
let prices = self.client.list_prices().await?;
for price in prices {
if let Some(lookup_key) = price.lookup_key.clone() {
state.prices_by_lookup_key.insert(lookup_key, price);
}
}
log::info!("StripeBilling: initialized");
Ok(())
}
pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
self.find_price_id_by_lookup_key("zed-pro").await
}
pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
self.find_price_id_by_lookup_key("zed-free").await
}
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
self.state
.read()
.await
.prices_by_lookup_key
.get(lookup_key)
.map(|price| price.id.clone())
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
self.state
.read()
.await
.prices_by_lookup_key
.get(lookup_key)
.cloned()
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
}
/// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
/// not already exist.
///
/// Always returns a new Stripe customer if the email address is `None`.
pub async fn find_or_create_customer_by_email(
&self,
email_address: Option<&str>,
) -> Result<StripeCustomerId> {
let existing_customer = if let Some(email) = email_address {
let customers = self.client.list_customers_by_email(email).await?;
customers.first().cloned()
} else {
None
};
let customer_id = if let Some(existing_customer) = existing_customer {
existing_customer.id
} else {
let customer = self
.client
.create_customer(crate::stripe_client::CreateCustomerParams {
email: email_address,
})
.await?;
customer.id
};
Ok(customer_id)
}
pub async fn subscribe_to_zed_free(
&self,
customer_id: StripeCustomerId,
) -> Result<StripeSubscription> {
let zed_free_price_id = self.zed_free_price_id().await?;
let existing_subscriptions = self
.client
.list_subscriptions_for_customer(&customer_id)
.await?;
let existing_active_subscription =
existing_subscriptions.into_iter().find(|subscription| {
subscription.status == SubscriptionStatus::Active
|| subscription.status == SubscriptionStatus::Trialing
});
if let Some(subscription) = existing_active_subscription {
return Ok(subscription);
}
let params = StripeCreateSubscriptionParams {
customer: customer_id,
items: vec![StripeCreateSubscriptionItems {
price: Some(zed_free_price_id),
quantity: Some(1),
}],
automatic_tax: Some(StripeAutomaticTax { enabled: true }),
};
let subscription = self.client.create_subscription(params).await?;
Ok(subscription)
}
}

View File

@@ -1,285 +0,0 @@
#[cfg(test)]
mod fake_stripe_client;
mod real_stripe_client;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
#[cfg(test)]
pub use fake_stripe_client::*;
pub use real_stripe_client::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)]
pub struct StripeCustomerId(pub Arc<str>);
#[derive(Debug, Clone)]
pub struct StripeCustomer {
pub id: StripeCustomerId,
pub email: Option<String>,
}
#[derive(Debug)]
pub struct CreateCustomerParams<'a> {
pub email: Option<&'a str>,
}
#[derive(Debug)]
pub struct UpdateCustomerParams<'a> {
pub email: Option<&'a str>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscription {
pub id: StripeSubscriptionId,
pub customer: StripeCustomerId,
// TODO: Create our own version of this enum.
pub status: stripe::SubscriptionStatus,
pub current_period_end: i64,
pub current_period_start: i64,
pub items: Vec<StripeSubscriptionItem>,
pub cancel_at: Option<i64>,
pub cancellation_details: Option<StripeCancellationDetails>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionItemId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionItem {
pub id: StripeSubscriptionItemId,
pub price: Option<StripePrice>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct StripeCancellationDetails {
pub reason: Option<StripeCancellationDetailsReason>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCancellationDetailsReason {
CancellationRequested,
PaymentDisputed,
PaymentFailed,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionParams {
pub customer: StripeCustomerId,
pub items: Vec<StripeCreateSubscriptionItems>,
pub automatic_tax: Option<StripeAutomaticTax>,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionItems {
pub price: Option<StripePriceId>,
pub quantity: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct UpdateSubscriptionParams {
pub items: Option<Vec<UpdateSubscriptionItems>>,
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct UpdateSubscriptionItems {
pub price: Option<StripePriceId>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionTrialSettings {
pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionTrialSettingsEndBehavior {
pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
Cancel,
CreateInvoice,
Pause,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripePriceId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripePrice {
pub id: StripePriceId,
pub unit_amount: Option<i64>,
pub lookup_key: Option<String>,
pub recurring: Option<StripePriceRecurring>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripePriceRecurring {
pub meter: Option<String>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)]
pub struct StripeMeterId(pub Arc<str>);
#[derive(Debug, Clone, Deserialize)]
pub struct StripeMeter {
pub id: StripeMeterId,
pub event_name: String,
}
#[derive(Debug, Serialize)]
pub struct StripeCreateMeterEventParams<'a> {
pub identifier: &'a str,
pub event_name: &'a str,
pub payload: StripeCreateMeterEventPayload<'a>,
pub timestamp: Option<i64>,
}
#[derive(Debug, Serialize)]
pub struct StripeCreateMeterEventPayload<'a> {
pub value: u64,
pub stripe_customer_id: &'a StripeCustomerId,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeBillingAddressCollection {
Auto,
Required,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCustomerUpdate {
pub address: Option<StripeCustomerUpdateAddress>,
pub name: Option<StripeCustomerUpdateName>,
pub shipping: Option<StripeCustomerUpdateShipping>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCustomerUpdateAddress {
Auto,
Never,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCustomerUpdateName {
Auto,
Never,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCustomerUpdateShipping {
Auto,
Never,
}
#[derive(Debug, Default)]
pub struct StripeCreateCheckoutSessionParams<'a> {
pub customer: Option<&'a StripeCustomerId>,
pub client_reference_id: Option<&'a str>,
pub mode: Option<StripeCheckoutSessionMode>,
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<&'a str>,
pub billing_address_collection: Option<StripeBillingAddressCollection>,
pub customer_update: Option<StripeCustomerUpdate>,
pub tax_id_collection: Option<StripeTaxIdCollection>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCheckoutSessionMode {
Payment,
Setup,
Subscription,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCreateCheckoutSessionLineItems {
pub price: Option<String>,
pub quantity: Option<u64>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCheckoutSessionPaymentMethodCollection {
Always,
IfRequired,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCreateCheckoutSessionSubscriptionData {
pub metadata: Option<HashMap<String, String>>,
pub trial_period_days: Option<u32>,
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeTaxIdCollection {
pub enabled: bool,
}
#[derive(Debug, Clone)]
pub struct StripeAutomaticTax {
pub enabled: bool,
}
#[derive(Debug)]
pub struct StripeCheckoutSession {
pub url: Option<String>,
}
#[async_trait]
pub trait StripeClient: Send + Sync {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer>;
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
async fn update_customer(
&self,
customer_id: &StripeCustomerId,
params: UpdateCustomerParams<'_>,
) -> Result<StripeCustomer>;
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>>;
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription>;
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription>;
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()>;
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>;
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession>;
}

View File

@@ -1,247 +0,0 @@
use std::sync::Arc;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use chrono::{Duration, Utc};
use collections::HashMap;
use parking_lot::Mutex;
use uuid::Uuid;
use crate::stripe_client::{
CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession,
StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate,
StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection,
UpdateCustomerParams, UpdateSubscriptionParams,
};
#[derive(Debug, Clone)]
pub struct StripeCreateMeterEventCall {
pub identifier: Arc<str>,
pub event_name: Arc<str>,
pub value: u64,
pub stripe_customer_id: StripeCustomerId,
pub timestamp: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct StripeCreateCheckoutSessionCall {
pub customer: Option<StripeCustomerId>,
pub client_reference_id: Option<String>,
pub mode: Option<StripeCheckoutSessionMode>,
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<String>,
pub billing_address_collection: Option<StripeBillingAddressCollection>,
pub customer_update: Option<StripeCustomerUpdate>,
pub tax_id_collection: Option<StripeTaxIdCollection>,
}
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
pub update_subscription_calls:
Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
}
impl FakeStripeClient {
pub fn new() -> Self {
Self {
customers: Arc::new(Mutex::new(HashMap::default())),
subscriptions: Arc::new(Mutex::new(HashMap::default())),
update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
prices: Arc::new(Mutex::new(HashMap::default())),
meters: Arc::new(Mutex::new(HashMap::default())),
create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
}
}
}
#[async_trait]
impl StripeClient for FakeStripeClient {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
Ok(self
.customers
.lock()
.values()
.filter(|customer| customer.email.as_deref() == Some(email))
.cloned()
.collect())
}
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
self.customers
.lock()
.get(customer_id)
.cloned()
.ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
}
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
let customer = StripeCustomer {
id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
email: params.email.map(|email| email.to_string()),
};
self.customers
.lock()
.insert(customer.id.clone(), customer.clone());
Ok(customer)
}
async fn update_customer(
&self,
customer_id: &StripeCustomerId,
params: UpdateCustomerParams<'_>,
) -> Result<StripeCustomer> {
let mut customers = self.customers.lock();
if let Some(customer) = customers.get_mut(customer_id) {
if let Some(email) = params.email {
customer.email = Some(email.to_string());
}
Ok(customer.clone())
} else {
Err(anyhow!("no customer found for {customer_id:?}"))
}
}
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let subscriptions = self
.subscriptions
.lock()
.values()
.filter(|subscription| subscription.customer == *customer_id)
.cloned()
.collect();
Ok(subscriptions)
}
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription> {
self.subscriptions
.lock()
.get(subscription_id)
.cloned()
.ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
}
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
customer: params.customer,
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: params
.items
.into_iter()
.map(|item| StripeSubscriptionItem {
id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
price: item
.price
.and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
})
.collect(),
cancel_at: None,
cancellation_details: None,
};
self.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
Ok(subscription)
}
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()> {
let subscription = self.get_subscription(subscription_id).await?;
self.update_subscription_calls
.lock()
.push((subscription.id, params));
Ok(())
}
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
// TODO: Implement fake subscription cancellation.
let _ = subscription_id;
Ok(())
}
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let prices = self.prices.lock().values().cloned().collect();
Ok(prices)
}
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
let meters = self.meters.lock().values().cloned().collect();
Ok(meters)
}
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
self.create_meter_event_calls
.lock()
.push(StripeCreateMeterEventCall {
identifier: params.identifier.into(),
event_name: params.event_name.into(),
value: params.payload.value,
stripe_customer_id: params.payload.stripe_customer_id.clone(),
timestamp: params.timestamp,
});
Ok(())
}
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession> {
self.create_checkout_session_calls
.lock()
.push(StripeCreateCheckoutSessionCall {
customer: params.customer.cloned(),
client_reference_id: params.client_reference_id.map(|id| id.to_string()),
mode: params.mode,
line_items: params.line_items,
payment_method_collection: params.payment_method_collection,
subscription_data: params.subscription_data,
success_url: params.success_url.map(|url| url.to_string()),
billing_address_collection: params.billing_address_collection,
customer_update: params.customer_update,
tax_id_collection: params.tax_id_collection,
});
Ok(StripeCheckoutSession {
url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
})
}
}

View File

@@ -1,612 +0,0 @@
use std::str::FromStr as _;
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use stripe::{
CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode,
CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems,
CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings,
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
CreateCustomer, CreateSubscriptionAutomaticTax, Customer, CustomerId, ListCustomers, Price,
PriceId, Recurring, Subscription, SubscriptionId, SubscriptionItem, SubscriptionItemId,
UpdateCustomer, UpdateSubscriptionItems, UpdateSubscriptionTrialSettings,
UpdateSubscriptionTrialSettingsEndBehavior,
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
};
use crate::stripe_client::{
CreateCustomerParams, StripeAutomaticTax, StripeBillingAddressCollection,
StripeCancellationDetails, StripeCancellationDetailsReason, StripeCheckoutSession,
StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate,
StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeCustomerUpdateShipping,
StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription,
StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection,
UpdateCustomerParams, UpdateSubscriptionParams,
};
pub struct RealStripeClient {
client: Arc<stripe::Client>,
}
impl RealStripeClient {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self { client }
}
}
#[async_trait]
impl StripeClient for RealStripeClient {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
let response = Customer::list(
&self.client,
&ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
Ok(response
.data
.into_iter()
.map(StripeCustomer::from)
.collect())
}
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
let customer_id = customer_id.try_into()?;
let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?;
Ok(StripeCustomer::from(customer))
}
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
let customer = Customer::create(
&self.client,
CreateCustomer {
email: params.email,
..Default::default()
},
)
.await?;
Ok(StripeCustomer::from(customer))
}
async fn update_customer(
&self,
customer_id: &StripeCustomerId,
params: UpdateCustomerParams<'_>,
) -> Result<StripeCustomer> {
let customer = Customer::update(
&self.client,
&customer_id.try_into()?,
UpdateCustomer {
email: params.email,
..Default::default()
},
)
.await?;
Ok(StripeCustomer::from(customer))
}
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let customer_id = customer_id.try_into()?;
let subscriptions = stripe::Subscription::list(
&self.client,
&stripe::ListSubscriptions {
customer: Some(customer_id),
status: None,
..Default::default()
},
)
.await?;
Ok(subscriptions
.data
.into_iter()
.map(StripeSubscription::from)
.collect())
}
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription> {
let subscription_id = subscription_id.try_into()?;
let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
Ok(StripeSubscription::from(subscription))
}
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let customer_id = params.customer.try_into()?;
let mut create_subscription = stripe::CreateSubscription::new(customer_id);
create_subscription.items = Some(
params
.items
.into_iter()
.map(|item| stripe::CreateSubscriptionItems {
price: item.price.map(|price| price.to_string()),
quantity: item.quantity,
..Default::default()
})
.collect(),
);
create_subscription.automatic_tax = params.automatic_tax.map(Into::into);
let subscription = Subscription::create(&self.client, create_subscription).await?;
Ok(StripeSubscription::from(subscription))
}
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()> {
let subscription_id = subscription_id.try_into()?;
stripe::Subscription::update(
&self.client,
&subscription_id,
stripe::UpdateSubscription {
items: params.items.map(|items| {
items
.into_iter()
.map(|item| UpdateSubscriptionItems {
price: item.price.map(|price| price.to_string()),
..Default::default()
})
.collect()
}),
trial_settings: params.trial_settings.map(Into::into),
..Default::default()
},
)
.await?;
Ok(())
}
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
let subscription_id = subscription_id.try_into()?;
Subscription::cancel(
&self.client,
&subscription_id,
stripe::CancelSubscription {
invoice_now: None,
..Default::default()
},
)
.await?;
Ok(())
}
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let response = stripe::Price::list(
&self.client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
},
)
.await?;
Ok(response.data.into_iter().map(StripePrice::from).collect())
}
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
#[derive(Serialize)]
struct Params {
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u64>,
}
let response = self
.client
.get_query::<stripe::List<StripeMeter>, _>(
"/billing/meters",
Params { limit: Some(100) },
)
.await?;
Ok(response.data)
}
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
#[derive(Deserialize)]
struct StripeMeterEvent {
pub identifier: String,
}
let identifier = params.identifier;
match self
.client
.post_form::<StripeMeterEvent, _>("/billing/meter_events", params)
.await
{
Ok(_event) => Ok(()),
Err(stripe::StripeError::Stripe(error)) => {
if error.http_status == 400
&& error
.message
.as_ref()
.map_or(false, |message| message.contains(identifier))
{
Ok(())
} else {
Err(anyhow!(stripe::StripeError::Stripe(error)))
}
}
Err(error) => Err(anyhow!("failed to create meter event: {error:?}")),
}
}
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession> {
let params = params.try_into()?;
let session = CheckoutSession::create(&self.client, params).await?;
Ok(session.into())
}
}
impl From<CustomerId> for StripeCustomerId {
fn from(value: CustomerId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<StripeCustomerId> for CustomerId {
type Error = anyhow::Error;
fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
}
}
impl TryFrom<&StripeCustomerId> for CustomerId {
type Error = anyhow::Error;
fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
}
}
impl From<Customer> for StripeCustomer {
fn from(value: Customer) -> Self {
StripeCustomer {
id: value.id.into(),
email: value.email,
}
}
}
impl From<SubscriptionId> for StripeSubscriptionId {
fn from(value: SubscriptionId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<&StripeSubscriptionId> for SubscriptionId {
type Error = anyhow::Error;
fn try_from(value: &StripeSubscriptionId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID")
}
}
impl From<Subscription> for StripeSubscription {
fn from(value: Subscription) -> Self {
Self {
id: value.id.into(),
customer: value.customer.id().into(),
status: value.status,
current_period_start: value.current_period_start,
current_period_end: value.current_period_end,
items: value.items.data.into_iter().map(Into::into).collect(),
cancel_at: value.cancel_at,
cancellation_details: value.cancellation_details.map(Into::into),
}
}
}
impl From<CancellationDetails> for StripeCancellationDetails {
fn from(value: CancellationDetails) -> Self {
Self {
reason: value.reason.map(Into::into),
}
}
}
impl From<CancellationDetailsReason> for StripeCancellationDetailsReason {
fn from(value: CancellationDetailsReason) -> Self {
match value {
CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
}
}
}
impl From<SubscriptionItemId> for StripeSubscriptionItemId {
fn from(value: SubscriptionItemId) -> Self {
Self(value.as_str().into())
}
}
impl From<SubscriptionItem> for StripeSubscriptionItem {
fn from(value: SubscriptionItem) -> Self {
Self {
id: value.id.into(),
price: value.price.map(Into::into),
}
}
}
impl From<StripeAutomaticTax> for CreateSubscriptionAutomaticTax {
fn from(value: StripeAutomaticTax) -> Self {
Self {
enabled: value.enabled,
liability: None,
}
}
}
impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehavior>
for UpdateSubscriptionTrialSettingsEndBehavior
{
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
{
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
impl From<PriceId> for StripePriceId {
fn from(value: PriceId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<StripePriceId> for PriceId {
type Error = anyhow::Error;
fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
}
}
impl From<Price> for StripePrice {
fn from(value: Price) -> Self {
Self {
id: value.id.into(),
unit_amount: value.unit_amount,
lookup_key: value.lookup_key,
recurring: value.recurring.map(StripePriceRecurring::from),
}
}
}
impl From<Recurring> for StripePriceRecurring {
fn from(value: Recurring) -> Self {
Self { meter: value.meter }
}
}
impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
type Error = anyhow::Error;
fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
Ok(Self {
customer: value
.customer
.map(|customer_id| customer_id.try_into())
.transpose()?,
client_reference_id: value.client_reference_id,
mode: value.mode.map(Into::into),
line_items: value
.line_items
.map(|line_items| line_items.into_iter().map(Into::into).collect()),
payment_method_collection: value.payment_method_collection.map(Into::into),
subscription_data: value.subscription_data.map(Into::into),
success_url: value.success_url,
billing_address_collection: value.billing_address_collection.map(Into::into),
customer_update: value.customer_update.map(Into::into),
tax_id_collection: value.tax_id_collection.map(Into::into),
..Default::default()
})
}
}
impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
fn from(value: StripeCheckoutSessionMode) -> Self {
match value {
StripeCheckoutSessionMode::Payment => Self::Payment,
StripeCheckoutSessionMode::Setup => Self::Setup,
StripeCheckoutSessionMode::Subscription => Self::Subscription,
}
}
}
impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
Self {
price: value.price,
quantity: value.quantity,
..Default::default()
}
}
}
impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
match value {
StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
}
}
}
impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
Self {
trial_period_days: value.trial_period_days,
trial_settings: value.trial_settings.map(Into::into),
metadata: value.metadata,
..Default::default()
}
}
}
impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehavior>
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
{
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
{
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
impl From<CheckoutSession> for StripeCheckoutSession {
fn from(value: CheckoutSession) -> Self {
Self { url: value.url }
}
}
impl From<StripeBillingAddressCollection> for stripe::CheckoutSessionBillingAddressCollection {
fn from(value: StripeBillingAddressCollection) -> Self {
match value {
StripeBillingAddressCollection::Auto => {
stripe::CheckoutSessionBillingAddressCollection::Auto
}
StripeBillingAddressCollection::Required => {
stripe::CheckoutSessionBillingAddressCollection::Required
}
}
}
}
impl From<StripeCustomerUpdateAddress> for stripe::CreateCheckoutSessionCustomerUpdateAddress {
fn from(value: StripeCustomerUpdateAddress) -> Self {
match value {
StripeCustomerUpdateAddress::Auto => {
stripe::CreateCheckoutSessionCustomerUpdateAddress::Auto
}
StripeCustomerUpdateAddress::Never => {
stripe::CreateCheckoutSessionCustomerUpdateAddress::Never
}
}
}
}
impl From<StripeCustomerUpdateName> for stripe::CreateCheckoutSessionCustomerUpdateName {
fn from(value: StripeCustomerUpdateName) -> Self {
match value {
StripeCustomerUpdateName::Auto => stripe::CreateCheckoutSessionCustomerUpdateName::Auto,
StripeCustomerUpdateName::Never => {
stripe::CreateCheckoutSessionCustomerUpdateName::Never
}
}
}
}
impl From<StripeCustomerUpdateShipping> for stripe::CreateCheckoutSessionCustomerUpdateShipping {
fn from(value: StripeCustomerUpdateShipping) -> Self {
match value {
StripeCustomerUpdateShipping::Auto => {
stripe::CreateCheckoutSessionCustomerUpdateShipping::Auto
}
StripeCustomerUpdateShipping::Never => {
stripe::CreateCheckoutSessionCustomerUpdateShipping::Never
}
}
}
}
impl From<StripeCustomerUpdate> for stripe::CreateCheckoutSessionCustomerUpdate {
fn from(value: StripeCustomerUpdate) -> Self {
stripe::CreateCheckoutSessionCustomerUpdate {
address: value.address.map(Into::into),
name: value.name.map(Into::into),
shipping: value.shipping.map(Into::into),
}
}
}
impl From<StripeTaxIdCollection> for stripe::CreateCheckoutSessionTaxIdCollection {
fn from(value: StripeTaxIdCollection) -> Self {
stripe::CreateCheckoutSessionTaxIdCollection {
enabled: value.enabled,
}
}
}

View File

@@ -8,7 +8,6 @@ mod channel_buffer_tests;
mod channel_guest_tests;
mod channel_message_tests;
mod channel_tests;
// mod debug_panel_tests;
mod editor_tests;
mod following_tests;
mod git_tests;
@@ -18,7 +17,6 @@ mod random_channel_buffer_tests;
mod random_project_collaboration_tests;
mod randomized_test_helpers;
mod remote_editing_collaboration_tests;
mod stripe_billing_tests;
mod test_server;
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};

View File

@@ -1,123 +0,0 @@
use std::sync::Arc;
use pretty_assertions::assert_eq;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
let stripe_client = Arc::new(FakeStripeClient::new());
let stripe_billing = StripeBilling::test(stripe_client.clone());
(stripe_billing, stripe_client)
}
#[gpui::test]
async fn test_initialize() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Add test prices
let price1 = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(1_000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
let price2 = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
let price3 = StripePrice {
id: StripePriceId("price_3".into()),
unit_amount: Some(500),
lookup_key: None,
recurring: Some(StripePriceRecurring {
meter: Some("meter_1".to_string()),
}),
};
stripe_client
.prices
.lock()
.insert(price1.id.clone(), price1);
stripe_client
.prices
.lock()
.insert(price2.id.clone(), price2);
stripe_client
.prices
.lock()
.insert(price3.id.clone(), price3);
// Initialize the billing system
stripe_billing.initialize().await.unwrap();
// Verify that prices can be found by lookup key
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
assert_eq!(zed_pro_price_id.to_string(), "price_1");
let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
assert_eq!(zed_free_price_id.to_string(), "price_2");
// Verify that a price can be found by lookup key
let zed_pro_price = stripe_billing
.find_price_by_lookup_key("zed-pro")
.await
.unwrap();
assert_eq!(zed_pro_price.id.to_string(), "price_1");
assert_eq!(zed_pro_price.unit_amount, Some(1_000));
// Verify that finding a non-existent lookup key returns an error
let result = stripe_billing
.find_price_by_lookup_key("non-existent")
.await;
assert!(result.is_err());
}
#[gpui::test]
async fn test_find_or_create_customer_by_email() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Create a customer with an email that doesn't yet correspond to a customer.
{
let email = "user@example.com";
let customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
let customer = stripe_client
.customers
.lock()
.get(&customer_id)
.unwrap()
.clone();
assert_eq!(customer.email.as_deref(), Some(email));
}
// Create a customer with an email that corresponds to an existing customer.
{
let email = "user2@example.com";
let existing_customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
let customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
assert_eq!(customer_id, existing_customer_id);
let customer = stripe_client
.customers
.lock()
.get(&customer_id)
.unwrap()
.clone();
assert_eq!(customer.email.as_deref(), Some(email));
}
}

View File

@@ -1,4 +1,3 @@
use crate::stripe_client::FakeStripeClient;
use crate::{
AppState, Config,
db::{NewUserParams, UserId, tests::TestDb},
@@ -566,12 +565,8 @@ impl TestServer {
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
llm_db: None,
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
blob_store_client: None,
real_stripe_client: None,
stripe_client: Some(Arc::new(FakeStripeClient::new())),
stripe_billing: None,
executor,
kinesis_client: None,
config: Config {
@@ -608,7 +603,6 @@ impl TestServer {
auto_join_channel_id: None,
migrations_path: None,
seed_path: None,
stripe_api_key: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,

View File

@@ -674,7 +674,7 @@ impl ChatPanel {
})
})
.when_some(message_id, |el, message_id| {
let this = cx.entity().clone();
let this = cx.entity();
el.child(
self.render_popover_button(

View File

@@ -95,7 +95,7 @@ pub fn init(cx: &mut App) {
.and_then(|room| room.read(cx).channel_id());
if let Some(channel_id) = channel_id {
let workspace = cx.entity().clone();
let workspace = cx.entity();
window.defer(cx, move |window, cx| {
ChannelView::open(channel_id, None, workspace, window, cx)
.detach_and_log_err(cx)
@@ -1142,7 +1142,7 @@ impl CollabPanel {
window: &mut Window,
cx: &mut Context<Self>,
) {
let this = cx.entity().clone();
let this = cx.entity();
if !(role == proto::ChannelRole::Guest
|| role == proto::ChannelRole::Talker
|| role == proto::ChannelRole::Member)
@@ -1272,7 +1272,7 @@ impl CollabPanel {
.channel_for_id(clipboard.channel_id)
.map(|channel| channel.name.clone())
});
let this = cx.entity().clone();
let this = cx.entity();
let context_menu = ContextMenu::build(window, cx, |mut context_menu, window, cx| {
if self.has_subchannels(ix) {
@@ -1439,7 +1439,7 @@ impl CollabPanel {
window: &mut Window,
cx: &mut Context<Self>,
) {
let this = cx.entity().clone();
let this = cx.entity();
let in_room = ActiveCall::global(cx).read(cx).room().is_some();
let context_menu = ContextMenu::build(window, cx, |mut context_menu, _, _| {

View File

@@ -586,7 +586,7 @@ impl ChannelModalDelegate {
return;
};
let user_id = membership.user.id;
let picker = cx.entity().clone();
let picker = cx.entity();
let context_menu = ContextMenu::build(window, cx, |mut menu, _window, _cx| {
let role = membership.role;

View File

@@ -321,7 +321,7 @@ impl NotificationPanel {
.justify_end()
.child(Button::new("decline", "Decline").on_click({
let notification = notification.clone();
let entity = cx.entity().clone();
let entity = cx.entity();
move |_, _, cx| {
entity.update(cx, |this, cx| {
this.respond_to_notification(
@@ -334,7 +334,7 @@ impl NotificationPanel {
}))
.child(Button::new("accept", "Accept").on_click({
let notification = notification.clone();
let entity = cx.entity().clone();
let entity = cx.entity();
move |_, _, cx| {
entity.update(cx, |this, cx| {
this.respond_to_notification(

View File

@@ -12,6 +12,8 @@ minidumper.workspace = true
paths.workspace = true
release_channel.workspace = true
smol.workspace = true
serde.workspace = true
serde_json.workspace = true
workspace-hack.workspace = true
[lints]

View File

@@ -2,15 +2,17 @@ use crash_handler::CrashHandler;
use log::info;
use minidumper::{Client, LoopAction, MinidumpBinary};
use release_channel::{RELEASE_CHANNEL, ReleaseChannel};
use serde::{Deserialize, Serialize};
use std::{
env,
fs::File,
fs::{self, File},
io,
panic::Location,
path::{Path, PathBuf},
process::{self, Command},
sync::{
LazyLock, OnceLock,
Arc, OnceLock,
atomic::{AtomicBool, Ordering},
},
thread,
@@ -18,19 +20,17 @@ use std::{
};
// set once the crash handler has initialized and the client has connected to it
pub static CRASH_HANDLER: AtomicBool = AtomicBool::new(false);
pub static CRASH_HANDLER: OnceLock<Arc<Client>> = OnceLock::new();
// set when the first minidump request is made to avoid generating duplicate crash reports
pub static REQUESTED_MINIDUMP: AtomicBool = AtomicBool::new(false);
const CRASH_HANDLER_TIMEOUT: Duration = Duration::from_secs(60);
const CRASH_HANDLER_PING_TIMEOUT: Duration = Duration::from_secs(60);
const CRASH_HANDLER_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
pub static GENERATE_MINIDUMPS: LazyLock<bool> = LazyLock::new(|| {
*RELEASE_CHANNEL != ReleaseChannel::Dev || env::var("ZED_GENERATE_MINIDUMPS").is_ok()
});
pub async fn init(id: String) {
if !*GENERATE_MINIDUMPS {
pub async fn init(crash_init: InitCrashHandler) {
if *RELEASE_CHANNEL == ReleaseChannel::Dev && env::var("ZED_GENERATE_MINIDUMPS").is_err() {
return;
}
let exe = env::current_exe().expect("unable to find ourselves");
let zed_pid = process::id();
// TODO: we should be able to get away with using 1 crash-handler process per machine,
@@ -61,9 +61,11 @@ pub async fn init(id: String) {
smol::Timer::after(retry_frequency).await;
}
let client = maybe_client.unwrap();
client.send_message(1, id).unwrap(); // set session id on the server
client
.send_message(1, serde_json::to_vec(&crash_init).unwrap())
.unwrap();
let client = std::sync::Arc::new(client);
let client = Arc::new(client);
let handler = crash_handler::CrashHandler::attach(unsafe {
let client = client.clone();
crash_handler::make_crash_event(move |crash_context: &crash_handler::CrashContext| {
@@ -72,7 +74,6 @@ pub async fn init(id: String) {
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
client.send_message(2, "mistakes were made").unwrap();
client.ping().unwrap();
client.request_dump(crash_context).is_ok()
} else {
@@ -87,7 +88,7 @@ pub async fn init(id: String) {
{
handler.set_ptracer(Some(server_pid));
}
CRASH_HANDLER.store(true, Ordering::Release);
CRASH_HANDLER.set(client.clone()).ok();
std::mem::forget(handler);
info!("crash handler registered");
@@ -98,14 +99,43 @@ pub async fn init(id: String) {
}
pub struct CrashServer {
session_id: OnceLock<String>,
initialization_params: OnceLock<InitCrashHandler>,
panic_info: OnceLock<CrashPanic>,
has_connection: Arc<AtomicBool>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct CrashInfo {
pub init: InitCrashHandler,
pub panic: Option<CrashPanic>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct InitCrashHandler {
pub session_id: String,
pub zed_version: String,
pub release_channel: String,
pub commit_sha: String,
// pub gpu: String,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct CrashPanic {
pub message: String,
pub span: String,
}
impl minidumper::ServerHandler for CrashServer {
fn create_minidump_file(&self) -> Result<(File, PathBuf), io::Error> {
let err_message = "Need to send a message with the ID upon starting the crash handler";
let err_message = "Missing initialization data";
let dump_path = paths::logs_dir()
.join(self.session_id.get().expect(err_message))
.join(
&self
.initialization_params
.get()
.expect(err_message)
.session_id,
)
.with_extension("dmp");
let file = File::create(&dump_path)?;
Ok((file, dump_path))
@@ -122,38 +152,71 @@ impl minidumper::ServerHandler for CrashServer {
info!("failed to write minidump: {:#}", e);
}
}
let crash_info = CrashInfo {
init: self
.initialization_params
.get()
.expect("not initialized")
.clone(),
panic: self.panic_info.get().cloned(),
};
let crash_data_path = paths::logs_dir()
.join(&crash_info.init.session_id)
.with_extension("json");
fs::write(crash_data_path, serde_json::to_vec(&crash_info).unwrap()).ok();
LoopAction::Exit
}
fn on_message(&self, kind: u32, buffer: Vec<u8>) {
let message = String::from_utf8(buffer).expect("invalid utf-8");
info!("kind: {kind}, message: {message}",);
if kind == 1 {
self.session_id
.set(message)
.expect("session id already initialized");
match kind {
1 => {
let init_data =
serde_json::from_slice::<InitCrashHandler>(&buffer).expect("invalid init data");
self.initialization_params
.set(init_data)
.expect("already initialized");
}
2 => {
let panic_data =
serde_json::from_slice::<CrashPanic>(&buffer).expect("invalid panic data");
self.panic_info.set(panic_data).expect("already panicked");
}
_ => {
panic!("invalid message kind");
}
}
}
fn on_client_disconnected(&self, clients: usize) -> LoopAction {
info!("client disconnected, {clients} remaining");
if clients == 0 {
LoopAction::Exit
} else {
LoopAction::Continue
}
fn on_client_disconnected(&self, _clients: usize) -> LoopAction {
LoopAction::Exit
}
fn on_client_connected(&self, _clients: usize) -> LoopAction {
self.has_connection.store(true, Ordering::SeqCst);
LoopAction::Continue
}
}
pub fn handle_panic() {
if !*GENERATE_MINIDUMPS {
return;
}
pub fn handle_panic(message: String, span: Option<&Location>) {
let span = span
.map(|loc| format!("{}:{}", loc.file(), loc.line()))
.unwrap_or_default();
// wait 500ms for the crash handler process to start up
// if it's still not there just write panic info and no minidump
let retry_frequency = Duration::from_millis(100);
for _ in 0..5 {
if CRASH_HANDLER.load(Ordering::Acquire) {
if let Some(client) = CRASH_HANDLER.get() {
client
.send_message(
2,
serde_json::to_vec(&CrashPanic { message, span }).unwrap(),
)
.ok();
log::error!("triggering a crash to generate a minidump...");
#[cfg(target_os = "linux")]
CrashHandler.simulate_signal(crash_handler::Signal::Trap as u32);
@@ -170,14 +233,30 @@ pub fn crash_server(socket: &Path) {
log::info!("Couldn't create socket, there may already be a running crash server");
return;
};
let ab = AtomicBool::new(false);
let shutdown = Arc::new(AtomicBool::new(false));
let has_connection = Arc::new(AtomicBool::new(false));
std::thread::spawn({
let shutdown = shutdown.clone();
let has_connection = has_connection.clone();
move || {
std::thread::sleep(CRASH_HANDLER_CONNECT_TIMEOUT);
if !has_connection.load(Ordering::SeqCst) {
shutdown.store(true, Ordering::SeqCst);
}
}
});
server
.run(
Box::new(CrashServer {
session_id: OnceLock::new(),
initialization_params: OnceLock::new(),
panic_info: OnceLock::new(),
has_connection,
}),
&ab,
Some(CRASH_HANDLER_TIMEOUT),
&shutdown,
Some(CRASH_HANDLER_PING_TIMEOUT),
)
.expect("failed to run server");
}

View File

@@ -291,7 +291,7 @@ pub(crate) fn new_debugger_pane(
let Some(project) = project.upgrade() else {
return ControlFlow::Break(());
};
let this_pane = cx.entity().clone();
let this_pane = cx.entity();
let item = if tab.pane == this_pane {
pane.item_for_index(tab.ix)
} else {
@@ -502,7 +502,7 @@ pub(crate) fn new_debugger_pane(
.on_drag(
DraggedTab {
item: item.boxed_clone(),
pane: cx.entity().clone(),
pane: cx.entity(),
detail: 0,
is_active: selected,
ix,

Some files were not shown because too many files have changed in this diff Show More