Compare commits

...

75 Commits

Author SHA1 Message Date
Michael Sloan
fc3a67f046 Fix mild memory leak for failed cache tasks 2025-05-04 22:32:22 +02:00
Michael Sloan
670813b9de Remove old ideation comment 2025-05-04 22:32:09 +02:00
Michael Sloan
61d615549f Clippy + other polish 2025-05-04 22:19:06 +02:00
Michael Sloan
3931d74275 Merge branch 'main' into gemini-caching 2025-05-04 22:12:18 +02:00
Michael Sloan
bb82d9ca82 agent eval: Fix --model arg and add --provider (#29883)
Release Notes:

- N/A
2025-05-04 13:43:57 -06:00
ZaraPhu
007685f6d4 docs: Add instructions for uninstalling Zed (#29840) 2025-05-04 17:41:36 +00:00
Max Brunsfeld
c3d9cdecab Change cloud language model provider JSON protocol to surface errors and usage information (#29830)
Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
2025-05-04 17:37:42 +00:00
Bennet Bo Fenner
3984531a45 agent: Rename @rules to @rule (#29881)
This is purely a cosmetic change, renamed `@rules` to `@rule` which
unifies the @mention experience (for files, threads etc. we also use
`@file`, `@thread` not `@files`, `@thread`). Would also make sense to
rename the rules picker to rule picker, but i do not wanna introduce
conflicts just for the purpose of re-naming.

Release Notes:

- N/A
2025-05-04 16:25:44 +00:00
Michael Sloan
2cea66c5cb Change some todo! to TODO 2025-05-04 18:11:09 +02:00
Michael Sloan
09e26204ad Update cache TTL if it already exists 2025-05-04 18:09:49 +02:00
Marshall Bowers
cceb13b7cd collab: Add use_llm_request_queue to LlmTokenClaims (#29877)
This PR adds a `use_llm_request_queue` field to the LLM token claims,
based on the `llm-request-queue` feature flag.

Release Notes:

- N/A
2025-05-04 12:08:43 -04:00
Marshall Bowers
427101b634 collab: Drop legacy subscription usage and meter tables (#29876)
This PR adds a migration to drop the `subscription_usages` and
`subscription_usage_meters` tables from the database.

We're now using `subscription_usages_v2` and
`subscription_usage_meters_v2` everywhere.

Release Notes:

- N/A
2025-05-04 10:42:40 -04:00
Antonio Scandurra
4d51602e7b Encourage editing over re-creating a file from scratch (#29870)
I also introduced a new eval to prove the encouragement actually makes a
difference.

Release Notes:

- Improved agent behavior when streaming edits, encouraging it to
editing files as opposed to creating them from scratch
2025-05-04 13:18:28 +00:00
Marshall Bowers
ca1dc821cf collab: Fix subscription_usage_id column type (#29871)
This PR fixes the type of the `subscription_usage_id` column on the
`SubscriptionUsageMeter` model.

Release Notes:

- N/A
2025-05-04 13:05:26 +00:00
Michael Sloan
bfea3e5285 Remove todo about special handling of response for cache missing 2025-05-04 14:19:36 +02:00
Michael Sloan
f704e0578a Check cache creation 404 text 2025-05-04 14:14:52 +02:00
Michael Sloan
04b16fedde Reorganize 2025-05-04 14:05:47 +02:00
Danilo Leal
2e3baef299 agent: Polish single-file review toolbar controls (#29866) 2025-05-04 07:53:21 -03:00
Antonio Scandurra
545ae27079 Add the ability to follow the agent as it makes edits (#29839)
Nathan here: I also tacked on a bunch of UI refinement.

Release Notes:

- Introduced the ability to follow the agent around as it reads and
edits files.

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-05-04 08:28:39 +00:00
Danilo Leal
425f32e068 agent: Add the single_file_review setting to the UI (#29859)
Release Notes:

- agent: Add the `single_file_review` setting to the UI
2025-05-03 21:01:44 -03:00
Agus Zubiaga
9c11d24887 Fix hiding editor toolbar and add agent_review setting (#29854)
Closes #29836

The agent diff toolbar item was causing the editor toolbar to show even
when all the other elements were disabled via settings.

This PR fixes this by setting the location to
`ToolbarItemLocation::Hidden` in the states where it shouldn't show.

It also adds a new a `toolbar.agent_review` setting to hide the agent
review buttons altogether. However, if the other toolbar elements are
hidden and the file isn't under review, the editor toolbar will still be
hidden. So you only need to set this to `false` if you don't want them
to show up even under agent review.

Release Notes:

- N/A
2025-05-03 17:43:46 -03:00
Marshall Bowers
1fc57ea9f5 feature_flags: Add a constant to control Agent-related feature flags (#29853)
This PR adds a singular constant that controls the Agent-related feature
flags.

This way we can tweak this one value when we're ready to build the final
build for the launch.

Release Notes:

- N/A
2025-05-03 20:16:25 +00:00
Marshall Bowers
c3d2831d86 collab: Use new subscription usage tables (#29848)
This PR updates Collab to use the new subscription usage tables added in
#29847.

Release Notes:

- N/A
2025-05-03 17:56:43 +00:00
Marshall Bowers
c1247977ed collab: Add new tables for subscription usages and meters (#29847)
This PR adds two new tables:

- `subscription_usages_v2`
- `subscription_usage_meters_v2`

These are the same as the old ones, except using UUIDs as primary keys.

Release Notes:

- N/A
2025-05-03 17:21:22 +00:00
Marshall Bowers
12c26a4fa6 collab: Don't try to transfer usage when a Zed Pro trial is canceled (#29843)
This PR fixes an issue where we would erroneously try to transfer
existing subscription usage when a Zed Pro trial was canceled.

Release Notes:

- N/A
2025-05-03 14:57:54 +00:00
Michael Sloan
bde4bd5b3d Keep track of models lacking caching, based on cache creation status code 2025-05-03 08:11:49 -06:00
Michael Sloan
85e26b7f02 Resolve some todo! + other polish 2025-05-03 07:48:13 -06:00
Michael Sloan
0cc51f72e5 Fix compilation 2025-05-03 05:00:13 -06:00
Marshall Bowers
7f8e3fd482 ui: Implement ParentElement for Banner (#29834)
This PR implements the `ParentElement` trait for the `Banner` component
so that it can use the real children APIs instead of a bespoke one.

Release Notes:

- N/A
2025-05-03 02:36:53 +00:00
Marshall Bowers
f0515d1c34 agent: Show a notice when reaching consecutive tool use limits (#29833)
This PR adds a notice when reaching consecutive tool use limits when
using normal mode.

Here's an example with the limit artificially lowered to 2 consecutive
tool uses:


https://github.com/user-attachments/assets/32da8d38-67de-4d6b-8f24-754d2518e5d4

Release Notes:

- agent: Added a notice when reaching consecutive tool use limits when
using a model in normal mode.
2025-05-03 02:09:54 +00:00
Danilo Leal
10a7f2a972 agent: Add several UX improvements (#29828)
Still a work in progress.

Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Cole Miller <m@cole-miller.net>
Co-authored-by: Nathan Sobo <1789+nathansobo@users.noreply.github.com>
Co-authored-by: Cole Miller <53574922+cole-miller@users.noreply.github.com>
2025-05-02 19:00:55 -06:00
Danilo Leal
5053562e28 agent: Refresh the profile selector and modal design (#29816)
- [x] Separate MCP servers from tools in the profile customization modal
view
- [x] Group MCP tools in the MCP picker and add a heading
- [x] Separate bult-in profiles from custom ones in the dropdown
selector
- [x] Separate bult-in profiles from custom ones in the modal
- [ ] Enable looping through items via keybinding without opening the
dropdown (will be done on a follow-up PR)
- [ ] Stretch: Focus on the currently active item upon opening the
dropdown (will be done on a follow-up PR)

Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <53836821+bennetbo@users.noreply.github.com>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
2025-05-02 20:34:36 -03:00
Agus Zubiaga
1877fce609 agent: Fix default cursor position on reviewing editors (#29825)
The cursor wasn't always placed at the first hunk for review editors.

Release Notes:

- N/A
2025-05-02 21:58:00 +00:00
Agus Zubiaga
64316309aa agent: Review edits in single-file editors (#29820)
Enables reviewing agent edits from single-file editors in addition to
the multibuffer experience we already had.


https://github.com/user-attachments/assets/a2c287f0-51d6-43a1-8537-821498b91983


This feature can be turned off by setting `assistant.single_file_review:
false`.

Release Notes:

- agent: Review edits in single-file editors
2025-05-02 17:57:16 -03:00
Max Brunsfeld
04772bf17d Add support for queuing status updates in cloud language model provider (#29818)
This sets us up to display queue position information to the user, once
our language model backend is updated to support request queuing.

The JSON returned by the LLM backend will need to look like this:

```json
{"queue": {"status": "queued", "position": 1}}
{"queue": {"status": "started"}}
{"event": {"THE_UPSTREAM_MODEL_PROVIDER_EVENT": "..."}} 
```

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
2025-05-02 20:36:39 +00:00
Richard Feldman
4d1df7bcd7 Re-enable directory-related tools (#29809)
Also `now` in `write` profile

Release Notes:

- Tools for manipulating directories no longer require confirmation, and
are enabled in the Write profile
- Enabled `now` and `list_directory` tools by default in Write profile

---------

Co-authored-by: Cole Miller <m@cole-miller.net>
Co-authored-by: Cole Miller <cole@zed.dev>
2025-05-02 16:11:16 -04:00
Cole Miller
9547d42b15 Support @-mentions in inline assists and when editing old agent panel messages (#29734)
Closes #ISSUE

Co-authored-by: Bennet <bennet@zed.dev>

Release Notes:

- Added support for context `@mentions` in the inline prompt editor and
when editing past messages in the agent panel.

---------

Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-05-02 20:08:53 +00:00
Michael Sloan
abbb3dc7a6 Attempt to drop expired cache entries. Untested - writteon on a plane 2025-05-02 14:07:43 -06:00
Umesh Yadav
c918f6cde1 agent: Add assistant panel width persistence (#28808)
Previously, the assistant panel width was not persisted across sessions.
This meant that upon restarting the Zed editor, the panel would revert
to its default size, disrupting the user's preferred layout.

This pull request introduces persistence for the assistant panel width.
The width is now saved to the key-value store when the editor is closed
and restored on startup, ensuring a consistent UI experience across
different sessions.

Release Notes:

- agent: Add assistant panel width persistence

---------

Signed-off-by: Umesh Yadav <umesh4257@gmail.com>
2025-05-02 13:05:03 -07:00
Anthony Eid
da98e300cc debugger: Clear active debug line on thread continued (#29811)
I also moved the breakpoint store to session from local mode, because
both remote/local modes will need the ability to remove active debug
lines.

Release Notes:

- N/A
2025-05-02 15:24:28 -04:00
Richard Feldman
e6b0d8e48b Delete obsolete tools (#29808)
Release Notes:

- Removed some obsolete tools: batch_tool, code_actions, code_symbols,
contents, symbol_info, rename

Co-authored-by: Cole Miller <m@cole-miller.net>
2025-05-02 18:52:42 +00:00
Bennet Bo Fenner
9147f89257 zed_extension_api: Release v0.5.0 (#29802)
This PR releases v0.5.0 of the Zed extension API.

Support for this version of the extension API will land in Zed v0.186.x.

Release Notes:

- N/A
2025-05-02 15:58:54 +00:00
Richard Feldman
9efc09c5a6 Add eval for open_tool (#29801)
Also have its description say it should only be used on request

Release Notes:

- N/A
2025-05-02 15:56:07 +00:00
Bennet Bo Fenner
e6f6b351b7 extension_api: Add documentation to context server configuration (#29800)
Release Notes:

- N/A
2025-05-02 15:37:05 +00:00
Bennet Bo Fenner
fde621f0e3 agent: Ensure that web search tool is always available (#29799)
Some changes in the LanguageModelRegistry caused the web search tool not
to show up, because the `DefaultModelChanged` event is not emitted at
startup anymore.

Release Notes:

- agent: Fixed an issue where the web search tool would not be available
after starting Zed (only when using zed.dev as a provider).
2025-05-02 15:34:08 +00:00
Marshall Bowers
c4556e9909 collab: Fix adding users to feature flags when migrating to new billing (#29795)
This PR fixes an issue where users were not being added to the feature
flags when being migrated to the new billing.

Release Notes:

- N/A
2025-05-02 15:07:49 +00:00
Kirill Bulatov
7e2de84155 Properly score fuzzy match queries with multiple chars in lower case (#29794)
Closes https://github.com/zed-industries/zed/issues/29526

Release Notes:

- Fixed file finder crashing for certain file names with multiple chars
in lowercase form
2025-05-02 15:02:53 +00:00
Kirill Bulatov
d1b35be353 Use proper settings in the diagnostics section (#29791)
Follow-up of https://github.com/zed-industries/zed/pull/29706

Release Notes:

- N/A

Co-authored-by: Cole Miller <cole@zed.dev>
2025-05-02 16:48:52 +03:00
Marshall Bowers
49a71ec3b8 collab: Update billing migration endpoint to work for users without active subscriptions (#29792)
This PR updates the billing migration endpoint to work for users who do
not have an active subscription.

This will allow us to use the endpoint to migrate all users.

Release Notes:

- N/A
2025-05-02 13:48:14 +00:00
Nate Butler
3bd7ae6e5b Standardize agent previews (#29790)
This PR makes agent previews render like any other preview in the
component preview list & pages.

Page:

![CleanShot 2025-05-02 at 09 17
12@2x](https://github.com/user-attachments/assets/8b611380-b686-4fd6-9c76-de27e35b0b38)

List:

![CleanShot 2025-05-02 at 09 17
33@2x](https://github.com/user-attachments/assets/ab063649-dc3c-4c95-969b-c3795b2197f2)


Release Notes:

- N/A
2025-05-02 13:32:59 +00:00
Max Brunsfeld
225deb6785 agent: Add animation in the edit file tool card until a diff is assigned (#29773)
This PR prevents this edit card from being shown expanded but empty,
like this:

<img width="590" alt="Screenshot 2025-05-01 at 7 38 47 PM"
src="https://github.com/user-attachments/assets/147d3d73-05b9-4493-8145-0ee915f12cd9"
/>

Now, we will show an animation until it has a diff computed.


https://github.com/user-attachments/assets/52900cdf-ee3d-4c3b-88c7-c53377543bcf

Release Notes:

- N/A

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-05-02 09:48:40 -03:00
Kirill Bulatov
33011f2eaf Open diagnostics editor faster when fetching cargo diagnostics (#29787)
Follow-up of https://github.com/zed-industries/zed/pull/29706

Release Notes:

- N/A
2025-05-02 12:10:01 +00:00
Kirill Bulatov
e14d078f8a Fix tasks not being stopped on reruns (#29786)
Follow-up of https://github.com/zed-industries/zed/pull/28993

* Tone down tasks' cancellation logging
* Fix task terminals' leak, disallowing to fully cancel the task by
dropping the terminal off the pane:

f619d5f02a/crates/terminal_view/src/terminal_panel.rs (L1464-L1471)

Release Notes:

- Fixed tasks not being stopped on reruns
2025-05-02 11:45:43 +00:00
Stanislav Alekseev
460ac96df4 Use project environment in LSP runnables context (#29761)
Release Notes:

- Fixed the tasks from LSP not inheriting the worktree environment

----

cc @SomeoneToIgnore
2025-05-02 11:01:39 +00:00
Michael Sloan
141e0a702a Remove logic for checking if a model supports caching 2025-05-01 21:37:54 -06:00
Michael Sloan
bb47b766a1 Stable IDs for gemini-1.5 variants 2025-05-01 21:37:10 -06:00
Michael Sloan
515cdb9ae6 Progress towards better cache awaiting, written without internet on an airplane 2025-05-01 21:25:15 -06:00
Michael Sloan
2666ff7873 Remove some agent generated comments 2025-05-01 09:11:52 -06:00
Michael Sloan
1d0cc37205 Progress towards blocking on a specific cache
Co-authored-by: Max <max@zed.dev>
2025-04-30 23:26:47 -06:00
Michael Sloan
1257d44998 Initial implementation of also caching every agent request
Co-authored-by: Max <max@zed.dev>
2025-04-30 16:43:42 -06:00
Michael Sloan
5fdcdc1926 Add missing fields in provider code 2025-04-30 12:13:54 -06:00
Michael Sloan
8c91fc3153 Use model IDs which support caching for Gemini 2025-04-30 12:13:50 -06:00
Michael Sloan
b1595dba71 Comment out some code from interface sketching 2025-04-30 11:52:08 -06:00
Michael Sloan
d4702209ea Support for using cache in generation requests etc 2025-04-30 00:40:49 -06:00
Michael Sloan
8440ec03ad Fixes after merge 2025-04-29 23:32:51 -06:00
Michael Sloan
8c8eabe96d Merge branch 'fix-gemini-token-counting' into gemini-caching 2025-04-29 23:23:49 -06:00
Michael Sloan
10dfa36c91 Fix Gemini token counting + add support for counting whole requests
* Now provides the model id in the path instead of always `gemini-pro`, which appears to have stopped working.

* `CountTokensRequest` now takes a full `GenerateContentRequest` instead of just content.

* Fixes handling of `models/` prefix in `model` field of `GenerateContentRequest`, since that's required for use in `CountTokensRequest`. This didn't cause issues before because it was always cleared and used in the path.
2025-04-29 23:19:57 -06:00
Michael Sloan
88c7893913 Merge branch 'main' into gemini-caching 2025-04-29 23:03:14 -06:00
Michael Sloan
c2151f0082 Progress 2025-04-29 23:01:16 -06:00
Michael Sloan
d677117a48 Undo a change from Option<Vec<Tool>> -> Vec<Tool>
While this is safe, it is not safe on the llm worker side as it needs to deserialize `null` from older Zeds. Better to keep the definitions consistent.
2025-04-29 22:45:58 -06:00
Michael Sloan
f86b552a20 Merge branch 'main' into gemini-caching 2025-04-29 19:06:57 -06:00
Michael Sloan
1c6040e54f Add --prompt-file + improve request types 2025-04-29 14:17:37 -06:00
Michael Sloan
b861e1ca8c Add a cli example for google ai API 2025-04-29 13:58:49 -06:00
Michael Sloan
961e7dd52a Wrap gemini cache creation and update APIs 2025-04-29 13:03:06 -06:00
Michael Sloan
9c548fecbc WIP 2025-04-29 11:05:59 -06:00
148 changed files with 7267 additions and 4226 deletions

13
Cargo.lock generated
View File

@@ -79,7 +79,6 @@ dependencies = [
"heed",
"html_to_markdown",
"http_client",
"indexmap",
"indoc",
"itertools 0.14.0",
"jsonschema",
@@ -6199,12 +6198,17 @@ name = "google_ai"
version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"futures 0.3.31",
"http_client",
"log",
"reqwest_client",
"schemars",
"serde",
"serde_json",
"strum 0.27.1",
"time",
"tokio",
"workspace-hack",
]
@@ -7910,6 +7914,7 @@ dependencies = [
"mistral",
"ollama",
"open_ai",
"parking_lot",
"partial-json-fixer",
"project",
"proto",
@@ -7922,6 +7927,7 @@ dependencies = [
"theme",
"thiserror 2.0.12",
"tiktoken-rs",
"time",
"tokio",
"ui",
"util",
@@ -18093,7 +18099,6 @@ dependencies = [
"component",
"dap",
"db",
"derive_more",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
@@ -18827,9 +18832,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.7.1"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1"
checksum = "2adf9bc80def4ec93c190f06eb78111865edc2576019a9753eaef6fd7bc3b72c"
dependencies = [
"anyhow",
"serde",

View File

@@ -611,7 +611,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "0.7.1"
zed_llm_client = "0.7.4"
zstd = "0.11"
[workspace.dependencies.async-stripe]

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-crosshair-icon lucide-crosshair"><circle cx="12" cy="12" r="10"/><line x1="22" x2="18" y1="12" y2="12"/><line x1="6" x2="2" y1="12" y2="12"/><line x1="12" x2="12" y1="6" y2="2"/><line x1="12" x2="12" y1="22" y2="18"/></svg>

After

Width:  |  Height:  |  Size: 426 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-list-collapse-icon lucide-list-collapse"><path d="m3 10 2.5-2.5L3 5"/><path d="m3 19 2.5-2.5L3 14"/><path d="M10 6h11"/><path d="M10 12h11"/><path d="M10 18h11"/></svg>

After

Width:  |  Height:  |  Size: 371 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-message-circle-dashed-icon lucide-message-circle-dashed"><path d="M13.5 3.1c-.5 0-1-.1-1.5-.1s-1 .1-1.5.1"/><path d="M19.3 6.8a10.45 10.45 0 0 0-2.1-2.1"/><path d="M20.9 13.5c.1-.5.1-1 .1-1.5s-.1-1-.1-1.5"/><path d="M17.2 19.3a10.45 10.45 0 0 0 2.1-2.1"/><path d="M10.5 20.9c.5.1 1 .1 1.5.1s1-.1 1.5-.1"/><path d="M3.5 17.5 2 22l4.5-1.5"/><path d="M3.1 10.5c0 .5-.1 1-.1 1.5s.1 1 .1 1.5"/><path d="M6.8 4.7a10.45 10.45 0 0 0-2.1 2.1"/></svg>

After

Width:  |  Height:  |  Size: 644 B

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-scissors-icon lucide-scissors"><circle cx="6" cy="6" r="3"/><path d="M8.12 8.12 12 12"/><path d="M20 4 8.12 15.88"/><circle cx="6" cy="18" r="3"/><path d="M14.8 14.8 20 20"/></svg>

After

Width:  |  Height:  |  Size: 383 B

View File

@@ -194,6 +194,16 @@
"alt-shift-y": "git::UnstageAndNext"
}
},
{
"context": "Editor && editor_agent_diff",
"bindings": {
"ctrl-y": "agent::Keep",
"ctrl-n": "agent::Reject",
"ctrl-shift-y": "agent::KeepAll",
"ctrl-shift-n": "agent::RejectAll",
"shift-ctrl-r": "agent::OpenAgentDiff"
}
},
{
"context": "AgentDiff",
"bindings": {

View File

@@ -247,6 +247,17 @@
"cmd-shift-n": "agent::RejectAll"
}
},
{
"context": "Editor && editor_agent_diff",
"use_key_equivalents": true,
"bindings": {
"cmd-y": "agent::Keep",
"cmd-n": "agent::Reject",
"cmd-shift-y": "agent::KeepAll",
"cmd-shift-n": "agent::RejectAll",
"shift-ctrl-r": "agent::OpenAgentDiff"
}
},
{
"context": "AssistantPanel",
"use_key_equivalents": true,

View File

@@ -300,8 +300,10 @@
"breadcrumbs": true,
// Whether to show quick action buttons.
"quick_actions": true,
// Whether to show the Selections menu in the editor toolbar
"selections_menu": true
// Whether to show the Selections menu in the editor toolbar.
"selections_menu": true,
// Whether to show agent review buttons in the editor toolbar.
"agent_review": true
},
// Scrollbar related settings
"scrollbar": {
@@ -659,6 +661,8 @@
"always_allow_tool_actions": false,
// When enabled, the agent will stream edits.
"stream_edits": false,
// When enabled, agent edits will be displayed in single-file editors for review
"single_file_review": true,
"default_profile": "write",
"profiles": {
"ask": {
@@ -669,7 +673,7 @@
"contents": true,
"diagnostics": true,
"fetch": true,
"list_directory": false,
"list_directory": true,
"now": true,
"find_path": true,
"read_file": true,
@@ -683,24 +687,19 @@
"name": "Write",
"enable_all_context_servers": true,
"tools": {
"batch_tool": false,
"code_actions": false,
"code_symbols": false,
"contents": false,
"copy_path": false,
"copy_path": true,
"create_directory": true,
"create_file": true,
"delete_path": false,
"delete_path": true,
"diagnostics": true,
"edit_file": true,
"fetch": true,
"list_directory": true,
"move_path": false,
"now": false,
"move_path": true,
"now": true,
"find_path": true,
"read_file": true,
"grep": true,
"rename": false,
"symbol_info": false,
"terminal": true,
"thinking": true,
"web_search": true
@@ -934,7 +933,7 @@
// Shows all diagnostics when not specified.
"max_severity": null
},
"rust": {
"cargo": {
// When enabled, Zed disables rust-analyzer's check on save and starts to query
// Cargo diagnostics separately.
"fetch_cargo_diagnostics": false

View File

@@ -46,7 +46,6 @@ gpui.workspace = true
heed.workspace = true
html_to_markdown.workspace = true
http_client.workspace = true
indexmap.workspace = true
itertools.workspace = true
jsonschema.workspace = true
language.workspace = true

View File

@@ -722,7 +722,7 @@ fn open_markdown_link(
}
}),
Some(MentionLink::Fetch(url)) => cx.open_url(&url),
Some(MentionLink::Rules(prompt_id)) => window.dispatch_action(
Some(MentionLink::Rule(prompt_id)) => window.dispatch_action(
Box::new(OpenRulesLibrary {
prompt_to_select: Some(prompt_id.0),
}),
@@ -764,7 +764,6 @@ impl ActiveThread {
.unwrap()
}
});
let mut this = Self {
language_registry,
thread_store,
@@ -953,6 +952,9 @@ impl ActiveThread {
ThreadEvent::UsageUpdated(usage) => {
self.last_usage = Some(*usage);
}
ThreadEvent::NewRequest | ThreadEvent::CompletionCanceled => {
cx.notify();
}
ThreadEvent::StreamedCompletion
| ThreadEvent::SummaryGenerated
| ThreadEvent::SummaryChanged => {
@@ -1723,14 +1725,13 @@ impl ActiveThread {
let tool_uses = thread.tool_uses_for_message(message_id, cx);
let has_tool_uses = !tool_uses.is_empty();
let is_generating = thread.is_generating();
let is_generating_stale = thread.is_generation_stale().unwrap_or(false);
let is_first_message = ix == 0;
let is_last_message = ix == self.messages.len() - 1;
let show_feedback = thread.is_turn_end(ix);
let generating_label = (is_generating && is_last_message)
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
let loading_dots = (is_generating_stale && is_last_message)
.then(|| AnimatedLabel::new("").size(LabelSize::Small));
let editing_message_state = self
.editing_message
@@ -1753,6 +1754,8 @@ impl ActiveThread {
// For all items that should be aligned with the LLM's response.
const RESPONSE_PADDING_X: Pixels = px(19.);
let show_feedback = thread.is_turn_end(ix);
let feedback_container = h_flex()
.group("feedback_container")
.mt_1()
@@ -2029,80 +2032,84 @@ impl ActiveThread {
v_flex()
.w_full()
.when_some(checkpoint, |parent, checkpoint| {
let mut is_pending = false;
let mut error = None;
if let Some(last_restore_checkpoint) =
self.thread.read(cx).last_restore_checkpoint()
{
if last_restore_checkpoint.message_id() == message_id {
match last_restore_checkpoint {
LastRestoreCheckpoint::Pending { .. } => is_pending = true,
LastRestoreCheckpoint::Error { error: err, .. } => {
error = Some(err.clone());
.map(|parent| {
if let Some(checkpoint) = checkpoint.filter(|_| is_generating) {
let mut is_pending = false;
let mut error = None;
if let Some(last_restore_checkpoint) =
self.thread.read(cx).last_restore_checkpoint()
{
if last_restore_checkpoint.message_id() == message_id {
match last_restore_checkpoint {
LastRestoreCheckpoint::Pending { .. } => is_pending = true,
LastRestoreCheckpoint::Error { error: err, .. } => {
error = Some(err.clone());
}
}
}
}
}
let restore_checkpoint_button =
Button::new(("restore-checkpoint", ix), "Restore Checkpoint")
.icon(if error.is_some() {
IconName::XCircle
} else {
IconName::Undo
})
.icon_size(IconSize::XSmall)
.icon_position(IconPosition::Start)
.icon_color(if error.is_some() {
Some(Color::Error)
} else {
None
})
.label_size(LabelSize::XSmall)
.disabled(is_pending)
.on_click(cx.listener(move |this, _, _window, cx| {
this.thread.update(cx, |thread, cx| {
thread
.restore_checkpoint(checkpoint.clone(), cx)
.detach_and_log_err(cx);
});
}));
let restore_checkpoint_button =
Button::new(("restore-checkpoint", ix), "Restore Checkpoint")
.icon(if error.is_some() {
IconName::XCircle
} else {
IconName::Undo
})
.icon_size(IconSize::XSmall)
.icon_position(IconPosition::Start)
.icon_color(if error.is_some() {
Some(Color::Error)
} else {
None
})
.label_size(LabelSize::XSmall)
.disabled(is_pending)
.on_click(cx.listener(move |this, _, _window, cx| {
this.thread.update(cx, |thread, cx| {
thread
.restore_checkpoint(checkpoint.clone(), cx)
.detach_and_log_err(cx);
});
}));
let restore_checkpoint_button = if is_pending {
restore_checkpoint_button
.with_animation(
("pulsating-restore-checkpoint-button", ix),
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any_element()
} else if let Some(error) = error {
restore_checkpoint_button
.tooltip(Tooltip::text(error.to_string()))
.into_any_element()
let restore_checkpoint_button = if is_pending {
restore_checkpoint_button
.with_animation(
("pulsating-restore-checkpoint-button", ix),
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any_element()
} else if let Some(error) = error {
restore_checkpoint_button
.tooltip(Tooltip::text(error.to_string()))
.into_any_element()
} else {
restore_checkpoint_button.into_any_element()
};
parent.child(
h_flex()
.pt_2p5()
.px_2p5()
.w_full()
.gap_1()
.child(ui::Divider::horizontal())
.child(restore_checkpoint_button)
.child(ui::Divider::horizontal()),
)
} else {
restore_checkpoint_button.into_any_element()
};
parent.child(
h_flex()
.pt_2p5()
.px_2p5()
.w_full()
.gap_1()
.child(ui::Divider::horizontal())
.child(restore_checkpoint_button)
.child(ui::Divider::horizontal()),
)
parent
}
})
.when(is_first_message, |parent| {
parent.child(self.render_rules_item(cx))
})
.child(styled_message)
.when(generating_label.is_some(), |this| {
.when(is_generating && is_last_message, |this| {
this.child(
h_flex()
.h_8()
@@ -2110,7 +2117,7 @@ impl ActiveThread {
.mb_4()
.ml_4()
.py_1p5()
.child(generating_label.unwrap()),
.when_some(loading_dots, |this, loading_dots| this.child(loading_dots)),
)
})
.when(show_feedback, move |parent| {

File diff suppressed because it is too large Load Diff

View File

@@ -45,7 +45,7 @@ pub use crate::context::{ContextLoadResult, LoadedContext};
pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;
pub use agent_diff::{AgentDiff, AgentDiffToolbar};
pub use agent_diff::{AgentDiffPane, AgentDiffToolbar};
pub use context_store::ContextStore;
pub use ui::{all_agent_previews, get_agent_preview};
@@ -77,7 +77,8 @@ actions!(
Keep,
Reject,
RejectAll,
KeepAll
KeepAll,
Follow
]
);

View File

@@ -214,47 +214,91 @@ impl AssistantConfiguration {
fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
const HEADING: &str = "Allow running editing tools without asking for confirmation";
h_flex()
.gap_4()
.justify_between()
.flex_wrap()
.child(
v_flex()
.gap_0p5()
.max_w_5_6()
.child(Label::new("Allow running editing tools without asking for confirmation"))
.child(
Label::new(
"The agent can perform potentially destructive actions without asking for your confirmation.",
)
.color(Color::Muted),
),
)
.child(
Switch::new(
"always-allow-tool-actions-switch",
always_allow_tool_actions.into(),
)
.color(SwitchColor::Accent)
.on_click({
let fs = self.fs.clone();
move |state, _window, cx| {
let allow = state == &ToggleState::Selected;
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _| {
settings.set_always_allow_tool_actions(allow);
},
);
}
}),
)
}
fn render_single_file_review(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let single_file_review = AssistantSettings::get_global(cx).single_file_review;
h_flex()
.gap_4()
.justify_between()
.flex_wrap()
.child(
v_flex()
.gap_0p5()
.max_w_5_6()
.child(Label::new("Enable single-file agent reviews"))
.child(
Label::new(
"Agent edits are also displayed in single-file editors for review.",
)
.color(Color::Muted),
),
)
.child(
Switch::new("single-file-review-switch", single_file_review.into())
.color(SwitchColor::Accent)
.on_click({
let fs = self.fs.clone();
move |state, _window, cx| {
let allow = state == &ToggleState::Selected;
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _| {
settings.set_single_file_review(allow);
},
);
}
}),
)
}
fn render_general_settings_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_2()
.gap_2p5()
.flex_1()
.child(Headline::new("General Settings"))
.child(
h_flex()
.gap_4()
.justify_between()
.flex_wrap()
.child(
v_flex()
.gap_0p5()
.max_w_5_6()
.child(Label::new(HEADING))
.child(Label::new("When enabled, the agent can perform potentially destructive actions without asking for your confirmation.").color(Color::Muted)),
)
.child(
Switch::new(
"always-allow-tool-actions-switch",
always_allow_tool_actions.into(),
)
.color(SwitchColor::Accent)
.on_click({
let fs = self.fs.clone();
move |state, _window, cx| {
let allow = state == &ToggleState::Selected;
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _| {
settings.set_always_allow_tool_actions(allow);
},
);
}
}),
),
)
.child(self.render_command_permission(cx))
.child(self.render_single_file_review(cx))
}
fn render_context_servers_section(
@@ -549,7 +593,7 @@ impl Render for AssistantConfiguration {
.track_scroll(&self.scroll_handle)
.size_full()
.overflow_y_scroll()
.child(self.render_command_permission(cx))
.child(self.render_general_settings_section(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_context_servers_section(window, cx))
.child(Divider::horizontal().color(DividerColor::Border))

View File

@@ -2,7 +2,7 @@ mod profile_modal_header;
use std::sync::Arc;
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, builtin_profiles};
use assistant_tool::ToolWorkingSet;
use convert_case::{Case, Casing as _};
use editor::Editor;
@@ -22,6 +22,8 @@ use crate::assistant_configuration::manage_profiles_modal::profile_modal_header:
use crate::assistant_configuration::tool_picker::{ToolPicker, ToolPickerDelegate};
use crate::{AssistantPanel, ManageProfiles, ThreadStore};
use super::tool_picker::ToolPickerMode;
enum Mode {
ChooseProfile(ChooseProfileMode),
NewProfile(NewProfileMode),
@@ -31,26 +33,39 @@ enum Mode {
tool_picker: Entity<ToolPicker>,
_subscription: Subscription,
},
ConfigureMcps {
profile_id: AgentProfileId,
tool_picker: Entity<ToolPicker>,
_subscription: Subscription,
},
}
impl Mode {
pub fn choose_profile(_window: &mut Window, cx: &mut Context<ManageProfilesModal>) -> Self {
let settings = AssistantSettings::get_global(cx);
let mut profiles = settings.profiles.clone();
profiles.sort_unstable_by(|_, a, _, b| a.name.cmp(&b.name));
let mut builtin_profiles = Vec::new();
let mut custom_profiles = Vec::new();
let profiles = profiles
.into_iter()
.map(|(id, profile)| ProfileEntry {
id,
name: profile.name,
for (profile_id, profile) in settings.profiles.iter() {
let entry = ProfileEntry {
id: profile_id.clone(),
name: profile.name.clone(),
navigation: NavigableEntry::focusable(cx),
})
.collect::<Vec<_>>();
};
if builtin_profiles::is_builtin(profile_id) {
builtin_profiles.push(entry);
} else {
custom_profiles.push(entry);
}
}
builtin_profiles.sort_unstable_by(|a, b| a.name.cmp(&b.name));
custom_profiles.sort_unstable_by(|a, b| a.name.cmp(&b.name));
Self::ChooseProfile(ChooseProfileMode {
profiles,
builtin_profiles,
custom_profiles,
add_new_profile: NavigableEntry::focusable(cx),
})
}
@@ -65,7 +80,8 @@ struct ProfileEntry {
#[derive(Clone)]
pub struct ChooseProfileMode {
profiles: Vec<ProfileEntry>,
builtin_profiles: Vec<ProfileEntry>,
custom_profiles: Vec<ProfileEntry>,
add_new_profile: NavigableEntry,
}
@@ -74,6 +90,8 @@ pub struct ViewProfileMode {
profile_id: AgentProfileId,
fork_profile: NavigableEntry,
configure_tools: NavigableEntry,
configure_mcps: NavigableEntry,
cancel_item: NavigableEntry,
}
#[derive(Clone)]
@@ -166,10 +184,50 @@ impl ManageProfilesModal {
profile_id,
fork_profile: NavigableEntry::focusable(cx),
configure_tools: NavigableEntry::focusable(cx),
configure_mcps: NavigableEntry::focusable(cx),
cancel_item: NavigableEntry::focusable(cx),
});
self.focus_handle(cx).focus(window);
}
fn configure_mcps(
&mut self,
profile_id: AgentProfileId,
window: &mut Window,
cx: &mut Context<Self>,
) {
let settings = AssistantSettings::get_global(cx);
let Some(profile) = settings.profiles.get(&profile_id).cloned() else {
return;
};
let tool_picker = cx.new(|cx| {
let delegate = ToolPickerDelegate::new(
ToolPickerMode::McpTools,
self.fs.clone(),
self.tools.clone(),
self.thread_store.clone(),
profile_id.clone(),
profile,
cx,
);
ToolPicker::mcp_tools(delegate, window, cx)
});
let dismiss_subscription = cx.subscribe_in(&tool_picker, window, {
let profile_id = profile_id.clone();
move |this, _tool_picker, _: &DismissEvent, window, cx| {
this.view_profile(profile_id.clone(), window, cx);
}
});
self.mode = Mode::ConfigureMcps {
profile_id,
tool_picker,
_subscription: dismiss_subscription,
};
self.focus_handle(cx).focus(window);
}
fn configure_tools(
&mut self,
profile_id: AgentProfileId,
@@ -183,6 +241,7 @@ impl ManageProfilesModal {
let tool_picker = cx.new(|cx| {
let delegate = ToolPickerDelegate::new(
ToolPickerMode::BuiltinTools,
self.fs.clone(),
self.tools.clone(),
self.thread_store.clone(),
@@ -190,7 +249,7 @@ impl ManageProfilesModal {
profile,
cx,
);
ToolPicker::new(delegate, window, cx)
ToolPicker::builtin_tools(delegate, window, cx)
});
let dismiss_subscription = cx.subscribe_in(&tool_picker, window, {
let profile_id = profile_id.clone();
@@ -241,6 +300,7 @@ impl ManageProfilesModal {
}
Mode::ViewProfile(_) => {}
Mode::ConfigureTools { .. } => {}
Mode::ConfigureMcps { .. } => {}
}
}
@@ -257,7 +317,12 @@ impl ManageProfilesModal {
}
}
Mode::ViewProfile(_) => self.choose_profile(window, cx),
Mode::ConfigureTools { .. } => {}
Mode::ConfigureTools { profile_id, .. } => {
self.view_profile(profile_id.clone(), window, cx)
}
Mode::ConfigureMcps { profile_id, .. } => {
self.view_profile(profile_id.clone(), window, cx)
}
}
}
@@ -284,6 +349,7 @@ impl Focusable for ManageProfilesModal {
Mode::NewProfile(mode) => mode.name_editor.focus_handle(cx),
Mode::ViewProfile(_) => self.focus_handle.clone(),
Mode::ConfigureTools { tool_picker, .. } => tool_picker.focus_handle(cx),
Mode::ConfigureMcps { tool_picker, .. } => tool_picker.focus_handle(cx),
}
}
}
@@ -291,6 +357,51 @@ impl Focusable for ManageProfilesModal {
impl EventEmitter<DismissEvent> for ManageProfilesModal {}
impl ManageProfilesModal {
fn render_profile(
&self,
profile: &ProfileEntry,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement + use<> {
div()
.id(SharedString::from(format!("profile-{}", profile.id)))
.track_focus(&profile.navigation.focus_handle)
.on_action({
let profile_id = profile.id.clone();
cx.listener(move |this, _: &menu::Confirm, window, cx| {
this.view_profile(profile_id.clone(), window, cx);
})
})
.child(
ListItem::new(SharedString::from(format!("profile-{}", profile.id)))
.toggle_state(profile.navigation.focus_handle.contains_focused(window, cx))
.inset(true)
.spacing(ListItemSpacing::Sparse)
.child(Label::new(profile.name.clone()))
.end_slot(
h_flex()
.gap_1()
.child(
Label::new("Customize")
.size(LabelSize::Small)
.color(Color::Muted),
)
.children(KeyBinding::for_action_in(
&menu::Confirm,
&self.focus_handle,
window,
cx,
)),
)
.on_click({
let profile_id = profile.id.clone();
cx.listener(move |this, _, window, cx| {
this.view_profile(profile_id.clone(), window, cx);
})
}),
)
}
fn render_choose_profile(
&mut self,
mode: ChooseProfileMode,
@@ -301,57 +412,31 @@ impl ManageProfilesModal {
div()
.track_focus(&self.focus_handle(cx))
.size_full()
.child(ProfileModalHeader::new(
"Agent Profiles",
IconName::ZedAssistant,
))
.child(ProfileModalHeader::new("Agent Profiles", None))
.child(
v_flex()
.pb_1()
.child(ListSeparator)
.children(mode.profiles.iter().map(|profile| {
div()
.id(SharedString::from(format!("profile-{}", profile.id)))
.track_focus(&profile.navigation.focus_handle)
.on_action({
let profile_id = profile.id.clone();
cx.listener(move |this, _: &menu::Confirm, window, cx| {
this.view_profile(profile_id.clone(), window, cx);
})
})
.children(
mode.builtin_profiles
.iter()
.map(|profile| self.render_profile(profile, window, cx)),
)
.when(!mode.custom_profiles.is_empty(), |this| {
this.child(ListSeparator)
.child(
ListItem::new(SharedString::from(format!(
"profile-{}",
profile.id
)))
.toggle_state(
profile
.navigation
.focus_handle
.contains_focused(window, cx),
)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.child(Label::new(profile.name.clone()))
.end_slot(
h_flex()
.gap_1()
.child(Label::new("Customize").size(LabelSize::Small))
.children(KeyBinding::for_action_in(
&menu::Confirm,
&self.focus_handle,
window,
cx,
)),
)
.on_click({
let profile_id = profile.id.clone();
cx.listener(move |this, _, window, cx| {
this.view_profile(profile_id.clone(), window, cx);
})
}),
div().pl_2().pb_1().child(
Label::new("Custom Profiles")
.size(LabelSize::Small)
.color(Color::Muted),
),
)
}))
.children(
mode.custom_profiles
.iter()
.map(|profile| self.render_profile(profile, window, cx)),
)
})
.child(ListSeparator)
.child(
div()
@@ -382,7 +467,10 @@ impl ManageProfilesModal {
.into_any_element(),
)
.map(|mut navigable| {
for profile in mode.profiles {
for profile in mode.builtin_profiles {
navigable = navigable.entry(profile.navigation);
}
for profile in mode.custom_profiles {
navigable = navigable.entry(profile.navigation);
}
@@ -411,11 +499,14 @@ impl ManageProfilesModal {
.id("new-profile")
.track_focus(&self.focus_handle(cx))
.child(ProfileModalHeader::new(
match base_profile_name {
match &base_profile_name {
Some(base_profile) => format!("Fork {base_profile}"),
None => "New Profile".into(),
},
IconName::Plus,
match base_profile_name {
Some(_) => Some(IconName::Scissors),
None => Some(IconName::Plus),
},
))
.child(ListSeparator)
.child(h_flex().p_2().child(mode.name_editor.clone()))
@@ -429,20 +520,24 @@ impl ManageProfilesModal {
) -> impl IntoElement {
let settings = AssistantSettings::get_global(cx);
let profile_id = &settings.default_profile;
let profile_name = settings
.profiles
.get(&mode.profile_id)
.map(|profile| profile.name.clone())
.unwrap_or_else(|| "Unknown".into());
let icon = match profile_id.as_str() {
"write" => IconName::Pencil,
"ask" => IconName::MessageBubbles,
_ => IconName::UserRoundPen,
};
Navigable::new(
div()
.track_focus(&self.focus_handle(cx))
.size_full()
.child(ProfileModalHeader::new(
profile_name,
IconName::ZedAssistant,
))
.child(ProfileModalHeader::new(profile_name, Some(icon)))
.child(
v_flex()
.pb_1()
@@ -466,7 +561,11 @@ impl ManageProfilesModal {
)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.start_slot(Icon::new(IconName::GitBranch))
.start_slot(
Icon::new(IconName::Scissors)
.size(IconSize::Small)
.color(Color::Muted),
)
.child(Label::new("Fork Profile"))
.on_click({
let profile_id = mode.profile_id.clone();
@@ -499,7 +598,11 @@ impl ManageProfilesModal {
)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.start_slot(Icon::new(IconName::Cog))
.start_slot(
Icon::new(IconName::Settings)
.size(IconSize::Small)
.color(Color::Muted),
)
.child(Label::new("Configure Tools"))
.on_click({
let profile_id = mode.profile_id.clone();
@@ -512,12 +615,90 @@ impl ManageProfilesModal {
})
}),
),
)
.child(
div()
.id("configure-mcps")
.track_focus(&mode.configure_mcps.focus_handle)
.on_action({
let profile_id = mode.profile_id.clone();
cx.listener(move |this, _: &menu::Confirm, window, cx| {
this.configure_mcps(profile_id.clone(), window, cx);
})
})
.child(
ListItem::new("configure-mcps")
.toggle_state(
mode.configure_mcps
.focus_handle
.contains_focused(window, cx),
)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.start_slot(
Icon::new(IconName::Hammer)
.size(IconSize::Small)
.color(Color::Muted),
)
.child(Label::new("Configure MCP Servers"))
.on_click({
let profile_id = mode.profile_id.clone();
cx.listener(move |this, _, window, cx| {
this.configure_mcps(profile_id.clone(), window, cx);
})
}),
),
)
.child(ListSeparator)
.child(
div()
.id("cancel-item")
.track_focus(&mode.cancel_item.focus_handle)
.on_action({
cx.listener(move |this, _: &menu::Confirm, window, cx| {
this.cancel(window, cx);
})
})
.child(
ListItem::new("cancel-item")
.toggle_state(
mode.cancel_item
.focus_handle
.contains_focused(window, cx),
)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.start_slot(
Icon::new(IconName::ArrowLeft)
.size(IconSize::Small)
.color(Color::Muted),
)
.child(Label::new("Go Back"))
.end_slot(
div().children(
KeyBinding::for_action_in(
&menu::Cancel,
&self.focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
),
)
.on_click({
cx.listener(move |this, _, window, cx| {
this.cancel(window, cx);
})
}),
),
),
)
.into_any_element(),
)
.entry(mode.fork_profile)
.entry(mode.configure_tools)
.entry(mode.configure_mcps)
.entry(mode.cancel_item)
}
}
@@ -525,6 +706,43 @@ impl Render for ManageProfilesModal {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let settings = AssistantSettings::get_global(cx);
let go_back_item = div()
.id("cancel-item")
.track_focus(&self.focus_handle)
.on_action({
cx.listener(move |this, _: &menu::Confirm, window, cx| {
this.cancel(window, cx);
})
})
.child(
ListItem::new("cancel-item")
.toggle_state(self.focus_handle.contains_focused(window, cx))
.inset(true)
.spacing(ListItemSpacing::Sparse)
.start_slot(
Icon::new(IconName::ArrowLeft)
.size(IconSize::Small)
.color(Color::Muted),
)
.child(Label::new("Go Back"))
.end_slot(
div().children(
KeyBinding::for_action_in(
&menu::Cancel,
&self.focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
),
)
.on_click({
cx.listener(move |this, _, window, cx| {
this.cancel(window, cx);
})
}),
);
div()
.elevation_3(cx)
.w(rems(34.))
@@ -556,13 +774,39 @@ impl Render for ManageProfilesModal {
.map(|profile| profile.name.clone())
.unwrap_or_else(|| "Unknown".into());
div()
v_flex()
.pb_1()
.child(ProfileModalHeader::new(
format!("{profile_name}: Configure Tools"),
IconName::Cog,
format!("{profile_name} Configure Tools"),
Some(IconName::Cog),
))
.child(ListSeparator)
.child(tool_picker.clone())
.child(ListSeparator)
.child(go_back_item)
.into_any_element()
}
Mode::ConfigureMcps {
profile_id,
tool_picker,
..
} => {
let profile_name = settings
.profiles
.get(profile_id)
.map(|profile| profile.name.clone())
.unwrap_or_else(|| "Unknown".into());
v_flex()
.pb_1()
.child(ProfileModalHeader::new(
format!("{profile_name} — Configure MCP Servers"),
Some(IconName::Hammer),
))
.child(ListSeparator)
.child(tool_picker.clone())
.child(ListSeparator)
.child(go_back_item)
.into_any_element()
}
})

View File

@@ -3,11 +3,11 @@ use ui::prelude::*;
#[derive(IntoElement)]
pub struct ProfileModalHeader {
label: SharedString,
icon: IconName,
icon: Option<IconName>,
}
impl ProfileModalHeader {
pub fn new(label: impl Into<SharedString>, icon: IconName) -> Self {
pub fn new(label: impl Into<SharedString>, icon: Option<IconName>) -> Self {
Self {
label: label.into(),
icon,
@@ -17,22 +17,26 @@ impl ProfileModalHeader {
impl RenderOnce for ProfileModalHeader {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
h_flex()
let mut container = h_flex()
.w_full()
.px(DynamicSpacing::Base12.rems(cx))
.pt(DynamicSpacing::Base08.rems(cx))
.pb(DynamicSpacing::Base04.rems(cx))
.rounded_t_sm()
.gap_1p5()
.child(Icon::new(self.icon).size(IconSize::XSmall))
.child(
h_flex().gap_1().overflow_x_hidden().child(
div()
.max_w_96()
.overflow_x_hidden()
.text_ellipsis()
.child(Headline::new(self.label).size(HeadlineSize::XSmall)),
),
)
.gap_1p5();
if let Some(icon) = self.icon {
container = container.child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted));
}
container.child(
h_flex().gap_1().overflow_x_hidden().child(
div()
.max_w_96()
.overflow_x_hidden()
.text_ellipsis()
.child(Headline::new(self.label).size(HeadlineSize::XSmall)),
),
)
}
}

View File

@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::BTreeMap, sync::Arc};
use assistant_settings::{
AgentProfile, AgentProfileContent, AgentProfileId, AssistantSettings, AssistantSettingsContent,
@@ -6,11 +6,10 @@ use assistant_settings::{
};
use assistant_tool::{ToolSource, ToolWorkingSet};
use fs::Fs;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
use picker::{Picker, PickerDelegate};
use settings::{Settings as _, update_settings_file};
use ui::{HighlightedLabel, ListItem, ListItemSpacing, prelude::*};
use ui::{ListItem, ListItemSpacing, prelude::*};
use util::ResultExt as _;
use crate::ThreadStore;
@@ -19,11 +18,30 @@ pub struct ToolPicker {
picker: Entity<Picker<ToolPickerDelegate>>,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum ToolPickerMode {
BuiltinTools,
McpTools,
}
impl ToolPicker {
pub fn new(delegate: ToolPickerDelegate, window: &mut Window, cx: &mut Context<Self>) -> Self {
pub fn builtin_tools(
delegate: ToolPickerDelegate,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
Self { picker }
}
pub fn mcp_tools(
delegate: ToolPickerDelegate,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
Self { picker }
}
}
impl EventEmitter<DismissEvent> for ToolPicker {}
@@ -41,24 +59,31 @@ impl Render for ToolPicker {
}
#[derive(Debug, Clone)]
pub struct ToolEntry {
pub name: Arc<str>,
pub source: ToolSource,
pub enum PickerItem {
Tool {
server_id: Option<Arc<str>>,
name: Arc<str>,
},
ContextServer {
server_id: Arc<str>,
},
}
pub struct ToolPickerDelegate {
tool_picker: WeakEntity<ToolPicker>,
thread_store: WeakEntity<ThreadStore>,
fs: Arc<dyn Fs>,
tools: Vec<ToolEntry>,
items: Arc<Vec<PickerItem>>,
profile_id: AgentProfileId,
profile: AgentProfile,
matches: Vec<StringMatch>,
filtered_items: Vec<PickerItem>,
selected_index: usize,
mode: ToolPickerMode,
}
impl ToolPickerDelegate {
pub fn new(
mode: ToolPickerMode,
fs: Arc<dyn Fs>,
tool_set: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
@@ -66,33 +91,60 @@ impl ToolPickerDelegate {
profile: AgentProfile,
cx: &mut Context<ToolPicker>,
) -> Self {
let mut tool_entries = Vec::new();
for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
name: tool.name().into(),
source: source.clone(),
}));
}
let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
Self {
tool_picker: cx.entity().downgrade(),
thread_store,
fs,
tools: tool_entries,
items,
profile_id,
profile,
matches: Vec::new(),
filtered_items: Vec::new(),
selected_index: 0,
mode,
}
}
fn resolve_items(
mode: ToolPickerMode,
tool_set: &Entity<ToolWorkingSet>,
cx: &mut App,
) -> Vec<PickerItem> {
let mut items = Vec::new();
for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
match source {
ToolSource::Native => {
if mode == ToolPickerMode::BuiltinTools {
items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
name: tool.name().into(),
server_id: None,
}));
}
}
ToolSource::ContextServer { id } => {
if mode == ToolPickerMode::McpTools && !tools.is_empty() {
let server_id: Arc<str> = id.clone().into();
items.push(PickerItem::ContextServer {
server_id: server_id.clone(),
});
items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
name: tool.name().into(),
server_id: Some(server_id.clone()),
}));
}
}
}
}
items
}
}
impl PickerDelegate for ToolPickerDelegate {
type ListItem = ListItem;
type ListItem = AnyElement;
fn match_count(&self) -> usize {
self.matches.len()
self.filtered_items.len()
}
fn selected_index(&self) -> usize {
@@ -108,8 +160,25 @@ impl PickerDelegate for ToolPickerDelegate {
self.selected_index = ix;
}
fn can_select(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) -> bool {
let item = &self.filtered_items[ix];
match item {
PickerItem::Tool { .. } => true,
PickerItem::ContextServer { .. } => false,
}
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Search tools…".into()
match self.mode {
ToolPickerMode::BuiltinTools => "Search built-in tools…",
ToolPickerMode::McpTools => "Search MCP servers…",
}
.into()
}
fn update_matches(
@@ -118,74 +187,76 @@ impl PickerDelegate for ToolPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let background = cx.background_executor().clone();
let candidates = self
.tools
.iter()
.enumerate()
.map(|(id, profile)| StringMatchCandidate::new(id, profile.name.as_ref()))
.collect::<Vec<_>>();
let all_items = self.items.clone();
cx.spawn_in(window, async move |this, cx| {
let matches = if query.is_empty() {
candidates
.into_iter()
.enumerate()
.map(|(index, candidate)| StringMatch {
candidate_id: index,
string: candidate.string,
positions: Vec::new(),
score: 0.,
})
.collect()
} else {
match_strings(
&candidates,
&query,
false,
100,
&Default::default(),
background,
)
.await
};
let filtered_items = cx
.background_spawn(async move {
let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
BTreeMap::default();
for item in all_items.iter() {
if let PickerItem::Tool { server_id, name } = item.clone() {
if name.contains(&query) {
tools_by_provider.entry(server_id).or_default().push(name);
}
}
}
let mut items = Vec::new();
for (server_id, names) in tools_by_provider {
if let Some(server_id) = server_id.clone() {
items.push(PickerItem::ContextServer { server_id });
}
for name in names {
items.push(PickerItem::Tool {
server_id: server_id.clone(),
name,
});
}
}
items
})
.await;
this.update(cx, |this, _cx| {
this.delegate.matches = matches;
this.delegate.filtered_items = filtered_items;
this.delegate.selected_index = this
.delegate
.selected_index
.min(this.delegate.matches.len().saturating_sub(1));
.min(this.delegate.filtered_items.len().saturating_sub(1));
})
.log_err();
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
if self.matches.is_empty() {
if self.filtered_items.is_empty() {
self.dismissed(window, cx);
return;
}
let candidate_id = self.matches[self.selected_index].candidate_id;
let tool = &self.tools[candidate_id];
let item = &self.filtered_items[self.selected_index];
let is_enabled = match &tool.source {
ToolSource::Native => {
let is_enabled = self.profile.tools.entry(tool.name.clone()).or_default();
*is_enabled = !*is_enabled;
*is_enabled
}
ToolSource::ContextServer { id } => {
let preset = self
.profile
.context_servers
.entry(id.clone().into())
.or_default();
let is_enabled = preset.tools.entry(tool.name.clone()).or_default();
*is_enabled = !*is_enabled;
*is_enabled
}
let PickerItem::Tool {
name: tool_name,
server_id,
} = item
else {
return;
};
let is_currently_enabled = if let Some(server_id) = server_id.clone() {
let preset = self.profile.context_servers.entry(server_id).or_default();
let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
*preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
is_enabled
} else {
let is_enabled = *self.profile.tools.entry(tool_name.clone()).or_default();
*self.profile.tools.entry(tool_name.clone()).or_default() = !is_enabled;
is_enabled
};
let active_profile_id = &AssistantSettings::get_global(cx).default_profile;
@@ -200,7 +271,8 @@ impl PickerDelegate for ToolPickerDelegate {
update_settings_file::<AssistantSettings>(self.fs.clone(), cx, {
let profile_id = self.profile_id.clone();
let default_profile = self.profile.clone();
let tool = tool.clone();
let server_id = server_id.clone();
let tool_name = tool_name.clone();
move |settings: &mut AssistantSettingsContent, _cx| {
settings
.v2_setting(|v2_settings| {
@@ -228,17 +300,11 @@ impl PickerDelegate for ToolPickerDelegate {
.collect(),
});
match tool.source {
ToolSource::Native => {
*profile.tools.entry(tool.name).or_default() = is_enabled;
}
ToolSource::ContextServer { id } => {
let preset = profile
.context_servers
.entry(id.clone().into())
.or_default();
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
}
if let Some(server_id) = server_id {
let preset = profile.context_servers.entry(server_id).or_default();
*preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
} else {
*profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
}
Ok(())
@@ -259,45 +325,53 @@ impl PickerDelegate for ToolPickerDelegate {
ix: usize,
selected: bool,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let tool_match = &self.matches[ix];
let tool = &self.tools[tool_match.candidate_id];
let item = &self.filtered_items[ix];
match item {
PickerItem::ContextServer { server_id, .. } => Some(
div()
.px_2()
.pb_1()
.when(ix > 1, |this| {
this.mt_1()
.pt_2()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
})
.child(
Label::new(server_id)
.size(LabelSize::XSmall)
.color(Color::Muted),
)
.into_any_element(),
),
PickerItem::Tool { name, server_id } => {
let is_enabled = if let Some(server_id) = server_id {
self.profile
.context_servers
.get(server_id.as_ref())
.and_then(|preset| preset.tools.get(name))
.copied()
.unwrap_or(self.profile.enable_all_context_servers)
} else {
self.profile.tools.get(name).copied().unwrap_or(false)
};
let is_enabled = match &tool.source {
ToolSource::Native => self.profile.tools.get(&tool.name).copied().unwrap_or(false),
ToolSource::ContextServer { id } => self
.profile
.context_servers
.get(id.as_ref())
.and_then(|preset| preset.tools.get(&tool.name))
.copied()
.unwrap_or(self.profile.enable_all_context_servers),
};
Some(
ListItem::new(ix)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
.child(
h_flex()
.gap_2()
.child(HighlightedLabel::new(
tool_match.string.clone(),
tool_match.positions.clone(),
))
.map(|parent| match &tool.source {
ToolSource::Native => parent,
ToolSource::ContextServer { id } => parent
.child(Label::new(id).size(LabelSize::XSmall).color(Color::Muted)),
}),
Some(
ListItem::new(ix)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
.child(Label::new(name.clone()))
.end_slot::<Icon>(is_enabled.then(|| {
Icon::new(IconName::Check)
.size(IconSize::Small)
.color(Color::Success)
}))
.into_any_element(),
)
.end_slot::<Icon>(is_enabled.then(|| {
Icon::new(IconName::Check)
.size(IconSize::Small)
.color(Color::Success)
})),
)
}
}
}
}

View File

@@ -104,10 +104,9 @@ impl Render for AssistantModelSelector {
let focus_handle = self.focus_handle.clone();
let model = self.selector.read(cx).active_model(cx);
let (model_name, model_icon) = match model {
Some(model) => (model.model.name().0, Some(model.provider.icon())),
_ => (SharedString::from("No model selected"), None),
};
let model_name = model
.map(|model| model.model.name().0)
.unwrap_or_else(|| SharedString::from("No model selected"));
LanguageModelSelectorPopoverMenu::new(
self.selector.clone(),
@@ -116,11 +115,6 @@ impl Render for AssistantModelSelector {
.child(
h_flex()
.gap_0p5()
.children(
model_icon.map(|icon| {
Icon::new(icon).color(Color::Muted).size(IconSize::Small)
}),
)
.child(
Label::new(model_name)
.size(LabelSize::Small)

View File

@@ -3,6 +3,9 @@ use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use db::kvp::KEY_VALUE_STORE;
use serde::{Deserialize, Serialize};
use anyhow::{Result, anyhow};
use assistant_context_editor::{
AssistantContext, AssistantPanelDelegate, ConfigurationError, ContextEditor, ContextEvent,
@@ -34,12 +37,13 @@ use ui::{
Banner, ContextMenu, KeyBinding, PopoverMenu, PopoverMenuHandle, Tab, Tooltip, prelude::*,
};
use util::ResultExt as _;
use workspace::Workspace;
use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::OpenConfiguration;
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
use crate::active_thread::{ActiveThread, ActiveThreadEvent};
use crate::agent_diff::AgentDiff;
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
use crate::history_store::{HistoryEntry, HistoryStore, RecentEntry};
use crate::message_editor::{MessageEditor, MessageEditorEvent};
@@ -48,11 +52,18 @@ use crate::thread_history::{PastContext, PastThread, ThreadHistory};
use crate::thread_store::ThreadStore;
use crate::ui::UsageBanner;
use crate::{
AddContextServer, AgentDiff, DeleteRecentlyOpenThread, ExpandMessageEditor, InlineAssistant,
NewTextThread, NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ThreadEvent,
ToggleContextPicker, ToggleNavigationMenu, ToggleOptionsMenu,
AddContextServer, AgentDiffPane, DeleteRecentlyOpenThread, ExpandMessageEditor, Follow,
InlineAssistant, NewTextThread, NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff,
OpenHistory, ThreadEvent, ToggleContextPicker, ToggleNavigationMenu, ToggleOptionsMenu,
};
const AGENT_PANEL_KEY: &str = "agent_panel";
#[derive(Serialize, Deserialize)]
struct SerializedAssistantPanel {
width: Option<Pixels>,
}
pub fn init(cx: &mut App) {
cx.observe_new(
|workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
@@ -93,9 +104,12 @@ pub fn init(cx: &mut App) {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
let thread = panel.read(cx).thread.read(cx).thread().clone();
AgentDiff::deploy_in_workspace(thread, workspace, window, cx);
AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx);
}
})
.register_action(|workspace, _: &Follow, window, cx| {
workspace.follow(CollaboratorId::Agent, window, cx);
})
.register_action(|workspace, _: &ExpandMessageEditor, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
@@ -302,9 +316,22 @@ pub struct AssistantPanel {
assistant_navigation_menu: Option<Entity<ContextMenu>>,
width: Option<Pixels>,
height: Option<Pixels>,
pending_serialization: Option<Task<Result<()>>>,
}
impl AssistantPanel {
fn serialize(&mut self, cx: &mut Context<Self>) {
let width = self.width;
self.pending_serialization = Some(cx.background_spawn(async move {
KEY_VALUE_STORE
.write_kvp(
AGENT_PANEL_KEY.into(),
serde_json::to_string(&SerializedAssistantPanel { width })?,
)
.await?;
anyhow::Ok(())
}));
}
pub fn load(
workspace: WeakEntity<Workspace>,
prompt_builder: Arc<PromptBuilder>,
@@ -343,8 +370,19 @@ impl AssistantPanel {
})?
.await?;
workspace.update_in(cx, |workspace, window, cx| {
cx.new(|cx| {
let serialized_panel = if let Some(panel) = cx
.background_spawn(async move { KEY_VALUE_STORE.read_kvp(AGENT_PANEL_KEY) })
.await
.log_err()
.flatten()
{
Some(serde_json::from_str::<SerializedAssistantPanel>(&panel)?)
} else {
None
};
let panel = workspace.update_in(cx, |workspace, window, cx| {
let panel = cx.new(|cx| {
Self::new(
workspace,
thread_store,
@@ -353,8 +391,17 @@ impl AssistantPanel {
window,
cx,
)
})
})
});
if let Some(serialized_panel) = serialized_panel {
panel.update(cx, |panel, cx| {
panel.width = serialized_panel.width.map(|w| w.round());
cx.notify();
});
}
panel
})?;
Ok(panel)
})
}
@@ -431,6 +478,7 @@ impl AssistantPanel {
cx,
)
});
AgentDiff::set_active_thread(&workspace, &thread, window, cx);
let active_thread_subscription =
cx.subscribe(&active_thread, |_, _, event, cx| match &event {
@@ -586,6 +634,7 @@ impl AssistantPanel {
assistant_navigation_menu: None,
width: None,
height: None,
pending_serialization: None,
}
}
@@ -673,6 +722,7 @@ impl AssistantPanel {
cx,
)
});
AgentDiff::set_active_thread(&self.workspace, &thread, window, cx);
let active_thread_subscription =
cx.subscribe(&self.thread, |_, _, event, cx| match &event {
@@ -870,6 +920,7 @@ impl AssistantPanel {
cx,
)
});
AgentDiff::set_active_thread(&self.workspace, &thread, window, cx);
let active_thread_subscription =
cx.subscribe(&self.thread, |_, _, event, cx| match &event {
@@ -945,7 +996,7 @@ impl AssistantPanel {
let thread = self.thread.read(cx).thread().clone();
self.workspace
.update(cx, |workspace, cx| {
AgentDiff::deploy_in_workspace(thread, workspace, window, cx)
AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx)
})
.log_err();
}
@@ -1209,6 +1260,7 @@ impl Panel for AssistantPanel {
DockPosition::Left | DockPosition::Right => self.width = size,
DockPosition::Bottom => self.height = size,
}
self.serialize(cx);
cx.notify();
}
@@ -1841,7 +1893,7 @@ impl AssistantPanel {
.child(
Banner::new()
.severity(ui::Severity::Warning)
.children(
.child(
Label::new(
"Configure at least one LLM provider to start using the panel.",
)
@@ -1874,7 +1926,7 @@ impl AssistantPanel {
.child(
Banner::new()
.severity(ui::Severity::Warning)
.children(
.child(
h_flex()
.w_full()
.children(
@@ -1908,6 +1960,41 @@ impl AssistantPanel {
Some(UsageBanner::new(plan, usage).into_any_element())
}
fn render_tool_use_limit_reached(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let tool_use_limit_reached = self
.thread
.read(cx)
.thread()
.read(cx)
.tool_use_limit_reached();
if !tool_use_limit_reached {
return None;
}
let model = self
.thread
.read(cx)
.thread()
.read(cx)
.configured_model()?
.model;
let max_mode_upsell = if model.supports_max_mode() {
" Enable max mode for unlimited tool use."
} else {
""
};
Some(
Banner::new()
.severity(ui::Severity::Info)
.child(h_flex().child(Label::new(format!(
"Consecutive tool use limit reached.{max_mode_upsell}"
))))
.into_any_element(),
)
}
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let last_error = self.thread.read(cx).last_error()?;
@@ -2189,6 +2276,7 @@ impl Render for AssistantPanel {
.map(|parent| match &self.active_view {
ActiveView::Thread { .. } => parent
.child(self.render_active_thread_or_empty_state(window, cx))
.children(self.render_tool_use_limit_reached(cx))
.children(self.render_usage_banner(cx))
.child(h_flex().child(self.message_editor.clone()))
.children(self.render_last_error(cx)),

View File

@@ -4,10 +4,12 @@ use std::path::PathBuf;
use std::{ops::Range, path::Path, sync::Arc};
use assistant_tool::outline;
use collections::HashSet;
use collections::{HashMap, HashSet};
use editor::display_map::CreaseId;
use editor::{Addon, Editor};
use futures::future;
use futures::{FutureExt, future::Shared};
use gpui::{App, AppContext as _, Entity, SharedString, Task};
use gpui::{App, AppContext as _, Entity, SharedString, Subscription, Task};
use language::{Buffer, ParseStatus};
use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
use project::{Project, ProjectEntryId, ProjectPath, Worktree};
@@ -15,10 +17,11 @@ use prompt_store::{PromptStore, UserPromptId};
use ref_cast::RefCast;
use rope::Point;
use text::{Anchor, OffsetRangeExt as _};
use ui::{ElementId, IconName};
use ui::{Context, ElementId, IconName};
use util::markdown::MarkdownCodeBlock;
use util::{ResultExt as _, post_inc};
use crate::context_store::{ContextStore, ContextStoreEvent};
use crate::thread::Thread;
pub const RULES_ICON: IconName = IconName::Context;
@@ -67,7 +70,7 @@ pub enum AgentContextHandle {
}
impl AgentContextHandle {
fn id(&self) -> ContextId {
pub fn id(&self) -> ContextId {
match self {
Self::File(context) => context.context_id,
Self::Directory(context) => context.context_id,
@@ -1036,6 +1039,69 @@ impl Hash for AgentContextKey {
}
}
#[derive(Default)]
pub struct ContextCreasesAddon {
creases: HashMap<AgentContextKey, Vec<(CreaseId, SharedString)>>,
_subscription: Option<Subscription>,
}
impl Addon for ContextCreasesAddon {
fn to_any(&self) -> &dyn std::any::Any {
self
}
fn to_any_mut(&mut self) -> Option<&mut dyn std::any::Any> {
Some(self)
}
}
impl ContextCreasesAddon {
pub fn new() -> Self {
Self {
creases: HashMap::default(),
_subscription: None,
}
}
pub fn add_creases(
&mut self,
context_store: &Entity<ContextStore>,
key: AgentContextKey,
creases: impl IntoIterator<Item = (CreaseId, SharedString)>,
cx: &mut Context<Editor>,
) {
self.creases.entry(key).or_default().extend(creases);
self._subscription = Some(cx.subscribe(
&context_store,
|editor, _, event, cx| match event {
ContextStoreEvent::ContextRemoved(key) => {
let Some(this) = editor.addon_mut::<Self>() else {
return;
};
let (crease_ids, replacement_texts): (Vec<_>, Vec<_>) = this
.creases
.remove(key)
.unwrap_or_default()
.into_iter()
.unzip();
let ranges = editor
.remove_creases(crease_ids, cx)
.into_iter()
.map(|(_, range)| range)
.collect::<Vec<_>>();
editor.unfold_ranges(&ranges, false, false, cx);
editor.edit(ranges.into_iter().zip(replacement_texts), cx);
cx.notify();
}
},
))
}
pub fn into_inner(self) -> HashMap<AgentContextKey, Vec<(CreaseId, SharedString)>> {
self.creases
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -11,7 +11,7 @@ use std::sync::Arc;
use anyhow::{Result, anyhow};
pub use completion_provider::ContextPickerCompletionProvider;
use editor::display_map::{Crease, FoldId};
use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId};
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
use fetch_context_picker::FetchContextPicker;
use file_context_picker::FileContextPicker;
@@ -111,7 +111,7 @@ impl TryFrom<&str> for ContextPickerMode {
"symbol" => Ok(Self::Symbol),
"fetch" => Ok(Self::Fetch),
"thread" => Ok(Self::Thread),
"rules" => Ok(Self::Rules),
"rule" => Ok(Self::Rules),
_ => Err(format!("Invalid context picker mode: {}", value)),
}
}
@@ -124,7 +124,7 @@ impl ContextPickerMode {
Self::Symbol => "symbol",
Self::Fetch => "fetch",
Self::Thread => "thread",
Self::Rules => "rules",
Self::Rules => "rule",
}
}
@@ -482,7 +482,13 @@ impl ContextPicker {
return vec![];
};
recent_context_picker_entries(context_store, self.thread_store.clone(), workspace, cx)
recent_context_picker_entries(
context_store,
self.thread_store.clone(),
workspace,
None,
cx,
)
}
fn notify_current_picker(&mut self, cx: &mut Context<Self>) {
@@ -578,11 +584,12 @@ fn recent_context_picker_entries(
context_store: Entity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
workspace: Entity<Workspace>,
exclude_path: Option<ProjectPath>,
cx: &App,
) -> Vec<RecentEntry> {
let mut recent = Vec::with_capacity(6);
let current_files = context_store.read(cx).file_paths(cx);
let mut current_files = context_store.read(cx).file_paths(cx);
current_files.extend(exclude_path);
let workspace = workspace.read(cx);
let project = workspace.project().read(cx);
@@ -675,21 +682,20 @@ fn selection_ranges(
})
}
pub(crate) fn insert_fold_for_mention(
pub(crate) fn insert_crease_for_mention(
excerpt_id: ExcerptId,
crease_start: text::Anchor,
content_len: usize,
crease_label: SharedString,
crease_icon_path: SharedString,
editor_entity: Entity<Editor>,
window: &mut Window,
cx: &mut App,
) {
) -> Option<CreaseId> {
editor_entity.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
let Some(start) = snapshot.anchor_in_excerpt(excerpt_id, crease_start) else {
return;
};
let start = snapshot.anchor_in_excerpt(excerpt_id, crease_start)?;
let start = start.bias_right(&snapshot);
let end = snapshot.anchor_before(start.to_offset(&snapshot) + content_len);
@@ -701,10 +707,10 @@ pub(crate) fn insert_fold_for_mention(
editor_entity.downgrade(),
);
editor.display_map.update(cx, |display_map, cx| {
display_map.fold(vec![crease], cx);
});
});
let ids = editor.insert_creases(vec![crease.clone()], cx);
editor.fold_creases(vec![crease], false, window, cx);
Some(ids[0])
})
}
pub fn crease_for_mention(
@@ -714,20 +720,20 @@ pub fn crease_for_mention(
editor_entity: WeakEntity<Editor>,
) -> Crease<Anchor> {
let placeholder = FoldPlaceholder {
render: render_fold_icon_button(icon_path, label, editor_entity),
render: render_fold_icon_button(icon_path.clone(), label.clone(), editor_entity),
merge_adjacent: false,
..Default::default()
};
let render_trailer = move |_row, _unfold, _window: &mut Window, _cx: &mut App| Empty.into_any();
let crease = Crease::inline(
Crease::inline(
range,
placeholder.clone(),
fold_toggle("mention"),
render_trailer,
);
crease
)
.with_metadata(CreaseMetadata { icon_path, label })
}
fn render_fold_icon_button(
@@ -821,7 +827,7 @@ pub enum MentionLink {
Selection(ProjectPath, Range<usize>),
Fetch(String),
Thread(ThreadId),
Rules(UserPromptId),
Rule(UserPromptId),
}
impl MentionLink {
@@ -830,7 +836,7 @@ impl MentionLink {
const SELECTION: &str = "@selection";
const THREAD: &str = "@thread";
const FETCH: &str = "@fetch";
const RULES: &str = "@rules";
const RULE: &str = "@rule";
const SEPARATOR: &str = ":";
@@ -840,7 +846,7 @@ impl MentionLink {
|| url.starts_with(Self::FETCH)
|| url.starts_with(Self::SELECTION)
|| url.starts_with(Self::THREAD)
|| url.starts_with(Self::RULES)
|| url.starts_with(Self::RULE)
}
pub fn for_file(file_name: &str, full_path: &str) -> String {
@@ -878,8 +884,8 @@ impl MentionLink {
format!("[@{}]({}:{})", url, Self::FETCH, url)
}
pub fn for_rules(rules: &RulesContextEntry) -> String {
format!("[@{}]({}:{})", rules.title, Self::RULES, rules.prompt_id.0)
pub fn for_rule(rule: &RulesContextEntry) -> String {
format!("[@{}]({}:{})", rule.title, Self::RULE, rule.prompt_id.0)
}
pub fn try_parse(link: &str, workspace: &Entity<Workspace>, cx: &App) -> Option<Self> {
@@ -937,9 +943,9 @@ impl MentionLink {
Some(MentionLink::Thread(thread_id))
}
Self::FETCH => Some(MentionLink::Fetch(argument.to_string())),
Self::RULES => {
Self::RULE => {
let prompt_id = UserPromptId(Uuid::try_parse(argument).ok()?);
Some(MentionLink::Rules(prompt_id))
Some(MentionLink::Rule(prompt_id))
}
_ => None,
}

View File

@@ -19,9 +19,11 @@ use prompt_store::PromptStore;
use rope::Point;
use text::{Anchor, OffsetRangeExt, ToPoint};
use ui::prelude::*;
use util::ResultExt as _;
use workspace::Workspace;
use crate::context::RULES_ICON;
use crate::Thread;
use crate::context::{AgentContextHandle, AgentContextKey, ContextCreasesAddon, RULES_ICON};
use crate::context_store::ContextStore;
use crate::thread_store::ThreadStore;
@@ -235,6 +237,7 @@ pub struct ContextPickerCompletionProvider {
context_store: WeakEntity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
editor: WeakEntity<Editor>,
excluded_buffer: Option<WeakEntity<Buffer>>,
}
impl ContextPickerCompletionProvider {
@@ -243,12 +246,14 @@ impl ContextPickerCompletionProvider {
context_store: WeakEntity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
editor: WeakEntity<Editor>,
exclude_buffer: Option<WeakEntity<Buffer>>,
) -> Self {
Self {
workspace,
context_store,
thread_store,
editor,
excluded_buffer: exclude_buffer,
}
}
@@ -310,7 +315,7 @@ impl ContextPickerCompletionProvider {
let context_store = context_store.clone();
let selections = selections.clone();
let selection_infos = selection_infos.clone();
move |_, _: &mut Window, cx: &mut App| {
move |_, window: &mut Window, cx: &mut App| {
context_store.update(cx, |context_store, cx| {
for (buffer, range) in &selections {
context_store.add_selection(
@@ -323,7 +328,7 @@ impl ContextPickerCompletionProvider {
let editor = editor.clone();
let selection_infos = selection_infos.clone();
cx.defer(move |cx| {
window.defer(cx, move |window, cx| {
let mut current_offset = 0;
for (file_name, link, line_range) in selection_infos.iter() {
let snapshot =
@@ -354,9 +359,8 @@ impl ContextPickerCompletionProvider {
);
editor.update(cx, |editor, cx| {
editor.display_map.update(cx, |display_map, cx| {
display_map.fold(vec![crease], cx);
});
editor.insert_creases(vec![crease.clone()], cx);
editor.fold_creases(vec![crease], false, window, cx);
});
current_offset += text_len + 1;
@@ -419,21 +423,26 @@ impl ContextPickerCompletionProvider {
source_range.start,
new_text_len,
editor.clone(),
context_store.clone(),
move |cx| {
let thread_id = thread_entry.id.clone();
let context_store = context_store.clone();
let thread_store = thread_store.clone();
cx.spawn(async move |cx| {
let thread = thread_store
cx.spawn::<_, Option<_>>(async move |cx| {
let thread: Entity<Thread> = thread_store
.update(cx, |thread_store, cx| {
thread_store.open_thread(&thread_id, cx)
})?
.await?;
context_store.update(cx, |context_store, cx| {
context_store.add_thread(thread, false, cx)
})
})
.ok()?
.await
.log_err()?;
let context = context_store
.update(cx, |context_store, cx| {
context_store.add_thread(thread, false, cx)
})
.ok()??;
Some(context)
})
.detach_and_log_err(cx);
},
)),
}
@@ -446,7 +455,7 @@ impl ContextPickerCompletionProvider {
editor: Entity<Editor>,
context_store: Entity<ContextStore>,
) -> Completion {
let new_text = MentionLink::for_rules(&rules);
let new_text = MentionLink::for_rule(&rules);
let new_text_len = new_text.len();
Completion {
replace_range: source_range.clone(),
@@ -463,11 +472,13 @@ impl ContextPickerCompletionProvider {
source_range.start,
new_text_len,
editor.clone(),
context_store.clone(),
move |cx| {
let user_prompt_id = rules.prompt_id;
context_store.update(cx, |context_store, cx| {
context_store.add_rules(user_prompt_id, false, cx);
let context = context_store.update(cx, |context_store, cx| {
context_store.add_rules(user_prompt_id, false, cx)
});
Task::ready(context)
},
)),
}
@@ -498,27 +509,33 @@ impl ContextPickerCompletionProvider {
source_range.start,
new_text_len,
editor.clone(),
context_store.clone(),
move |cx| {
let context_store = context_store.clone();
let http_client = http_client.clone();
let url_to_fetch = url_to_fetch.clone();
cx.spawn(async move |cx| {
if context_store.update(cx, |context_store, _| {
context_store.includes_url(&url_to_fetch)
})? {
return Ok(());
if let Some(context) = context_store
.update(cx, |context_store, _| {
context_store.get_url_context(url_to_fetch.clone())
})
.ok()?
{
return Some(context);
}
let content = cx
.background_spawn(fetch_url_content(
http_client,
url_to_fetch.to_string(),
))
.await?;
context_store.update(cx, |context_store, cx| {
context_store.add_fetched_url(url_to_fetch.to_string(), content, cx)
})
.await
.log_err()?;
context_store
.update(cx, |context_store, cx| {
context_store.add_fetched_url(url_to_fetch.to_string(), content, cx)
})
.ok()
})
.detach_and_log_err(cx);
},
)),
}
@@ -577,15 +594,23 @@ impl ContextPickerCompletionProvider {
source_range.start,
new_text_len,
editor,
context_store.clone(),
move |cx| {
context_store.update(cx, |context_store, cx| {
let task = if is_directory {
Task::ready(context_store.add_directory(&project_path, false, cx))
} else {
if is_directory {
Task::ready(
context_store
.update(cx, |context_store, cx| {
context_store.add_directory(&project_path, false, cx)
})
.log_err()
.flatten(),
)
} else {
let result = context_store.update(cx, |context_store, cx| {
context_store.add_file_from_path(project_path.clone(), false, cx)
};
task.detach_and_log_err(cx);
})
});
cx.spawn(async move |_| result.await.log_err().flatten())
}
},
)),
}
@@ -640,18 +665,19 @@ impl ContextPickerCompletionProvider {
source_range.start,
new_text_len,
editor.clone(),
context_store.clone(),
move |cx| {
let symbol = symbol.clone();
let context_store = context_store.clone();
let workspace = workspace.clone();
super::symbol_context_picker::add_symbol(
let result = super::symbol_context_picker::add_symbol(
symbol.clone(),
false,
workspace.clone(),
context_store.downgrade(),
cx,
)
.detach_and_log_err(cx);
);
cx.spawn(async move |_| result.await.log_err()?.0)
},
)),
})
@@ -713,10 +739,18 @@ impl CompletionProvider for ContextPickerCompletionProvider {
let MentionCompletion { mode, argument, .. } = state;
let query = argument.unwrap_or_else(|| "".to_string());
let excluded_path = self
.excluded_buffer
.as_ref()
.and_then(WeakEntity::upgrade)
.and_then(|b| b.read(cx).file())
.map(|file| ProjectPath::from_file(file.as_ref(), cx));
let recent_entries = recent_context_picker_entries(
context_store.clone(),
thread_store.clone(),
workspace.clone(),
excluded_path.clone(),
cx,
);
@@ -749,11 +783,17 @@ impl CompletionProvider for ContextPickerCompletionProvider {
.into_iter()
.filter_map(|mat| match mat {
Match::File(FileMatch { mat, is_recent }) => {
let project_path = ProjectPath {
worktree_id: WorktreeId::from_usize(mat.worktree_id),
path: mat.path.clone(),
};
if excluded_path.as_ref() == Some(&project_path) {
return None;
}
Some(Self::completion_for_path(
ProjectPath {
worktree_id: WorktreeId::from_usize(mat.worktree_id),
path: mat.path.clone(),
},
project_path,
&mat.path_prefix,
is_recent,
mat.is_dir,
@@ -873,24 +913,44 @@ fn confirm_completion_callback(
start: Anchor,
content_len: usize,
editor: Entity<Editor>,
add_context_fn: impl Fn(&mut App) -> () + Send + Sync + 'static,
context_store: Entity<ContextStore>,
add_context_fn: impl Fn(&mut App) -> Task<Option<AgentContextHandle>> + Send + Sync + 'static,
) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> {
Arc::new(move |_, _, cx| {
add_context_fn(cx);
Arc::new(move |_, window, cx| {
let context = add_context_fn(cx);
let crease_text = crease_text.clone();
let crease_icon_path = crease_icon_path.clone();
let editor = editor.clone();
cx.defer(move |cx| {
crate::context_picker::insert_fold_for_mention(
let context_store = context_store.clone();
window.defer(cx, move |window, cx| {
let crease_id = crate::context_picker::insert_crease_for_mention(
excerpt_id,
start,
content_len,
crease_text,
crease_text.clone(),
crease_icon_path,
editor,
editor.clone(),
window,
cx,
);
cx.spawn(async move |cx| {
let crease_id = crease_id?;
let context = context.await?;
editor
.update(cx, |editor, cx| {
if let Some(addon) = editor.addon_mut::<ContextCreasesAddon>() {
addon.add_creases(
&context_store,
AgentContextKey(context),
[(crease_id, crease_text)],
cx,
);
}
})
.ok()
})
.detach();
});
false
})
@@ -1095,6 +1155,7 @@ mod tests {
"five.txt": "",
"six.txt": "",
"seven.txt": "",
"eight.txt": "",
}
}),
)
@@ -1121,9 +1182,12 @@ mod tests {
separator!("b/five.txt"),
separator!("b/six.txt"),
separator!("b/seven.txt"),
separator!("b/eight.txt"),
];
let mut opened_editors = Vec::new();
for path in paths {
workspace
let buffer = workspace
.update_in(&mut cx, |workspace, window, cx| {
workspace.open_path(
ProjectPath {
@@ -1138,6 +1202,7 @@ mod tests {
})
.await
.unwrap();
opened_editors.push(buffer);
}
let editor = workspace.update_in(&mut cx, |workspace, window, cx| {
@@ -1167,12 +1232,23 @@ mod tests {
let editor_entity = editor.downgrade();
editor.update_in(&mut cx, |editor, window, cx| {
let last_opened_buffer = opened_editors.last().and_then(|editor| {
editor
.downcast::<Editor>()?
.read(cx)
.buffer()
.read(cx)
.as_singleton()
.as_ref()
.map(Entity::downgrade)
});
window.focus(&editor.focus_handle(cx));
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace.downgrade(),
context_store.downgrade(),
None,
editor_entity,
last_opened_buffer,
))));
});

View File

@@ -130,21 +130,19 @@ impl PickerDelegate for FileContextPickerDelegate {
let is_directory = mat.is_dir;
let Some(task) = self
.context_store
self.context_store
.update(cx, |context_store, cx| {
if is_directory {
Task::ready(context_store.add_directory(&project_path, true, cx))
context_store
.add_directory(&project_path, true, cx)
.log_err();
} else {
context_store.add_file_from_path(project_path.clone(), true, cx)
context_store
.add_file_from_path(project_path.clone(), true, cx)
.detach_and_log_err(cx);
}
})
.ok()
else {
return;
};
task.detach_and_log_err(cx);
.ok();
}
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {

View File

@@ -14,6 +14,7 @@ use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
use crate::context::AgentContextHandle;
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
@@ -143,7 +144,7 @@ impl PickerDelegate for SymbolContextPickerDelegate {
let selected_index = self.selected_index;
cx.spawn(async move |this, cx| {
let included = add_symbol_task.await?;
let (_, included) = add_symbol_task.await?;
this.update(cx, |this, _| {
if let Some(mat) = this.delegate.matches.get_mut(selected_index) {
mat.is_included = included;
@@ -187,7 +188,7 @@ pub(crate) fn add_symbol(
workspace: Entity<Workspace>,
context_store: WeakEntity<ContextStore>,
cx: &mut App,
) -> Task<Result<bool>> {
) -> Task<Result<(Option<AgentContextHandle>, bool)>> {
let project = workspace.read(cx).project().clone();
let open_buffer_task = project.update(cx, |project, cx| {
project.open_buffer(symbol.path.clone(), cx)

View File

@@ -5,7 +5,7 @@ use std::sync::Arc;
use anyhow::{Result, anyhow};
use collections::{HashSet, IndexSet};
use futures::{self, FutureExt};
use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
use language::Buffer;
use language_model::LanguageModelImage;
use project::image_store::is_image_file;
@@ -31,6 +31,12 @@ pub struct ContextStore {
context_thread_ids: HashSet<ThreadId>,
}
pub enum ContextStoreEvent {
ContextRemoved(AgentContextKey),
}
impl EventEmitter<ContextStoreEvent> for ContextStore {}
impl ContextStore {
pub fn new(
project: WeakEntity<Project>,
@@ -82,7 +88,7 @@ impl ContextStore {
project_path: ProjectPath,
remove_if_exists: bool,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
) -> Task<Result<Option<AgentContextHandle>>> {
let Some(project) = self.project.upgrade() else {
return Task::ready(Err(anyhow!("failed to read project")));
};
@@ -108,21 +114,22 @@ impl ContextStore {
buffer: Entity<Buffer>,
remove_if_exists: bool,
cx: &mut Context<Self>,
) {
) -> Option<AgentContextHandle> {
let context_id = self.next_context_id.post_inc();
let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
let already_included = if self.has_context(&context) {
if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
if remove_if_exists {
self.remove_context(&context, cx);
None
} else {
Some(key.as_ref().clone())
}
true
} else if self.path_included_in_directory(project_path, cx).is_some() {
None
} else {
self.path_included_in_directory(project_path, cx).is_some()
};
if !already_included {
self.insert_context(context, cx);
self.insert_context(context.clone(), cx);
Some(context)
}
}
@@ -131,7 +138,7 @@ impl ContextStore {
project_path: &ProjectPath,
remove_if_exists: bool,
cx: &mut Context<Self>,
) -> Result<()> {
) -> Result<Option<AgentContextHandle>> {
let Some(project) = self.project.upgrade() else {
return Err(anyhow!("failed to read project"));
};
@@ -150,15 +157,20 @@ impl ContextStore {
context_id,
});
if self.has_context(&context) {
if remove_if_exists {
self.remove_context(&context, cx);
}
} else if self.path_included_in_directory(project_path, cx).is_none() {
self.insert_context(context, cx);
}
let context =
if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
if remove_if_exists {
self.remove_context(&context, cx);
None
} else {
Some(existing.as_ref().clone())
}
} else {
self.insert_context(context.clone(), cx);
Some(context)
};
anyhow::Ok(())
anyhow::Ok(context)
}
pub fn add_symbol(
@@ -169,7 +181,7 @@ impl ContextStore {
enclosing_range: Range<Anchor>,
remove_if_exists: bool,
cx: &mut Context<Self>,
) -> bool {
) -> (Option<AgentContextHandle>, bool) {
let context_id = self.next_context_id.post_inc();
let context = AgentContextHandle::Symbol(SymbolContextHandle {
buffer,
@@ -179,14 +191,18 @@ impl ContextStore {
context_id,
});
if self.has_context(&context) {
if remove_if_exists {
if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
let handle = if remove_if_exists {
self.remove_context(&context, cx);
}
return false;
None
} else {
Some(key.as_ref().clone())
};
return (handle, false);
}
self.insert_context(context, cx)
let included = self.insert_context(context.clone(), cx);
(Some(context), included)
}
pub fn add_thread(
@@ -194,16 +210,20 @@ impl ContextStore {
thread: Entity<Thread>,
remove_if_exists: bool,
cx: &mut Context<Self>,
) {
) -> Option<AgentContextHandle> {
let context_id = self.next_context_id.post_inc();
let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
if self.has_context(&context) {
if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
if remove_if_exists {
self.remove_context(&context, cx);
None
} else {
Some(existing.as_ref().clone())
}
} else {
self.insert_context(context, cx);
self.insert_context(context.clone(), cx);
Some(context)
}
}
@@ -212,19 +232,23 @@ impl ContextStore {
prompt_id: UserPromptId,
remove_if_exists: bool,
cx: &mut Context<ContextStore>,
) {
) -> Option<AgentContextHandle> {
let context_id = self.next_context_id.post_inc();
let context = AgentContextHandle::Rules(RulesContextHandle {
prompt_id,
context_id,
});
if self.has_context(&context) {
if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
if remove_if_exists {
self.remove_context(&context, cx);
None
} else {
Some(existing.as_ref().clone())
}
} else {
self.insert_context(context, cx);
self.insert_context(context.clone(), cx);
Some(context)
}
}
@@ -233,14 +257,15 @@ impl ContextStore {
url: String,
text: impl Into<SharedString>,
cx: &mut Context<ContextStore>,
) {
) -> AgentContextHandle {
let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
url: url.into(),
text: text.into(),
context_id: self.next_context_id.post_inc(),
});
self.insert_context(context, cx);
self.insert_context(context.clone(), cx);
context
}
pub fn add_image_from_path(
@@ -248,7 +273,7 @@ impl ContextStore {
project_path: ProjectPath,
remove_if_exists: bool,
cx: &mut Context<ContextStore>,
) -> Task<Result<()>> {
) -> Task<Result<Option<AgentContextHandle>>> {
let project = self.project.clone();
cx.spawn(async move |this, cx| {
let open_image_task = project.update(cx, |project, cx| {
@@ -262,7 +287,7 @@ impl ContextStore {
image,
remove_if_exists,
cx,
);
)
})
})
}
@@ -277,7 +302,7 @@ impl ContextStore {
image: Arc<Image>,
remove_if_exists: bool,
cx: &mut Context<ContextStore>,
) {
) -> Option<AgentContextHandle> {
let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
let context = AgentContextHandle::Image(ImageContext {
project_path,
@@ -288,11 +313,12 @@ impl ContextStore {
if self.has_context(&context) {
if remove_if_exists {
self.remove_context(&context, cx);
return;
return None;
}
}
self.insert_context(context, cx);
self.insert_context(context.clone(), cx);
Some(context)
}
pub fn add_selection(
@@ -364,9 +390,9 @@ impl ContextStore {
}
pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
if self
if let Some((_, key)) = self
.context_set
.shift_remove(AgentContextKey::ref_cast(context))
.shift_remove_full(AgentContextKey::ref_cast(context))
{
match context {
AgentContextHandle::Thread(thread_context) => {
@@ -375,6 +401,7 @@ impl ContextStore {
}
_ => {}
}
cx.emit(ContextStoreEvent::ContextRemoved(key));
cx.notify();
}
}
@@ -451,6 +478,12 @@ impl ContextStore {
.contains(&FetchedUrlContext::lookup_key(url.into()))
}
pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
self.context_set
.get(&FetchedUrlContext::lookup_key(url))
.map(|key| key.as_ref().clone())
}
pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
self.context()
.filter_map(|context| match context {

View File

@@ -11,7 +11,7 @@ use gpui::{
use itertools::Itertools;
use language::Buffer;
use project::ProjectItem;
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
use ui::{PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
use workspace::Workspace;
use crate::context::{AgentContextHandle, ContextKind};
@@ -357,7 +357,7 @@ impl Focusable for ContextStrip {
}
impl Render for ContextStrip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let context_picker = self.context_picker.clone();
let focus_handle = self.focus_handle.clone();
@@ -434,30 +434,6 @@ impl Render for ContextStrip {
})
.with_handle(self.context_picker_menu_handle.clone()),
)
.when(no_added_context && suggested_context.is_none(), {
|parent| {
parent.child(
h_flex()
.ml_1p5()
.gap_2()
.child(
Label::new("Add Context")
.size(LabelSize::Small)
.color(Color::Muted),
)
.opacity(0.5)
.children(
KeyBinding::for_action_in(
&ToggleContextPicker,
&focus_handle,
window,
cx,
)
.map(|binding| binding.into_any_element()),
),
)
}
})
.children(
added_contexts
.into_iter()

View File

@@ -1199,6 +1199,7 @@ impl InlineAssistant {
) -> Vec<InlineAssistId> {
let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
assist_group.linked = false;
for assist_id in &assist_group.assist_ids {
let assist = self.assists.get_mut(assist_id).unwrap();
if let Some(editor_decorations) = assist.decorations.as_ref() {

View File

@@ -1,8 +1,10 @@
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker;
use crate::context::ContextCreasesAddon;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::message_editor::{extract_message_creases, insert_message_creases};
use crate::terminal_codegen::TerminalCodegen;
use crate::thread_store::ThreadStore;
use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist};
@@ -10,7 +12,8 @@ use crate::{RemoveAllContext, ToggleContextPicker};
use client::ErrorExt;
use collections::VecDeque;
use editor::{
Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, GutterDimensions, MultiBuffer,
ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
GutterDimensions, MultiBuffer,
actions::{MoveDown, MoveUp},
};
use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag};
@@ -245,13 +248,22 @@ impl<T: 'static> PromptEditor<T> {
pub fn unlink(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let prompt = self.prompt(cx);
let existing_creases = self.editor.update(cx, extract_message_creases);
let focus = self.editor.focus_handle(cx).contains_focused(window, cx);
self.editor = cx.new(|cx| {
let mut editor = Editor::auto_height(Self::MAX_LINES as usize, window, cx);
editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
editor.set_placeholder_text(Self::placeholder_text(&self.mode, window, cx), cx);
editor.set_placeholder_text("Add a prompt…", cx);
editor.set_text(prompt, window, cx);
insert_message_creases(
&mut editor,
&existing_creases,
&self.context_store,
window,
cx,
);
if focus {
window.focus(&editor.focus_handle(cx));
}
@@ -838,6 +850,7 @@ impl PromptEditor<BufferCodegen> {
cx: &mut Context<PromptEditor<BufferCodegen>>,
) -> PromptEditor<BufferCodegen> {
let codegen_subscription = cx.observe(&codegen, Self::handle_codegen_changed);
let codegen_buffer = codegen.read(cx).buffer(cx).read(cx).as_singleton();
let mode = PromptEditorMode::Buffer {
id,
codegen,
@@ -860,8 +873,27 @@ impl PromptEditor<BufferCodegen> {
// typing in one will make what you typed appear in all of them.
editor.set_show_cursor_when_unfocused(true, cx);
editor.set_placeholder_text(Self::placeholder_text(&mode, window, cx), cx);
editor.register_addon(ContextCreasesAddon::new());
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: None,
});
editor
});
let prompt_editor_entity = prompt_editor.downgrade();
prompt_editor.update(cx, |editor, _| {
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace.clone(),
context_store.downgrade(),
thread_store.clone(),
prompt_editor_entity,
codegen_buffer.as_ref().map(Entity::downgrade),
))));
});
let context_picker_menu_handle = PopoverMenuHandle::default();
let model_selector_menu_handle = PopoverMenuHandle::default();
@@ -1013,8 +1045,25 @@ impl PromptEditor<TerminalCodegen> {
);
editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
editor.set_placeholder_text(Self::placeholder_text(&mode, window, cx), cx);
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: None,
});
editor
});
let prompt_editor_entity = prompt_editor.downgrade();
prompt_editor.update(cx, |editor, _| {
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace.clone(),
context_store.downgrade(),
thread_store.clone(),
prompt_editor_entity,
None,
))));
});
let context_picker_menu_handle = PopoverMenuHandle::default();
let model_selector_menu_handle = PopoverMenuHandle::default();

View File

@@ -2,15 +2,15 @@ use std::collections::BTreeMap;
use std::sync::Arc;
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::context::{ContextLoadResult, load_context};
use crate::context::{AgentContextKey, ContextCreasesAddon, ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use crate::ui::{AgentPreview, AnimatedLabel};
use crate::ui::{AgentPreview, AnimatedLabel, MaxModeTooltip};
use buffer_diff::BufferDiff;
use collections::HashSet;
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
use editor::{
ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorEvent, EditorMode,
EditorStyle, MultiBuffer,
AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorEvent,
EditorMode, EditorStyle, MultiBuffer,
};
use feature_flags::{FeatureFlagAppExt, NewBillingFeatureFlag};
use file_icons::FileIcons;
@@ -32,18 +32,18 @@ use std::time::Duration;
use theme::ThemeSettings;
use ui::{Disclosure, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
use workspace::{CollaboratorId, Workspace};
use zed_llm_client::CompletionMode;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::profile_selector::ProfileSelector;
use crate::thread::{Thread, TokenUsageRatio};
use crate::thread::{MessageCrease, Thread, TokenUsageRatio};
use crate::thread_store::ThreadStore;
use crate::{
ActiveThread, AgentDiff, Chat, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
ActiveThread, AgentDiffPane, Chat, ExpandMessageEditor, Follow, NewThread, OpenAgentDiff,
RemoveAllContext, ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
};
#[derive(RegisterComponent)]
@@ -97,7 +97,7 @@ pub(crate) fn create_editor(
window,
cx,
);
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
editor.set_placeholder_text("Message the agent @ to include context", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_context_menu_options(ContextMenuOptions {
@@ -105,6 +105,7 @@ pub(crate) fn create_editor(
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor.register_addon(ContextCreasesAddon::new());
editor
});
@@ -115,6 +116,7 @@ pub(crate) fn create_editor(
context_store,
Some(thread_store),
editor_entity,
None,
))));
});
editor
@@ -167,6 +169,9 @@ impl MessageEditor {
// When context changes, reload it for token counting.
let _ = this.reload_context(cx);
}),
cx.observe(&thread.read(cx).action_log().clone(), |_, _, cx| {
cx.notify()
}),
];
let model_selector = cx.new(|cx| {
@@ -195,8 +200,7 @@ impl MessageEditor {
model_selector,
edits_expanded: false,
editor_is_expanded: false,
profile_selector: cx
.new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)),
profile_selector: cx.new(|cx| ProfileSelector::new(fs, thread_store, cx)),
last_estimated_token_count: None,
update_token_count_task: None,
_subscriptions: subscriptions,
@@ -290,10 +294,11 @@ impl MessageEditor {
return;
}
let user_message = self.editor.update(cx, |editor, cx| {
let (user_message, user_message_creases) = self.editor.update(cx, |editor, cx| {
let creases = extract_message_creases(editor, cx);
let text = editor.text(cx);
editor.clear(window, cx);
text
(text, creases)
});
self.last_estimated_token_count.take();
@@ -311,7 +316,13 @@ impl MessageEditor {
thread
.update(cx, |thread, cx| {
thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx);
thread.insert_user_message(
user_message,
loaded_context,
checkpoint.ok(),
user_message_creases,
cx,
);
})
.log_err();
@@ -396,7 +407,7 @@ impl MessageEditor {
fn handle_review_click(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.edits_expanded = true;
AgentDiff::deploy(self.thread.clone(), self.workspace.clone(), window, cx).log_err();
AgentDiffPane::deploy(self.thread.clone(), self.workspace.clone(), window, cx).log_err();
cx.notify();
}
@@ -406,7 +417,8 @@ impl MessageEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Ok(diff) = AgentDiff::deploy(self.thread.clone(), self.workspace.clone(), window, cx)
if let Ok(diff) =
AgentDiffPane::deploy(self.thread.clone(), self.workspace.clone(), window, cx)
{
let path_key = multi_buffer::PathKey::for_buffer(&buffer, cx);
diff.update(cx, |diff, cx| diff.move_to_path(path_key, window, cx));
@@ -439,11 +451,49 @@ impl MessageEditor {
});
});
}))
.tooltip(Tooltip::text("Toggle Max Mode"))
.tooltip(|_, cx| cx.new(MaxModeTooltip::new).into())
.into_any_element(),
)
}
fn render_follow_toggle(&self, cx: &mut Context<Self>) -> impl IntoElement {
let following = self
.workspace
.read_with(cx, |workspace, _| {
workspace.is_being_followed(CollaboratorId::Agent)
})
.unwrap_or(false);
IconButton::new("follow-agent", IconName::Crosshair)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.toggle_state(following)
.selected_icon_color(Some(Color::Custom(cx.theme().players().agent().cursor)))
.tooltip(move |window, cx| {
if following {
Tooltip::for_action("Stop Following Agent", &Follow, window, cx)
} else {
Tooltip::with_meta(
"Follow Agent",
Some(&Follow),
"Track the agent's location as it reads and edits files.",
window,
cx,
)
}
})
.on_click(cx.listener(move |this, _, window, cx| {
this.workspace
.update(cx, |workspace, cx| {
if following {
workspace.unfollow(CollaboratorId::Agent, window, cx);
} else {
workspace.follow(CollaboratorId::Agent, window, cx);
}
})
.ok();
}))
}
fn render_editor(
&self,
font_size: Rems,
@@ -509,34 +559,39 @@ impl MessageEditor {
.items_start()
.justify_between()
.child(self.context_strip.clone())
.when(focus_handle.is_focused(window), |this| {
this.child(
IconButton::new("toggle-height", expand_icon)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
let expand_label = if is_editor_expanded {
"Minimize Message Editor".to_string()
} else {
"Expand Message Editor".to_string()
};
.child(
h_flex()
.gap_1()
.when(focus_handle.is_focused(window), |this| {
this.child(
IconButton::new("toggle-height", expand_icon)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
let expand_label = if is_editor_expanded {
"Minimize Message Editor".to_string()
} else {
"Expand Message Editor".to_string()
};
Tooltip::for_action_in(
expand_label,
&ExpandMessageEditor,
&focus_handle,
window,
cx,
)
}
})
.on_click(cx.listener(|_, _, window, cx| {
window.dispatch_action(Box::new(ExpandMessageEditor), cx);
})),
)
}),
Tooltip::for_action_in(
expand_label,
&ExpandMessageEditor,
&focus_handle,
window,
cx,
)
}
})
.on_click(cx.listener(|_, _, window, cx| {
window
.dispatch_action(Box::new(ExpandMessageEditor), cx);
})),
)
}),
),
)
.child(
v_flex()
@@ -579,7 +634,12 @@ impl MessageEditor {
h_flex()
.flex_none()
.justify_between()
.child(h_flex().gap_2().child(self.profile_selector.clone()))
.child(
h_flex()
.gap_1()
.child(self.render_follow_toggle(cx))
.child(self.profile_selector.clone()),
)
.child(
h_flex()
.gap_1()
@@ -1164,6 +1224,53 @@ impl MessageEditor {
}
}
pub fn extract_message_creases(
editor: &mut Editor,
cx: &mut Context<'_, Editor>,
) -> Vec<MessageCrease> {
let buffer_snapshot = editor.buffer().read(cx).snapshot(cx);
let mut contexts_by_crease_id = editor
.addon_mut::<ContextCreasesAddon>()
.map(std::mem::take)
.unwrap_or_default()
.into_inner()
.into_iter()
.flat_map(|(key, creases)| {
let context = key.0;
creases
.into_iter()
.map(move |(id, _)| (id, context.clone()))
})
.collect::<HashMap<_, _>>();
// Filter the addon's list of creases based on what the editor reports,
// since the addon might have removed creases in it.
let creases = editor.display_map.update(cx, |display_map, cx| {
display_map
.snapshot(cx)
.crease_snapshot
.creases()
.filter_map(|(id, crease)| {
Some((
id,
(
crease.range().to_offset(&buffer_snapshot),
crease.metadata()?.clone(),
),
))
})
.map(|(id, (range, metadata))| {
let context = contexts_by_crease_id.remove(&id);
MessageCrease {
range,
metadata,
context,
}
})
.collect()
});
creases
}
impl EventEmitter<MessageEditorEvent> for MessageEditor {}
pub enum MessageEditorEvent {
@@ -1204,6 +1311,43 @@ impl Render for MessageEditor {
}
}
pub fn insert_message_creases(
editor: &mut Editor,
message_creases: &[MessageCrease],
context_store: &Entity<ContextStore>,
window: &mut Window,
cx: &mut Context<'_, Editor>,
) {
let buffer_snapshot = editor.buffer().read(cx).snapshot(cx);
let creases = message_creases
.iter()
.map(|crease| {
let start = buffer_snapshot.anchor_after(crease.range.start);
let end = buffer_snapshot.anchor_before(crease.range.end);
crease_for_mention(
crease.metadata.label.clone(),
crease.metadata.icon_path.clone(),
start..end,
cx.weak_entity(),
)
})
.collect::<Vec<_>>();
let ids = editor.insert_creases(creases.clone(), cx);
editor.fold_creases(creases, false, window, cx);
if let Some(addon) = editor.addon_mut::<ContextCreasesAddon>() {
for (crease, id) in message_creases.iter().zip(ids) {
if let Some(context) = crease.context.as_ref() {
let key = AgentContextKey(context.clone());
addon.add_creases(
context_store,
key,
vec![(id, crease.metadata.label.clone())],
cx,
);
}
}
}
}
impl Component for MessageEditor {
fn scope() -> ComponentScope {
ComponentScope::Agent
@@ -1211,7 +1355,7 @@ impl Component for MessageEditor {
}
impl AgentPreview for MessageEditor {
fn create_preview(
fn agent_preview(
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,

View File

@@ -1,24 +1,23 @@
use std::sync::Arc;
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_settings::{
AgentProfile, AgentProfileId, AssistantSettings, GroupedAgentProfiles, builtin_profiles,
};
use fs::Fs;
use gpui::{Action, Entity, FocusHandle, Subscription, WeakEntity, prelude::*};
use indexmap::IndexMap;
use gpui::{Action, Entity, Subscription, WeakEntity, prelude::*};
use language_model::LanguageModelRegistry;
use settings::{Settings as _, SettingsStore, update_settings_file};
use ui::{
ButtonLike, ContextMenu, ContextMenuEntry, KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip,
prelude::*,
ButtonLike, ContextMenu, ContextMenuEntry, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*,
};
use util::ResultExt as _;
use crate::{ManageProfiles, ThreadStore, ToggleProfileSelector};
use crate::{ManageProfiles, ThreadStore};
pub struct ProfileSelector {
profiles: IndexMap<AgentProfileId, AgentProfile>,
profiles: GroupedAgentProfiles,
fs: Arc<dyn Fs>,
thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle,
menu_handle: PopoverMenuHandle<ContextMenu>,
_subscriptions: Vec<Subscription>,
}
@@ -27,24 +26,19 @@ impl ProfileSelector {
pub fn new(
fs: Arc<dyn Fs>,
thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle,
cx: &mut Context<Self>,
) -> Self {
let settings_subscription = cx.observe_global::<SettingsStore>(move |this, cx| {
this.refresh_profiles(cx);
});
let mut this = Self {
profiles: IndexMap::default(),
Self {
profiles: GroupedAgentProfiles::from_settings(AssistantSettings::get_global(cx)),
fs,
thread_store,
focus_handle,
menu_handle: PopoverMenuHandle::default(),
_subscriptions: vec![settings_subscription],
};
this.refresh_profiles(cx);
this
}
}
pub fn menu_handle(&self) -> PopoverMenuHandle<ContextMenu> {
@@ -52,9 +46,7 @@ impl ProfileSelector {
}
fn refresh_profiles(&mut self, cx: &mut Context<Self>) {
let settings = AssistantSettings::get_global(cx);
self.profiles = settings.profiles.clone();
self.profiles = GroupedAgentProfiles::from_settings(AssistantSettings::get_global(cx));
}
fn build_context_menu(
@@ -64,58 +56,21 @@ impl ProfileSelector {
) -> Entity<ContextMenu> {
ContextMenu::build(window, cx, |mut menu, _window, cx| {
let settings = AssistantSettings::get_global(cx);
let icon_position = IconPosition::End;
menu = menu.header("Profiles");
for (profile_id, profile) in self.profiles.clone() {
let documentation = match profile.name.to_lowercase().as_str() {
"write" => Some("Get help to write anything."),
"ask" => Some("Chat about your codebase."),
"manual" => Some("Chat about anything; no tools."),
_ => None,
};
let entry = ContextMenuEntry::new(profile.name.clone())
.toggleable(icon_position, profile_id == settings.default_profile);
let entry = if let Some(doc_text) = documentation {
entry.documentation_aside(move |_| Label::new(doc_text).into_any_element())
} else {
entry
};
menu = menu.item(entry.handler({
let fs = self.fs.clone();
let thread_store = self.thread_store.clone();
let profile_id = profile_id.clone();
move |_window, cx| {
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
let profile_id = profile_id.clone();
move |settings, _cx| {
settings.set_profile(profile_id.clone());
}
});
thread_store
.update(cx, |this, cx| {
this.load_profile_by_id(profile_id.clone(), cx);
})
.log_err();
}
}));
for (profile_id, profile) in self.profiles.builtin.iter() {
menu =
menu.item(self.menu_entry_for_profile(profile_id.clone(), profile, settings));
}
menu = menu.separator();
menu = menu.header("Customize Current Profile");
menu = menu.item(ContextMenuEntry::new("Tools…").handler({
let profile_id = settings.default_profile.clone();
move |window, cx| {
window.dispatch_action(
ManageProfiles::customize_tools(profile_id.clone()).boxed_clone(),
cx,
);
if !self.profiles.custom.is_empty() {
menu = menu.separator().header("Custom Profiles");
for (profile_id, profile) in self.profiles.custom.iter() {
menu = menu.item(self.menu_entry_for_profile(
profile_id.clone(),
profile,
settings,
));
}
}));
}
menu = menu.separator();
menu = menu.item(ContextMenuEntry::new("Configure Profiles…").handler(
@@ -127,10 +82,53 @@ impl ProfileSelector {
menu
})
}
fn menu_entry_for_profile(
&self,
profile_id: AgentProfileId,
profile: &AgentProfile,
settings: &AssistantSettings,
) -> ContextMenuEntry {
let documentation = match profile.name.to_lowercase().as_str() {
builtin_profiles::WRITE => Some("Get help to write anything."),
builtin_profiles::ASK => Some("Chat about your codebase."),
builtin_profiles::MANUAL => Some("Chat about anything with no tools."),
_ => None,
};
let entry = ContextMenuEntry::new(profile.name.clone())
.toggleable(IconPosition::End, profile_id == settings.default_profile);
let entry = if let Some(doc_text) = documentation {
entry.documentation_aside(move |_| Label::new(doc_text).into_any_element())
} else {
entry
};
entry.handler({
let fs = self.fs.clone();
let thread_store = self.thread_store.clone();
let profile_id = profile_id.clone();
move |_window, cx| {
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
let profile_id = profile_id.clone();
move |settings, _cx| {
settings.set_profile(profile_id.clone());
}
});
thread_store
.update(cx, |this, cx| {
this.load_profile_by_id(profile_id.clone(), cx);
})
.log_err();
}
})
}
}
impl Render for ProfileSelector {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let settings = AssistantSettings::get_global(cx);
let profile_id = &settings.default_profile;
let profile = settings.profiles.get(profile_id);
@@ -144,14 +142,7 @@ impl Render for ProfileSelector {
.default_model()
.map_or(false, |default| default.model.supports_tools());
let icon = match profile_id.as_str() {
"write" => IconName::Pencil,
"ask" => IconName::MessageBubbles,
_ => IconName::UserRoundPen,
};
let this = cx.entity().clone();
let focus_handle = self.focus_handle.clone();
PopoverMenu::new("profile-selector")
.menu(move |window, cx| {
@@ -161,7 +152,6 @@ impl Render for ProfileSelector {
ButtonLike::new("profile-selector-button").child(
h_flex()
.gap_1()
.child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
.child(
Label::new(selected_profile)
.size(LabelSize::Small)
@@ -171,17 +161,7 @@ impl Render for ProfileSelector {
Icon::new(IconName::ChevronDown)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(div().opacity(0.5).children({
let focus_handle = focus_handle.clone();
KeyBinding::for_action_in(
&ToggleProfileSelector,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(10.)))
})),
),
)
} else {
ButtonLike::new("tools-not-supported-button")

View File

@@ -9,6 +9,7 @@ use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use editor::display_map::CreaseMetadata;
use feature_flags::{self, FeatureFlagAppExt};
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
@@ -36,13 +37,13 @@ use settings::Settings;
use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use zed_llm_client::CompletionMode;
use zed_llm_client::{CompletionMode, CompletionRequestStatus};
use crate::ThreadStore;
use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
use crate::thread_store::{
SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
SerializedToolResult, SerializedToolUse, SharedProjectContext,
SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
@@ -96,6 +97,15 @@ impl MessageId {
}
}
/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
#[derive(Clone, Debug)]
pub struct MessageCrease {
pub range: Range<usize>,
pub metadata: CreaseMetadata,
/// None for a deserialized message, Some otherwise.
pub context: Option<AgentContextHandle>,
}
/// A message in a [`Thread`].
#[derive(Debug, Clone)]
pub struct Message {
@@ -103,6 +113,7 @@ pub struct Message {
pub role: Role,
pub segments: Vec<MessageSegment>,
pub loaded_context: LoadedContext,
pub creases: Vec<MessageCrease>,
}
impl Message {
@@ -309,6 +320,13 @@ fn default_completion_mode(cx: &App) -> CompletionMode {
}
}
#[derive(Debug, Clone, Copy)]
pub enum QueueState {
Sending,
Queued { position: usize },
Started,
}
/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
@@ -337,9 +355,11 @@ pub struct Thread {
request_token_usage: Vec<TokenUsage>,
cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>,
tool_use_limit_reached: bool,
feedback: Option<ThreadFeedback>,
message_feedback: HashMap<MessageId, ThreadFeedback>,
last_auto_capture_at: Option<Instant>,
last_received_chunk_at: Option<Instant>,
request_callback: Option<
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>,
@@ -398,9 +418,11 @@ impl Thread {
request_token_usage: Vec::new(),
cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None,
tool_use_limit_reached: false,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
last_received_chunk_at: None,
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
@@ -473,6 +495,18 @@ impl Thread {
text: message.context,
images: Vec::new(),
},
creases: message
.creases
.into_iter()
.map(|crease| MessageCrease {
range: crease.start..crease.end,
metadata: CreaseMetadata {
icon_path: crease.icon_path,
label: crease.label,
},
context: None,
})
.collect(),
})
.collect(),
next_message_id,
@@ -492,9 +526,11 @@ impl Thread {
request_token_usage: serialized.request_token_usage,
cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None,
tool_use_limit_reached: false,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
last_received_chunk_at: None,
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
@@ -602,6 +638,25 @@ impl Thread {
!self.pending_completions.is_empty() || !self.all_tools_finished()
}
/// Indicates whether streaming of language model events is stale.
/// When `is_generating()` is false, this method returns `None`.
pub fn is_generation_stale(&self) -> Option<bool> {
const STALE_THRESHOLD: u128 = 250;
self.last_received_chunk_at
.map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
}
fn received_chunk(&mut self) {
self.last_received_chunk_at = Some(Instant::now());
}
pub fn queue_state(&self) -> Option<QueueState> {
self.pending_completions
.first()
.map(|pending_completion| pending_completion.queue_state)
}
pub fn tools(&self) -> &Entity<ToolWorkingSet> {
&self.tools
}
@@ -762,6 +817,10 @@ impl Thread {
.unwrap_or(false)
}
pub fn tool_use_limit_reached(&self) -> bool {
self.tool_use_limit_reached
}
/// Returns whether all of the tool uses have finished running.
pub fn all_tools_finished(&self) -> bool {
// If the only pending tool uses left are the ones with errors, then
@@ -826,6 +885,7 @@ impl Thread {
text: impl Into<String>,
loaded_context: ContextLoadResult,
git_checkpoint: Option<GitStoreCheckpoint>,
creases: Vec<MessageCrease>,
cx: &mut Context<Self>,
) -> MessageId {
if !loaded_context.referenced_buffers.is_empty() {
@@ -840,6 +900,7 @@ impl Thread {
Role::User,
vec![MessageSegment::Text(text.into())],
loaded_context.loaded_context,
creases,
cx,
);
@@ -860,7 +921,13 @@ impl Thread {
segments: Vec<MessageSegment>,
cx: &mut Context<Self>,
) -> MessageId {
self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
self.insert_message(
Role::Assistant,
segments,
LoadedContext::default(),
Vec::new(),
cx,
)
}
pub fn insert_message(
@@ -868,6 +935,7 @@ impl Thread {
role: Role,
segments: Vec<MessageSegment>,
loaded_context: LoadedContext,
creases: Vec<MessageCrease>,
cx: &mut Context<Self>,
) -> MessageId {
let id = self.next_message_id.post_inc();
@@ -876,6 +944,7 @@ impl Thread {
role,
segments,
loaded_context,
creases,
});
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
@@ -995,6 +1064,16 @@ impl Thread {
})
.collect(),
context: message.loaded_context.text.clone(),
creases: message
.creases
.iter()
.map(|crease| SerializedCrease {
start: crease.range.start,
end: crease.range.end,
icon_path: crease.metadata.icon_path.clone(),
label: crease.metadata.label.clone(),
})
.collect(),
})
.collect(),
initial_project_snapshot,
@@ -1259,6 +1338,8 @@ impl Thread {
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
self.tool_use_limit_reached = false;
let pending_completion_id = post_inc(&mut self.completion_count);
let mut request_callback_parameters = if self.request_callback.is_some() {
Some((request.clone(), Vec::new()))
@@ -1272,23 +1353,23 @@ impl Thread {
prompt_id: prompt_id.clone(),
};
self.last_received_chunk_at = Some(Instant::now());
let task = cx.spawn(async move |thread, cx| {
let stream_completion_future = model.stream_completion_with_usage(request, &cx);
let stream_completion_future = model.stream_completion(request, &cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async {
let (mut events, usage) = stream_completion_future.await?;
let mut events = stream_completion_future.await?;
let mut stop_reason = StopReason::EndTurn;
let mut current_token_usage = TokenUsage::default();
if let Some(usage) = usage {
thread
.update(cx, |_thread, cx| {
cx.emit(ThreadEvent::UsageUpdated(usage));
})
.ok();
}
thread
.update(cx, |_thread, cx| {
cx.emit(ThreadEvent::NewRequest);
})
.ok();
let mut request_assistant_message_id = None;
@@ -1341,6 +1422,8 @@ impl Thread {
current_token_usage = token_usage;
}
LanguageModelCompletionEvent::Text(chunk) => {
thread.received_chunk();
cx.emit(ThreadEvent::ReceivedTextChunk);
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant
@@ -1369,6 +1452,8 @@ impl Thread {
text: chunk,
signature,
} => {
thread.received_chunk();
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant
&& !thread.tool_use.has_tool_results(last_message.id)
@@ -1427,6 +1512,37 @@ impl Thread {
});
}
}
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let Some(completion) = thread
.pending_completions
.iter_mut()
.find(|completion| completion.id == pending_completion_id)
{
match status_update {
CompletionRequestStatus::Queued {
position,
} => {
completion.queue_state = QueueState::Queued { position };
}
CompletionRequestStatus::Started => {
completion.queue_state = QueueState::Started;
}
CompletionRequestStatus::Failed {
code, message
} => {
return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
}
CompletionRequestStatus::UsageUpdated {
amount, limit
} => {
cx.emit(ThreadEvent::UsageUpdated(RequestUsage { limit, amount: amount as i32 }));
}
CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true;
}
}
}
}
}
thread.touch_updated_at();
@@ -1441,6 +1557,7 @@ impl Thread {
}
thread.update(cx, |thread, cx| {
thread.last_received_chunk_at = None;
thread
.pending_completions
.retain(|completion| completion.id != pending_completion_id);
@@ -1469,10 +1586,17 @@ impl Thread {
let tool_uses = thread.use_pending_tools(window, cx, model.clone());
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
StopReason::EndTurn | StopReason::MaxTokens => {
thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
});
}
},
Err(error) => {
thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
});
if error.is::<PaymentRequiredError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
} else if error.is::<MaxMonthlySpendReachedError>() {
@@ -1547,6 +1671,7 @@ impl Thread {
self.pending_completions.push(PendingCompletion {
id: pending_completion_id,
queue_state: QueueState::Sending,
_task: task,
});
}
@@ -1569,19 +1694,27 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
let stream = model.model.stream_completion_text_with_usage(request, &cx);
let (mut messages, usage) = stream.await?;
if let Some(usage) = usage {
this.update(cx, |_thread, cx| {
cx.emit(ThreadEvent::UsageUpdated(usage));
})
.ok();
}
let mut messages = model.model.stream_completion(request, &cx).await?;
let mut new_summary = String::new();
while let Some(message) = messages.stream.next().await {
let text = message?;
while let Some(event) = messages.next().await {
let event = event?;
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit },
) => {
this.update(cx, |_, cx| {
cx.emit(ThreadEvent::UsageUpdated(RequestUsage {
limit,
amount: amount as i32,
}));
})?;
continue;
}
_ => continue,
};
let mut lines = text.lines();
new_summary.extend(lines.next());
@@ -1919,6 +2052,11 @@ impl Thread {
}
self.finalize_pending_checkpoint(cx);
if canceled {
cx.emit(ThreadEvent::CompletionCanceled);
}
canceled
}
@@ -2420,6 +2558,7 @@ pub enum ThreadEvent {
UsageUpdated(RequestUsage),
StreamedCompletion,
ReceivedTextChunk,
NewRequest,
StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
StreamedToolUse {
@@ -2450,12 +2589,14 @@ pub enum ThreadEvent {
CheckpointChanged,
ToolConfirmationNeeded,
CancelEditing,
CompletionCanceled,
}
impl EventEmitter<ThreadEvent> for Thread {}
struct PendingCompletion {
id: usize,
queue_state: QueueState,
_task: Task<()>,
}
@@ -2502,7 +2643,13 @@ mod tests {
// Insert user message with context
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Please explain this code", loaded_context, None, cx)
thread.insert_user_message(
"Please explain this code",
loaded_context,
None,
Vec::new(),
cx,
)
});
// Check content and context in message object
@@ -2578,7 +2725,7 @@ fn main() {{
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message1_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Message 1", loaded_context, None, cx)
thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
});
// Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
@@ -2593,7 +2740,7 @@ fn main() {{
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Message 2", loaded_context, None, cx)
thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
});
// Third message with all three contexts (contexts 1 and 2 should be skipped)
@@ -2609,7 +2756,7 @@ fn main() {{
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message3_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Message 3", loaded_context, None, cx)
thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
});
// Check what contexts are included in each message
@@ -2723,6 +2870,7 @@ fn main() {{
"What is the best way to learn Rust?",
ContextLoadResult::default(),
None,
Vec::new(),
cx,
)
});
@@ -2756,6 +2904,7 @@ fn main() {{
"Are there any good books?",
ContextLoadResult::default(),
None,
Vec::new(),
cx,
)
});
@@ -2805,7 +2954,7 @@ fn main() {{
// Insert user message with the buffer as context
thread.update(cx, |thread, cx| {
thread.insert_user_message("Explain this code", loaded_context, None, cx)
thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
});
// Create a request and check that it doesn't have a stale buffer warning yet
@@ -2839,6 +2988,7 @@ fn main() {{
"What does the code do now?",
ContextLoadResult::default(),
None,
Vec::new(),
cx,
)
});

View File

@@ -733,6 +733,8 @@ pub struct SerializedMessage {
pub tool_results: Vec<SerializedToolResult>,
#[serde(default)]
pub context: String,
#[serde(default)]
pub creases: Vec<SerializedCrease>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -813,10 +815,19 @@ impl LegacySerializedMessage {
tool_uses: self.tool_uses,
tool_results: self.tool_results,
context: String::new(),
creases: Vec::new(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedCrease {
pub start: usize,
pub end: usize,
pub icon_path: SharedString,
pub label: SharedString,
}
struct GlobalThreadsDatabase(
Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
);

View File

@@ -2,6 +2,7 @@ mod agent_notification;
pub mod agent_preview;
mod animated_label;
mod context_pill;
mod max_mode_tooltip;
mod upsell;
mod usage_banner;
@@ -9,4 +10,5 @@ pub use agent_notification::*;
pub use agent_preview::*;
pub use animated_label::*;
pub use context_pill::*;
pub use max_mode_tooltip::*;
pub use usage_banner::*;

View File

@@ -3,7 +3,7 @@ use component::ComponentId;
use gpui::{App, Entity, WeakEntity};
use linkme::distributed_slice;
use std::sync::OnceLock;
use ui::{AnyElement, Component, Window};
use ui::{AnyElement, Component, ComponentScope, Window};
use workspace::Workspace;
use crate::{ActiveThread, ThreadStore};
@@ -22,27 +22,20 @@ pub type PreviewFn = fn(
pub static __ALL_AGENT_PREVIEWS: [fn() -> (ComponentId, PreviewFn)] = [..];
/// Trait that must be implemented by components that provide agent previews.
pub trait AgentPreview: Component {
/// Get the ID for this component
///
/// Eventually this will move to the component trait.
fn id() -> ComponentId
where
Self: Sized,
{
ComponentId(Self::name())
pub trait AgentPreview: Component + Sized {
#[allow(unused)] // We can't know this is used due to the distributed slice
fn scope(&self) -> ComponentScope {
ComponentScope::Agent
}
/// Static method to create a preview for this component type
fn create_preview(
fn agent_preview(
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement>
where
Self: Sized;
) -> Option<AnyElement>;
}
/// Register an agent preview for the given component type
@@ -55,8 +48,8 @@ macro_rules! register_agent_preview {
$crate::ui::agent_preview::PreviewFn,
) = || {
(
<$type as $crate::ui::agent_preview::AgentPreview>::id(),
<$type as $crate::ui::agent_preview::AgentPreview>::create_preview,
<$type as component::Component>::id(),
<$type as $crate::ui::agent_preview::AgentPreview>::agent_preview,
)
};
};

View File

@@ -0,0 +1,33 @@
use gpui::{Context, IntoElement, Render, Window};
use ui::{prelude::*, tooltip_container};
pub struct MaxModeTooltip;
impl MaxModeTooltip {
pub fn new(_cx: &mut Context<Self>) -> Self {
Self
}
}
impl Render for MaxModeTooltip {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
tooltip_container(_window, cx, |this, _, _| {
this.gap_1()
.child(
h_flex()
.gap_1p5()
.child(Icon::new(IconName::ZedMaxMode).size(IconSize::Small))
.child(Label::new("Zed's Max Mode"))
)
.child(
div()
.max_w_72()
.child(
Label::new("This mode enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning, offering an unfettered agentic experience.")
.size(LabelSize::Small)
.color(Color::Muted)
)
)
})
}
}

View File

@@ -66,7 +66,7 @@ impl RenderOnce for UsageBanner {
}),
};
Banner::new().severity(severity).children(
Banner::new().severity(severity).child(
h_flex().flex_1().gap_1().child(Label::new(message)).child(
h_flex()
.flex_1()

View File

@@ -107,6 +107,10 @@ impl Model {
}
}
pub fn matches_id(&self, other_id: &str) -> bool {
self.id() == other_id
}
/// The id of the model that should be used for making API requests
pub fn request_id(&self) -> &str {
match self {

View File

@@ -1815,10 +1815,6 @@ impl PromptEditor {
self.editor = cx.new(|cx| {
let mut editor = Editor::auto_height(Self::MAX_LINES as usize, window, cx);
editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
editor.set_placeholder_text(
Self::placeholder_text(self.codegen.read(cx), window, cx),
cx,
);
editor.set_placeholder_text("Add a prompt…", cx);
editor.set_text(prompt, window, cx);
if focus {

View File

@@ -5,7 +5,7 @@ use assistant_settings::AssistantSettings;
use client::telemetry::Telemetry;
use collections::{HashMap, VecDeque};
use editor::{
Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, MultiBuffer,
ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, MultiBuffer,
actions::{MoveDown, MoveUp, SelectAll},
};
use fs::Fs;
@@ -730,6 +730,11 @@ impl PromptEditor {
);
editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
editor.set_placeholder_text(Self::placeholder_text(window, cx), cx);
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: None,
});
editor
});

View File

@@ -2371,6 +2371,7 @@ impl AssistantContext {
});
match event {
LanguageModelCompletionEvent::StatusUpdate { .. } => {}
LanguageModelCompletionEvent::StartMessage { .. } => {}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
@@ -2428,8 +2429,8 @@ impl AssistantContext {
cx,
);
}
LanguageModelCompletionEvent::ToolUse(_) => {}
LanguageModelCompletionEvent::UsageUpdate(_) => {}
LanguageModelCompletionEvent::ToolUse(_) |
LanguageModelCompletionEvent::UsageUpdate(_) => {}
}
});

View File

@@ -62,7 +62,10 @@ use ui::{
prelude::*,
};
use util::{ResultExt, maybe};
use workspace::searchable::{Direction, SearchableItemHandle};
use workspace::{
CollaboratorId,
searchable::{Direction, SearchableItemHandle},
};
use workspace::{
Save, Toast, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
item::{self, FollowableItem, Item, ItemHandle},
@@ -1054,7 +1057,7 @@ impl ContextEditor {
|_, _, _, _| Empty.into_any_element(),
)
.with_metadata(CreaseMetadata {
icon: IconName::Ai,
icon_path: SharedString::from(IconName::Ai.path()),
label: "Thinking Process".into(),
}),
);
@@ -1097,7 +1100,7 @@ impl ContextEditor {
FoldPlaceholder {
render: render_fold_icon_button(
cx.entity().downgrade(),
section.icon,
section.icon.path().into(),
section.label.clone(),
),
merge_adjacent: false,
@@ -1107,7 +1110,7 @@ impl ContextEditor {
|_, _, _, _| Empty.into_any_element(),
)
.with_metadata(CreaseMetadata {
icon: section.icon,
icon_path: section.icon.path().into(),
label: section.label,
}),
);
@@ -2055,7 +2058,7 @@ impl ContextEditor {
FoldPlaceholder {
render: render_fold_icon_button(
weak_editor.clone(),
metadata.crease.icon,
metadata.crease.icon_path.clone(),
metadata.crease.label.clone(),
),
..Default::default()
@@ -2851,7 +2854,7 @@ fn render_thought_process_fold_icon_button(
fn render_fold_icon_button(
editor: WeakEntity<Editor>,
icon: IconName,
icon_path: SharedString,
label: SharedString,
) -> Arc<dyn Send + Sync + Fn(FoldId, Range<Anchor>, &mut App) -> AnyElement> {
Arc::new(move |fold_id, fold_range, _cx| {
@@ -2859,7 +2862,7 @@ fn render_fold_icon_button(
ButtonLike::new(fold_id)
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ElevatedSurface)
.child(Icon::new(icon))
.child(Icon::from_path(icon_path.clone()))
.child(Label::new(label.clone()).single_line())
.on_click(move |_, window, cx| {
editor
@@ -3417,15 +3420,14 @@ impl FollowableItem for ContextEditor {
true
}
fn set_leader_peer_id(
fn set_leader_id(
&mut self,
leader_peer_id: Option<proto::PeerId>,
leader_id: Option<CollaboratorId>,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.editor.update(cx, |editor, cx| {
editor.set_leader_peer_id(leader_peer_id, window, cx)
})
self.editor
.update(cx, |editor, cx| editor.set_leader_id(leader_id, window, cx))
}
fn dedup(&self, existing: &Self, _window: &Window, cx: &App) -> Option<item::Dedup> {

View File

@@ -5,6 +5,41 @@ use indexmap::IndexMap;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
pub mod builtin_profiles {
use super::AgentProfileId;
pub const WRITE: &str = "write";
pub const ASK: &str = "ask";
pub const MANUAL: &str = "manual";
pub fn is_builtin(profile_id: &AgentProfileId) -> bool {
profile_id.as_str() == WRITE || profile_id.as_str() == ASK || profile_id.as_str() == MANUAL
}
}
#[derive(Default)]
pub struct GroupedAgentProfiles {
pub builtin: IndexMap<AgentProfileId, AgentProfile>,
pub custom: IndexMap<AgentProfileId, AgentProfile>,
}
impl GroupedAgentProfiles {
pub fn from_settings(settings: &crate::AssistantSettings) -> Self {
let mut builtin = IndexMap::default();
let mut custom = IndexMap::default();
for (profile_id, profile) in settings.profiles.clone() {
if builtin_profiles::is_builtin(&profile_id) {
builtin.insert(profile_id, profile);
} else {
custom.insert(profile_id, profile);
}
}
Self { builtin, custom }
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AgentProfileId(pub Arc<str>);

View File

@@ -69,7 +69,7 @@ pub enum AssistantProviderContentV1 {
},
}
#[derive(Clone, Debug, Default)]
#[derive(Default, Clone, Debug)]
pub struct AssistantSettings {
pub enabled: bool,
pub button: bool,
@@ -88,6 +88,7 @@ pub struct AssistantSettings {
pub always_allow_tool_actions: bool,
pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
pub stream_edits: bool,
pub single_file_review: bool,
}
impl AssistantSettings {
@@ -224,6 +225,7 @@ impl AssistantSettingsContent {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
single_file_review: None,
},
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
},
@@ -252,6 +254,7 @@ impl AssistantSettingsContent {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
single_file_review: None,
},
None => AssistantSettingsContentV2::default(),
}
@@ -430,6 +433,14 @@ impl AssistantSettingsContent {
.ok();
}
pub fn set_single_file_review(&mut self, allow: bool) {
self.v2_setting(|setting| {
setting.single_file_review = Some(allow);
Ok(())
})
.ok();
}
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
self.v2_setting(|setting| {
setting.default_profile = Some(profile_id);
@@ -503,6 +514,7 @@ impl Default for VersionedAssistantSettingsContent {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
single_file_review: None,
})
}
}
@@ -562,6 +574,10 @@ pub struct AssistantSettingsContentV2 {
///
/// Default: false
stream_edits: Option<bool>,
/// Whether to display agent edits in single-file editors in addition to the review multibuffer pane.
///
/// Default: true
single_file_review: Option<bool>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
@@ -725,6 +741,7 @@ impl Settings for AssistantSettings {
value.notify_when_agent_waiting,
);
merge(&mut settings.stream_edits, value.stream_edits);
merge(&mut settings.single_file_review, value.single_file_review);
merge(&mut settings.default_profile, value.default_profile);
if let Some(profiles) = value.profiles {
@@ -857,6 +874,7 @@ mod tests {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
single_file_review: None,
},
)),
}

View File

@@ -29,6 +29,10 @@ impl ActionLog {
}
}
pub fn project(&self) -> &Entity<Project> {
&self.project
}
/// Notifies a diagnostics check
pub fn checked_project_diagnostics(&mut self) {
self.edited_since_project_diagnostics_check = false;

View File

@@ -1,7 +1,3 @@
mod batch_tool;
mod code_action_tool;
mod code_symbols_tool;
mod contents_tool;
mod copy_path_tool;
mod create_directory_tool;
mod create_file_tool;
@@ -17,11 +13,9 @@ mod move_path_tool;
mod now_tool;
mod open_tool;
mod read_file_tool;
mod rename_tool;
mod replace;
mod schema;
mod streaming_edit_file_tool;
mod symbol_info_tool;
mod templates;
mod terminal_tool;
mod thinking_tool;
@@ -34,7 +28,7 @@ use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
use gpui::App;
use gpui::{App, Entity};
use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool;
@@ -43,55 +37,42 @@ use web_search_tool::WebSearchTool;
pub(crate) use templates::*;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
use crate::code_symbols_tool::CodeSymbolsTool;
use crate::contents_tool::ContentsTool;
use crate::create_directory_tool::CreateDirectoryTool;
use crate::create_file_tool::CreateFileTool;
use crate::delete_path_tool::DeletePathTool;
use crate::diagnostics_tool::DiagnosticsTool;
use crate::edit_file_tool::EditFileTool;
use crate::fetch_tool::FetchTool;
use crate::find_path_tool::FindPathTool;
use crate::grep_tool::GrepTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::open_tool::OpenTool;
use crate::read_file_tool::ReadFileTool;
use crate::rename_tool::RenameTool;
use crate::streaming_edit_file_tool::StreamingEditFileTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
pub use create_file_tool::CreateFileToolInput;
pub use edit_file_tool::EditFileToolInput;
pub use create_file_tool::{CreateFileTool, CreateFileToolInput};
pub use edit_file_tool::{EditFileTool, EditFileToolInput};
pub use find_path_tool::FindPathToolInput;
pub use open_tool::OpenTool;
pub use read_file_tool::ReadFileToolInput;
pub use streaming_edit_file_tool::StreamingEditFileToolInput;
pub use terminal_tool::TerminalTool;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
let registry = ToolRegistry::global(cx);
registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
registry.register_tool(DiagnosticsTool);
registry.register_tool(ListDirectoryTool);
registry.register_tool(NowTool);
registry.register_tool(OpenTool);
registry.register_tool(CodeSymbolsTool);
registry.register_tool(ContentsTool);
registry.register_tool(FindPathTool);
registry.register_tool(ReadFileTool);
registry.register_tool(GrepTool);
registry.register_tool(RenameTool);
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
@@ -101,19 +82,12 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
cx.observe_global::<SettingsStore>(register_edit_file_tool)
.detach();
register_web_search_tool(&LanguageModelRegistry::global(cx), cx);
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |registry, event, cx| match event {
language_model::Event::DefaultModelChanged => {
let using_zed_provider = registry
.read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
register_web_search_tool(&registry, cx);
}
_ => {}
},
@@ -121,6 +95,18 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
.detach();
}
fn register_web_search_tool(registry: &Entity<LanguageModelRegistry>, cx: &mut App) {
let using_zed_provider = registry
.read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
}
fn register_edit_file_tool(cx: &mut App) {
let registry = ToolRegistry::global(cx);

View File

@@ -1,314 +0,0 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
use futures::future::join_all;
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ui::IconName;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ToolInvocation {
/// The name of the tool to invoke
pub name: String,
/// The input to the tool in JSON format
pub input: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct BatchToolInput {
/// The tool invocations to run as a batch. These tools will be run either sequentially
/// or concurrently depending on the `run_tools_concurrently` flag.
///
/// <example>
/// Basic file operations (concurrent)
///
/// ```json
/// {
/// "invocations": [
/// {
/// "name": "read_file",
/// "input": {
/// "path": "src/main.rs"
/// }
/// },
/// {
/// "name": "list_directory",
/// "input": {
/// "path": "src/lib"
/// }
/// },
/// {
/// "name": "grep",
/// "input": {
/// "regex": "fn run\\("
/// }
/// }
/// ],
/// "run_tools_concurrently": true
/// }
/// ```
/// </example>
///
/// <example>
/// Multiple find-replace operations on the same file (sequential)
///
/// ```json
/// {
/// "invocations": [
/// {
/// "name": "find_replace_file",
/// "input": {
/// "path": "src/config.rs",
/// "display_description": "Update default timeout value",
/// "find": "pub const DEFAULT_TIMEOUT: u64 = 30;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";",
/// "replace": "pub const DEFAULT_TIMEOUT: u64 = 60;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";"
/// }
/// },
/// {
/// "name": "find_replace_file",
/// "input": {
/// "path": "src/config.rs",
/// "display_description": "Update API endpoint URL",
/// "find": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";\n\npub const API_VERSION: &str = \"v1\";",
/// "replace": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.newdomain.com\";\n\npub const API_VERSION: &str = \"v1\";"
/// }
/// }
/// ],
/// "run_tools_concurrently": false
/// }
/// ```
/// </example>
///
/// <example>
/// Searching and analyzing code (concurrent)
///
/// ```json
/// {
/// "invocations": [
/// {
/// "name": "grep",
/// "input": {
/// "regex": "impl Database"
/// }
/// },
/// {
/// "name": "find_path",
/// "input": {
/// "glob": "**/*test*.rs"
/// }
/// }
/// ],
/// "run_tools_concurrently": true
/// }
/// ```
/// </example>
///
/// <example>
/// Multi-file refactoring (concurrent)
///
/// ```json
/// {
/// "invocations": [
/// {
/// "name": "find_replace_file",
/// "input": {
/// "path": "src/models/user.rs",
/// "display_description": "Add email field to User struct",
/// "find": "pub struct User {\n pub id: u64,\n pub username: String,\n pub created_at: DateTime<Utc>,\n}",
/// "replace": "pub struct User {\n pub id: u64,\n pub username: String,\n pub email: String,\n pub created_at: DateTime<Utc>,\n}"
/// }
/// },
/// {
/// "name": "find_replace_file",
/// "input": {
/// "path": "src/db/queries.rs",
/// "display_description": "Update user insertion query",
/// "find": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, created_at) VALUES ($1, $2, $3)\",\n &[&user.id, &user.username, &user.created_at],\n ).await?;\n \n Ok(())\n}",
/// "replace": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, email, created_at) VALUES ($1, $2, $3, $4)\",\n &[&user.id, &user.username, &user.email, &user.created_at],\n ).await?;\n \n Ok(())\n}"
/// }
/// }
/// ],
/// "run_tools_concurrently": true
/// }
/// ```
/// </example>
pub invocations: Vec<ToolInvocation>,
/// Whether to run the tools in this batch concurrently. If this is false (the default), the tools will run sequentially.
#[serde(default)]
pub run_tools_concurrently: bool,
}
pub struct BatchTool;
impl Tool for BatchTool {
fn name(&self) -> String {
"batch_tool".into()
}
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
serde_json::from_value::<BatchToolInput>(input.clone())
.map(|input| {
let working_set = ToolWorkingSet::default();
input.invocations.iter().any(|invocation| {
working_set
.tool(&invocation.name, cx)
.map_or(false, |tool| tool.needs_confirmation(&invocation.input, cx))
})
})
.unwrap_or(false)
}
fn description(&self) -> String {
include_str!("./batch_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::Cog
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<BatchToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<BatchToolInput>(input.clone()) {
Ok(input) => {
let count = input.invocations.len();
let mode = if input.run_tools_concurrently {
"concurrently"
} else {
"sequentially"
};
let first_tool_name = input
.invocations
.first()
.map(|inv| inv.name.clone())
.unwrap_or_default();
let all_same = input
.invocations
.iter()
.all(|invocation| invocation.name == first_tool_name);
if all_same {
format!(
"Run `{}` {} times {}",
first_tool_name,
input.invocations.len(),
mode
)
} else {
format!("Run {} tools {}", count, mode)
}
}
Err(_) => "Batch tools".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<BatchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
if input.invocations.is_empty() {
return Task::ready(Err(anyhow!("No tool invocations provided"))).into();
}
let run_tools_concurrently = input.run_tools_concurrently;
let foreground_task = {
let working_set = ToolWorkingSet::default();
let invocations = input.invocations;
let messages = messages.to_vec();
cx.spawn(async move |cx| {
let mut tasks = Vec::new();
let mut tool_names = Vec::new();
for invocation in invocations {
let tool_name = invocation.name.clone();
tool_names.push(tool_name.clone());
let tool = cx
.update(|cx| working_set.tool(&tool_name, cx))
.map_err(|err| {
anyhow!("Failed to look up tool '{}': {}", tool_name, err)
})?;
let Some(tool) = tool else {
return Err(anyhow!("Tool '{}' not found", tool_name));
};
let project = project.clone();
let action_log = action_log.clone();
let messages = messages.clone();
let tool_result = cx
.update(|cx| {
tool.run(invocation.input, &messages, project, action_log, window, cx)
})
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
tasks.push(tool_result.output);
}
Ok((tasks, tool_names))
})
};
cx.background_spawn(async move {
let (tasks, tool_names) = foreground_task.await?;
let mut results = Vec::with_capacity(tasks.len());
if run_tools_concurrently {
results.extend(join_all(tasks).await)
} else {
for task in tasks {
results.push(task.await);
}
};
let mut formatted_results = String::new();
let mut error_occurred = false;
for (i, result) in results.into_iter().enumerate() {
let tool_name = &tool_names[i];
match result {
Ok(output) => {
formatted_results
.push_str(&format!("Tool '{}' result:\n{}\n\n", tool_name, output));
}
Err(err) => {
error_occurred = true;
formatted_results
.push_str(&format!("Tool '{}' error: {}\n\n", tool_name, err));
}
}
}
if error_occurred {
formatted_results
.push_str("Note: Some tool invocations failed. See individual results above.");
}
Ok(formatted_results.trim().to_string())
})
.into()
}
}

View File

@@ -1,388 +0,0 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{self, Anchor, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{self, LspAction, Project};
use regex::Regex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{ops::Range, sync::Arc};
use ui::IconName;
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct CodeActionToolInput {
/// The relative path to the file containing the text range.
///
/// WARNING: you MUST start this path with one of the project's root directories.
pub path: String,
/// The specific code action to execute.
///
/// If this field is provided, the tool will execute the specified action.
/// If omitted, the tool will list all available code actions for the text range.
///
/// Here are some actions that are commonly supported (but may not be for this particular
/// text range; you can omit this field to list all the actions, if you want to know
/// what your options are, or you can just try an action and if it fails I'll tell you
/// what the available actions were instead):
/// - "quickfix.all" - applies all available quick fixes in the range
/// - "source.organizeImports" - sorts and cleans up import statements
/// - "source.fixAll" - applies all available auto fixes
/// - "refactor.extract" - extracts selected code into a new function or variable
/// - "refactor.inline" - inlines a variable by replacing references with its value
/// - "refactor.rewrite" - general code rewriting operations
/// - "source.addMissingImports" - adds imports for references that lack them
/// - "source.removeUnusedImports" - removes imports that aren't being used
/// - "source.implementInterface" - generates methods required by an interface/trait
/// - "source.generateAccessors" - creates getter/setter methods
/// - "source.convertToAsyncFunction" - converts callback-style code to async/await
///
/// Also, there is a special case: if you specify exactly "textDocument/rename" as the action,
/// then this will rename the symbol to whatever string you specified for the `arguments` field.
pub action: Option<String>,
/// Optional arguments to pass to the code action.
///
/// For rename operations (when action="textDocument/rename"), this should contain the new name.
/// For other code actions, these arguments may be passed to the language server.
pub arguments: Option<serde_json::Value>,
/// The text that comes immediately before the text range in the file.
pub context_before_range: String,
/// The text range. This text must appear in the file right between `context_before_range`
/// and `context_after_range`.
///
/// The file must contain exactly one occurrence of `context_before_range` followed by
/// `text_range` followed by `context_after_range`. If the file contains zero occurrences,
/// or if it contains more than one occurrence, the tool will fail, so it is absolutely
/// critical that you verify ahead of time that the string is unique. You can search
/// the file's contents to verify this ahead of time.
///
/// To make the string more likely to be unique, include a minimum of 1 line of context
/// before the text range, as well as a minimum of 1 line of context after the text range.
/// If these lines of context are not enough to obtain a string that appears only once
/// in the file, then double the number of context lines until the string becomes unique.
/// (Start with 1 line before and 1 line after though, because too much context is
/// needlessly costly.)
///
/// Do not alter the context lines of code in any way, and make sure to preserve all
/// whitespace and indentation for all lines of code. The combined string must be exactly
/// as it appears in the file, or else this tool call will fail.
pub text_range: String,
/// The text that comes immediately after the text range in the file.
pub context_after_range: String,
}
pub struct CodeActionTool;
impl Tool for CodeActionTool {
fn name(&self) -> String {
"code_actions".into()
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./code_action_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::Wand
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CodeActionToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<CodeActionToolInput>(input.clone()) {
Ok(input) => {
if let Some(action) = &input.action {
if action == "textDocument/rename" {
let new_name = match &input.arguments {
Some(serde_json::Value::String(new_name)) => new_name.clone(),
Some(value) => {
if let Ok(new_name) =
serde_json::from_value::<String>(value.clone())
{
new_name
} else {
"invalid name".to_string()
}
}
None => "missing name".to_string(),
};
format!("Rename '{}' to '{}'", input.text_range, new_name)
} else {
format!(
"Execute code action '{}' for '{}'",
action, input.text_range
)
}
} else {
format!("List available code actions for '{}'", input.text_range)
}
}
Err(_) => "Perform code action".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CodeActionToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx| {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
})?;
let range = {
let Some(range) = buffer.read_with(cx, |buffer, _cx| {
find_text_range(&buffer, &input.context_before_range, &input.text_range, &input.context_after_range)
})? else {
return Err(anyhow!(
"Failed to locate the text specified by context_before_range, text_range, and context_after_range. Make sure context_before_range and context_after_range each match exactly once in the file."
));
};
range
};
if let Some(action_type) = &input.action {
// Special-case the `rename` operation
let response = if action_type == "textDocument/rename" {
let Some(new_name) = input.arguments.and_then(|args| serde_json::from_value::<String>(args).ok()) else {
return Err(anyhow!("For rename operations, 'arguments' must be a string containing the new name"));
};
let position = buffer.read_with(cx, |buffer, _| {
range.start.to_point_utf16(&buffer.snapshot())
})?;
project
.update(cx, |project, cx| {
project.perform_rename(buffer.clone(), position, new_name.clone(), cx)
})?
.await?;
format!("Renamed '{}' to '{}'", input.text_range, new_name)
} else {
// Get code actions for the range
let actions = project
.update(cx, |project, cx| {
project.code_actions(&buffer, range.clone(), None, cx)
})?
.await?;
if actions.is_empty() {
return Err(anyhow!("No code actions available for this range"));
}
// Find all matching actions
let regex = match Regex::new(action_type) {
Ok(regex) => regex,
Err(err) => return Err(anyhow!("Invalid regex pattern: {}", err)),
};
let mut matching_actions = actions
.into_iter()
.filter(|action| { regex.is_match(action.lsp_action.title()) });
let Some(action) = matching_actions.next() else {
return Err(anyhow!("No code actions match the pattern: {}", action_type));
};
// There should have been exactly one matching action.
if let Some(second) = matching_actions.next() {
let mut all_matches = vec![action, second];
all_matches.extend(matching_actions);
return Err(anyhow!(
"Pattern '{}' matches multiple code actions: {}",
action_type,
all_matches.into_iter().map(|action| action.lsp_action.title().to_string()).collect::<Vec<_>>().join(", ")
));
}
let title = action.lsp_action.title().to_string();
project
.update(cx, |project, cx| {
project.apply_code_action(buffer.clone(), action, true, cx)
})?
.await?;
format!("Completed code action: {}", title)
};
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
})?;
Ok(response)
} else {
// No action specified, so list the available ones.
let (position_start, position_end) = buffer.read_with(cx, |buffer, _| {
let snapshot = buffer.snapshot();
(
range.start.to_point_utf16(&snapshot),
range.end.to_point_utf16(&snapshot)
)
})?;
// Convert position to display coordinates (1-based)
let position_start_display = language::Point {
row: position_start.row + 1,
column: position_start.column + 1,
};
let position_end_display = language::Point {
row: position_end.row + 1,
column: position_end.column + 1,
};
// Get code actions for the range
let actions = project
.update(cx, |project, cx| {
project.code_actions(&buffer, range.clone(), None, cx)
})?
.await?;
let mut response = format!(
"Available code actions for text range '{}' at position {}:{} to {}:{} (UTF-16 coordinates):\n\n",
input.text_range,
position_start_display.row, position_start_display.column,
position_end_display.row, position_end_display.column
);
if actions.is_empty() {
response.push_str("No code actions available for this range.");
} else {
for (i, action) in actions.iter().enumerate() {
let title = match &action.lsp_action {
LspAction::Action(code_action) => code_action.title.as_str(),
LspAction::Command(command) => command.title.as_str(),
LspAction::CodeLens(code_lens) => {
if let Some(cmd) = &code_lens.command {
cmd.title.as_str()
} else {
"Unknown code lens"
}
},
};
let kind = match &action.lsp_action {
LspAction::Action(code_action) => {
if let Some(kind) = &code_action.kind {
kind.as_str()
} else {
"unknown"
}
},
LspAction::Command(_) => "command",
LspAction::CodeLens(_) => "code_lens",
};
response.push_str(&format!("{}. {title} ({kind})\n", i + 1));
}
}
Ok(response)
}
}).into()
}
}
/// Finds the range of the text in the buffer, if it appears between context_before_range
/// and context_after_range, and if that combined string has one unique result in the buffer.
///
/// If an exact match fails, it tries adding a newline to the end of context_before_range and
/// to the beginning of context_after_range to accommodate line-based context matching.
fn find_text_range(
buffer: &Buffer,
context_before_range: &str,
text_range: &str,
context_after_range: &str,
) -> Option<Range<Anchor>> {
let snapshot = buffer.snapshot();
let text = snapshot.text();
// First try with exact match
let search_string = format!("{context_before_range}{text_range}{context_after_range}");
let mut positions = text.match_indices(&search_string);
let position_result = positions.next();
if let Some(position) = position_result {
// Check if the matched string is unique
if positions.next().is_none() {
let range_start = position.0 + context_before_range.len();
let range_end = range_start + text_range.len();
let range_start_anchor = snapshot.anchor_before(snapshot.offset_to_point(range_start));
let range_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(range_end));
return Some(range_start_anchor..range_end_anchor);
}
}
// If exact match fails or is not unique, try with line-based context
// Add a newline to the end of before context and beginning of after context
let line_based_before = if context_before_range.ends_with('\n') {
context_before_range.to_string()
} else {
format!("{context_before_range}\n")
};
let line_based_after = if context_after_range.starts_with('\n') {
context_after_range.to_string()
} else {
format!("\n{context_after_range}")
};
let line_search_string = format!("{line_based_before}{text_range}{line_based_after}");
let mut line_positions = text.match_indices(&line_search_string);
let line_position = line_positions.next()?;
// The line-based search string must also appear exactly once
if line_positions.next().is_some() {
return None;
}
let line_range_start = line_position.0 + line_based_before.len();
let line_range_end = line_range_start + text_range.len();
let line_range_start_anchor =
snapshot.anchor_before(snapshot.offset_to_point(line_range_start));
let line_range_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(line_range_end));
Some(line_range_start_anchor..line_range_end_anchor)
}

View File

@@ -1,247 +0,0 @@
use std::fmt::Write;
use std::path::PathBuf;
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::outline;
use assistant_tool::{ActionLog, Tool, ToolResult};
use collections::IndexMap;
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, Symbol};
use regex::{Regex, RegexBuilder};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::IconName;
use util::markdown::MarkdownInlineCode;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct CodeSymbolsInput {
/// The relative path of the source code file to read and get the symbols for.
/// This tool should only be used on source code files, never on any other type of file.
///
/// This path should never be absolute, and the first component
/// of the path should always be a root directory in a project.
///
/// If no path is specified, this tool returns a flat list of all symbols in the project
/// instead of a hierarchical outline of a specific file.
///
/// <example>
/// If the project has the following root directories:
///
/// - directory1
/// - directory2
///
/// If you want to access `file.md` in `directory1`, you should use the path `directory1/file.md`.
/// If you want to access `file.md` in `directory2`, you should use the path `directory2/file.md`.
/// </example>
#[serde(default)]
pub path: Option<String>,
/// Optional regex pattern to filter symbols by name.
/// When provided, only symbols whose names match this pattern will be included in the results.
///
/// <example>
/// To find only symbols that contain the word "test", use the regex pattern "test".
/// To find methods that start with "get_", use the regex pattern "^get_".
/// </example>
#[serde(default)]
pub regex: Option<String>,
/// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
///
/// <example>
/// Set to `true` to make regex matching case-sensitive.
/// </example>
#[serde(default)]
pub case_sensitive: bool,
/// Optional starting position for paginated results (0-based).
/// When not provided, starts from the beginning.
#[serde(default)]
pub offset: u32,
}
impl CodeSymbolsInput {
/// Which page of search results this is.
pub fn page(&self) -> u32 {
1 + (self.offset / RESULTS_PER_PAGE)
}
}
const RESULTS_PER_PAGE: u32 = 2000;
pub struct CodeSymbolsTool;
impl Tool for CodeSymbolsTool {
fn name(&self) -> String {
"code_symbols".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./code_symbols_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::Code
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CodeSymbolsInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<CodeSymbolsInput>(input.clone()) {
Ok(input) => {
let page = input.page();
match &input.path {
Some(path) => {
let path = MarkdownInlineCode(path);
if page > 1 {
format!("List page {page} of code symbols for {path}")
} else {
format!("List code symbols for {path}")
}
}
None => {
if page > 1 {
format!("List page {page} of project symbols")
} else {
"List all project symbols".to_string()
}
}
}
}
Err(_) => "List code symbols".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let regex = match input.regex {
Some(regex_str) => match RegexBuilder::new(&regex_str)
.case_insensitive(!input.case_sensitive)
.build()
{
Ok(regex) => Some(regex),
Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))).into(),
},
None => None,
};
cx.spawn(async move |cx| match input.path {
Some(path) => outline::file_outline(project, path, action_log, regex, cx).await,
None => project_symbols(project, regex, input.offset, cx).await,
})
.into()
}
}
async fn project_symbols(
project: Entity<Project>,
regex: Option<Regex>,
offset: u32,
cx: &mut AsyncApp,
) -> anyhow::Result<String> {
let symbols = project
.update(cx, |project, cx| project.symbols("", cx))?
.await?;
if symbols.is_empty() {
return Err(anyhow!("No symbols found in project."));
}
let mut symbols_by_path: IndexMap<PathBuf, Vec<&Symbol>> = IndexMap::default();
for symbol in symbols
.iter()
.filter(|symbol| {
if let Some(regex) = &regex {
regex.is_match(&symbol.name)
} else {
true
}
})
.skip(offset as usize)
// Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results.
.take((RESULTS_PER_PAGE as usize).saturating_add(1))
{
if let Some(worktree_path) = project.read_with(cx, |project, cx| {
project
.worktree_for_id(symbol.path.worktree_id, cx)
.map(|worktree| PathBuf::from(worktree.read(cx).root_name()))
})? {
let path = worktree_path.join(&symbol.path.path);
symbols_by_path.entry(path).or_default().push(symbol);
}
}
// If no symbols matched the filter, return early
if symbols_by_path.is_empty() {
return Err(anyhow!("No symbols found matching the criteria."));
}
let mut symbols_rendered = 0;
let mut has_more_symbols = false;
let mut output = String::new();
'outer: for (file_path, file_symbols) in symbols_by_path {
if symbols_rendered > 0 {
output.push('\n');
}
writeln!(&mut output, "{}", file_path.display()).ok();
for symbol in file_symbols {
if symbols_rendered >= RESULTS_PER_PAGE {
has_more_symbols = true;
break 'outer;
}
write!(&mut output, " {} ", symbol.label.text()).ok();
// Convert to 1-based line numbers for display
let start_line = symbol.range.start.0.row as usize + 1;
let end_line = symbol.range.end.0.row as usize + 1;
if start_line == end_line {
writeln!(&mut output, "[L{}]", start_line).ok();
} else {
writeln!(&mut output, "[L{}-{}]", start_line, end_line).ok();
}
symbols_rendered += 1;
}
}
Ok(if symbols_rendered == 0 {
"No symbols found in the requested page.".to_string()
} else if has_more_symbols {
format!(
"{output}\nShowing symbols {}-{} (more symbols were found; use offset: {} to see next page)",
offset + 1,
offset + symbols_rendered,
offset + RESULTS_PER_PAGE,
)
} else {
output
})
}

View File

@@ -1,236 +0,0 @@
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult, outline};
use gpui::{AnyWindowHandle, App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{fmt::Write, path::Path};
use ui::IconName;
use util::markdown::MarkdownInlineCode;
/// If the model requests to read a file whose size exceeds this, then
/// If the model requests to list the entries in a directory with more
/// entries than this, then the tool will return a subset of the entries
/// and suggest trying again.
const MAX_DIR_ENTRIES: usize = 1024;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ContentsToolInput {
/// The relative path of the file or directory to access.
///
/// This path should never be absolute, and the first component
/// of the path should always be a root directory in a project.
///
/// <example>
/// If the project has the following root directories:
///
/// - directory1
/// - directory2
///
/// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
/// If you want to list contents in the directory `directory2/subfolder`, you should use the path `directory2/subfolder`.
/// </example>
pub path: String,
/// Optional position (1-based index) to start reading on, if you want to read a subset of the contents.
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
///
/// Defaults to 1.
pub start: Option<u32>,
/// Optional position (1-based index) to end reading on, if you want to read a subset of the contents.
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
///
/// Defaults to reading until the end of the file or directory.
pub end: Option<u32>,
}
pub struct ContentsTool;
impl Tool for ContentsTool {
fn name(&self) -> String {
"contents".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./contents_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::FileSearch
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ContentsToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ContentsToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownInlineCode(&input.path);
match (input.start, input.end) {
(Some(start), None) => format!("Read {path} (from line {start})"),
(Some(start), Some(end)) => {
format!("Read {path} (lines {start}-{end})")
}
_ => format!("Read {path}"),
}
}
Err(_) => "Read file or directory".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ContentsToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// Sometimes models will return these even though we tell it to give a path and not a glob.
// When this happens, just list the root worktree directories.
if matches!(input.path.as_str(), "." | "" | "./" | "*") {
let output = project
.read(cx)
.worktrees(cx)
.filter_map(|worktree| {
worktree.read(cx).root_entry().and_then(|entry| {
if entry.is_dir() {
entry.path.to_str()
} else {
None
}
})
})
.collect::<Vec<_>>()
.join("\n");
return Task::ready(Ok(output)).into();
}
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found"))).into();
};
let worktree = worktree.read(cx);
let Some(entry) = worktree.entry_for_path(&project_path.path) else {
return Task::ready(Err(anyhow!("Path not found: {}", input.path))).into();
};
// If it's a directory, list its contents
if entry.is_dir() {
let mut output = String::new();
let start_index = input
.start
.map(|line| (line as usize).saturating_sub(1))
.unwrap_or(0);
let end_index = input
.end
.map(|line| (line as usize).saturating_sub(1))
.unwrap_or(MAX_DIR_ENTRIES);
let mut skipped = 0;
for (index, entry) in worktree.child_entries(&project_path.path).enumerate() {
if index >= start_index && index <= end_index {
writeln!(
output,
"{}",
Path::new(worktree.root_name()).join(&entry.path).display(),
)
.unwrap();
} else {
skipped += 1;
}
}
if output.is_empty() {
output.push_str(&input.path);
output.push_str(" is empty.");
}
if skipped > 0 {
write!(
output,
"\n\nNote: Skipped {skipped} entries. Adjust start and end to see other entries.",
).ok();
}
Task::ready(Ok(output)).into()
} else {
// It's a file, so read its contents
let file_path = input.path.clone();
cx.spawn(async move |cx| {
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
})?
.await?;
if input.start.is_some() || input.end.is_some() {
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let start = input.start.unwrap_or(1);
let lines = text.split('\n').skip(start as usize - 1);
if let Some(end) = input.end {
let count = end.saturating_sub(start).max(1); // Ensure at least 1 line
Itertools::intersperse(lines.take(count as usize), "\n").collect()
} else {
Itertools::intersperse(lines, "\n").collect()
}
})?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
})?;
Ok(result)
} else {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
if file_size <= outline::AUTO_OUTLINE_SIZE {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
})?;
Ok(result)
} else {
// File is too big, so return its outline and a suggestion to
// read again with a line number range specified.
let outline = outline::file_outline(project, file_path, action_log, None, cx).await?;
Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline."))
}
}
}).into()
}
}
}

View File

@@ -45,7 +45,7 @@ impl Tool for CopyPathTool {
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
false
}
fn description(&self) -> String {

View File

@@ -35,7 +35,7 @@ impl Tool for CreateDirectoryTool {
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
false
}
fn description(&self) -> String {

View File

@@ -34,7 +34,7 @@ impl Tool for DeletePathTool {
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
false
}
fn description(&self) -> String {

View File

@@ -19,6 +19,7 @@ use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role,
};
use project::{AgentLocation, Project};
use serde::Serialize;
use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff};
@@ -59,17 +60,20 @@ pub struct EditAgentOutput {
pub struct EditAgent {
model: Arc<dyn LanguageModel>,
action_log: Entity<ActionLog>,
project: Entity<Project>,
templates: Arc<Templates>,
}
impl EditAgent {
pub fn new(
model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
) -> Self {
EditAgent {
model,
project,
action_log,
templates,
}
@@ -118,39 +122,74 @@ impl EditAgent {
let (output_events_tx, output_events_rx) = mpsc::unbounded();
let this = self.clone();
let task = cx.spawn(async move |cx| {
// Ensure the buffer is tracked by the action log.
this.action_log
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
this.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
})?;
let mut raw_edits = String::new();
pin_mut!(edit_chunks);
while let Some(chunk) = edit_chunks.next().await {
let chunk = chunk?;
raw_edits.push_str(&chunk);
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
this.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
})?;
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
}
Ok(EditAgentOutput {
_raw_edits: raw_edits,
_parser_metrics: EditParserMetrics::default(),
})
let output = this
.replace_text_with_chunks_internal(buffer, edit_chunks, output_events_tx, cx)
.await;
this.project
.update(cx, |project, cx| project.set_agent_location(None, cx))?;
output
});
(task, output_events_rx)
}
async fn replace_text_with_chunks_internal(
&self,
buffer: Entity<Buffer>,
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
cx: &mut AsyncApp,
) -> Result<EditAgentOutput> {
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
self.action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx);
});
self.project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX,
}),
cx,
)
});
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
})?;
let mut raw_edits = String::new();
pin_mut!(edit_chunks);
while let Some(chunk) = edit_chunks.next().await {
let chunk = chunk?;
raw_edits.push_str(&chunk);
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
self.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
self.project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX,
}),
cx,
)
});
})?;
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
}
Ok(EditAgentOutput {
_raw_edits: raw_edits,
_parser_metrics: EditParserMetrics::default(),
})
}
pub fn edit(
&self,
buffer: Entity<Buffer>,
@@ -161,6 +200,18 @@ impl EditAgent {
Task<Result<EditAgentOutput>>,
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
) {
self.project
.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MIN,
}),
cx,
);
})
.ok();
let this = self.clone();
let (events_tx, events_rx) = mpsc::unbounded();
let output = cx.spawn(async move |cx| {
@@ -194,8 +245,14 @@ impl EditAgent {
let (output_events_tx, output_events_rx) = mpsc::unbounded();
let this = self.clone();
let task = cx.spawn(async move |mut cx| {
this.apply_edits_internal(buffer, edit_chunks, output_events_tx, &mut cx)
.await
this.action_log
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
let output = this
.apply_edits_internal(buffer, edit_chunks, output_events_tx, &mut cx)
.await;
this.project
.update(cx, |project, cx| project.set_agent_location(None, cx))?;
output
});
(task, output_events_rx)
}
@@ -207,10 +264,6 @@ impl EditAgent {
output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
cx: &mut AsyncApp,
) -> Result<EditAgentOutput> {
// Ensure the buffer is tracked by the action log.
self.action_log
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
let (output, mut edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
while let Some(edit_event) = edit_events.next().await {
let EditParserEvent::OldText(old_text_query) = edit_event? else {
@@ -275,14 +328,15 @@ impl EditAgent {
match op {
CharOperation::Insert { text } => {
let edit_start = snapshot.anchor_after(edit_start);
edits_tx.unbounded_send((edit_start..edit_start, text))?;
edits_tx
.unbounded_send((edit_start..edit_start, Arc::from(text)))?;
}
CharOperation::Delete { bytes } => {
let edit_end = edit_start + bytes;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
edits_tx.unbounded_send((edit_range, String::new()))?;
edits_tx.unbounded_send((edit_range, Arc::from("")))?;
}
CharOperation::Keep { bytes } => edit_start += bytes,
}
@@ -296,13 +350,35 @@ impl EditAgent {
// TODO: group all edits into one transaction
let mut edits_rx = edits_rx.ready_chunks(32);
while let Some(edits) = edits_rx.next().await {
if edits.is_empty() {
continue;
}
// Edit the buffer and report edits to the action log as part of the
// same effect cycle, otherwise the edit will be reported as if the
// user made it.
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.edit(edits, None, cx));
let max_edit_end = buffer.update(cx, |buffer, cx| {
buffer.edit(edits.iter().cloned(), None, cx);
let max_edit_end = buffer
.summaries_for_anchors::<Point, _>(
edits.iter().map(|(range, _)| &range.end),
)
.max()
.unwrap();
buffer.anchor_before(max_edit_end)
});
self.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx))
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
self.project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: max_edit_end,
}),
cx,
);
});
})?;
output_events
.unbounded_send(EditAgentOutputEvent::Edited)
@@ -657,7 +733,7 @@ mod tests {
use gpui::{App, AppContext, TestAppContext};
use indoc::indoc;
use language_model::fake_provider::FakeLanguageModel;
use project::Project;
use project::{AgentLocation, Project};
use rand::prelude::*;
use rand::rngs::StdRng;
use std::cmp;
@@ -775,8 +851,11 @@ mod tests {
}
#[gpui::test]
async fn test_events(cx: &mut TestAppContext) {
async fn test_edit_events(cx: &mut TestAppContext) {
let agent = init_test(cx).await;
let project = agent
.action_log
.read_with(cx, |log, _| log.project().clone());
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
let (apply, mut events) = agent.apply_edit_chunks(
@@ -792,6 +871,10 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abc\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
None
);
chunks_tx.unbounded_send("bc</old_text>").unwrap();
cx.run_until_parked();
@@ -800,6 +883,10 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abc\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
None
);
chunks_tx.unbounded_send("<new_text>abX").unwrap();
cx.run_until_parked();
@@ -808,6 +895,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXc\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
})
);
chunks_tx.unbounded_send("cY").unwrap();
cx.run_until_parked();
@@ -816,6 +910,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
})
);
chunks_tx.unbounded_send("</new_text>").unwrap();
chunks_tx.unbounded_send("<old_text>hall").unwrap();
@@ -825,6 +926,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
})
);
chunks_tx.unbounded_send("ucinated old</old_text>").unwrap();
chunks_tx.unbounded_send("<new_text>").unwrap();
@@ -839,6 +947,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
})
);
chunks_tx.unbounded_send("hallucinated new</new_").unwrap();
chunks_tx.unbounded_send("text>").unwrap();
@@ -848,6 +963,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
})
);
chunks_tx.unbounded_send("<old_text>gh").unwrap();
chunks_tx.unbounded_send("i</old_text>").unwrap();
@@ -858,6 +980,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
})
);
chunks_tx.unbounded_send("GHI</new_text>").unwrap();
cx.run_until_parked();
@@ -869,6 +998,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nGHI"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
})
);
drop(chunks_tx);
apply.await.unwrap();
@@ -877,16 +1013,108 @@ mod tests {
"abXcY\ndef\nGHI"
);
assert_eq!(drain_events(&mut events), vec![]);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
None
);
}
fn drain_events(
stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
) -> Vec<EditAgentOutputEvent> {
let mut events = Vec::new();
while let Ok(Some(event)) = stream.try_next() {
events.push(event);
}
events
}
#[gpui::test]
async fn test_overwrite_events(cx: &mut TestAppContext) {
let agent = init_test(cx).await;
let project = agent
.action_log
.read_with(cx, |log, _| log.project().clone());
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
let (apply, mut events) = agent.replace_text_with_chunks(
buffer.clone(),
chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
&mut cx.to_async(),
);
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
""
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX
})
);
chunks_tx.unbounded_send("jkl\n").unwrap();
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"jkl\n"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX
})
);
chunks_tx.unbounded_send("mno\n").unwrap();
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"jkl\nmno\n"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX
})
);
chunks_tx.unbounded_send("pqr").unwrap();
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"jkl\nmno\npqr"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX
})
);
drop(chunks_tx);
apply.await.unwrap();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"jkl\nmno\npqr"
);
assert_eq!(drain_events(&mut events), vec![]);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
None
);
}
#[gpui::test]
@@ -1173,7 +1401,17 @@ mod tests {
cx.update(Project::init_settings);
let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
let model = Arc::new(FakeLanguageModel::default());
let action_log = cx.new(|_| ActionLog::new(project));
EditAgent::new(model, action_log, Templates::new())
let action_log = cx.new(|_| ActionLog::new(project.clone()));
EditAgent::new(model, project, action_log, Templates::new())
}
fn drain_events(
stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
) -> Vec<EditAgentOutputEvent> {
let mut events = Vec::new();
while let Ok(Some(event)) = stream.try_next() {
events.push(event);
}
events
}
}

View File

@@ -517,7 +517,7 @@ fn eval_from_pixels_constructor() {
input_path: input_file_path.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::assert_eq(indoc! {"
assertion: EvalAssertion::judge_diff(indoc! {"
- The diff contains a new `from_pixels` constructor
- The diff contains new tests for the `from_pixels` constructor
"}),
@@ -957,7 +957,7 @@ impl EditAgentTest {
cx.spawn(async move |cx| {
let agent_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
Self::load_model("google", "gemini-2.5-pro-preview-03-25", cx).await;
let judge_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
(agent_model.unwrap(), judge_model.unwrap())
@@ -967,7 +967,7 @@ impl EditAgentTest {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self {
agent: EditAgent::new(agent_model, action_log, Templates::new()),
agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
project,
judge_model,
}

View File

@@ -7,19 +7,21 @@ use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUse
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, EditorMode, MultiBuffer, PathKey};
use gpui::{
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId, Task, WeakEntity,
Animation, AnimationExt, AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId,
Task, WeakEntity, pulsating_between,
};
use language::{
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
language_settings::SoftWrap,
};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use project::{AgentLocation, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use ui::{Disclosure, Tooltip, Window, prelude::*};
use util::ResultExt;
@@ -162,6 +164,19 @@ impl Tool for EditFileTool {
})?
.await?;
// Set the agent's location to the top of the file
project
.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MIN,
}),
cx,
);
})
.ok();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
if input.old_string.is_empty() {
@@ -224,6 +239,7 @@ impl Tool for EditFileTool {
let snapshot = cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
let base_version = diff.base_version.clone();
let snapshot = buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction();
buffer.apply_diff(diff, cx);
@@ -231,6 +247,21 @@ impl Tool for EditFileTool {
buffer.snapshot()
});
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
// Set the agent's location to the position of the first edit
if let Some(first_edit) = snapshot.edits_since::<usize>(&base_version).next() {
let position = snapshot.anchor_before(first_edit.new.start);
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position,
}),
cx,
);
})
}
snapshot
})?;
@@ -302,6 +333,7 @@ impl EditFileToolCard {
editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_show_scrollbars(false, cx);
editor.set_show_indent_guides(false, cx);
editor.set_read_only(true);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
@@ -323,6 +355,10 @@ impl EditFileToolCard {
}
}
pub fn has_diff(&self) -> bool {
self.total_lines.is_some()
}
pub fn set_diff(
&mut self,
path: Arc<Path>,
@@ -463,45 +499,44 @@ impl ToolCard for EditFileToolCard {
.rounded_t_md()
.when(!failed, |header| header.bg(codeblock_header_bg))
.child(path_label_button)
.map(|container| {
if failed {
container.child(
h_flex()
.gap_1()
.child(
Icon::new(IconName::Close)
.size(IconSize::Small)
.color(Color::Error),
)
.child(
Disclosure::new(
("edit-file-error-disclosure", self.editor_unique_id),
self.error_expanded,
)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener(
move |this, _event, _window, _cx| {
this.error_expanded = !this.error_expanded;
},
)),
),
)
} else {
container.child(
Disclosure::new(
("edit-file-disclosure", self.editor_unique_id),
self.preview_expanded,
.when(failed, |header| {
header.child(
h_flex()
.gap_1()
.child(
Icon::new(IconName::Close)
.size(IconSize::Small)
.color(Color::Error),
)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener(
move |this, _event, _window, _cx| {
this.preview_expanded = !this.preview_expanded;
},
)),
.child(
Disclosure::new(
("edit-file-error-disclosure", self.editor_unique_id),
self.error_expanded,
)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener(
move |this, _event, _window, _cx| {
this.error_expanded = !this.error_expanded;
},
)),
),
)
})
.when(!failed && self.has_diff(), |header| {
header.child(
Disclosure::new(
("edit-file-disclosure", self.editor_unique_id),
self.preview_expanded,
)
}
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener(
move |this, _event, _window, _cx| {
this.preview_expanded = !this.preview_expanded;
},
)),
)
});
let (editor, editor_line_height) = self.editor.update(cx, |editor, cx| {
@@ -538,6 +573,50 @@ impl ToolCard for EditFileToolCard {
const DEFAULT_COLLAPSED_LINES: u32 = 10;
let is_collapsible = self.total_lines.unwrap_or(0) > DEFAULT_COLLAPSED_LINES;
let waiting_for_diff = {
let styles = [
("w_4_5", (0.1, 0.85), 2000),
("w_1_4", (0.2, 0.75), 2200),
("w_2_4", (0.15, 0.64), 1900),
("w_3_5", (0.25, 0.72), 2300),
("w_2_5", (0.3, 0.56), 1800),
];
let mut container = v_flex()
.p_3()
.gap_1p5()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background);
for (width_method, pulse_range, duration_ms) in styles.iter() {
let (min_opacity, max_opacity) = *pulse_range;
let placeholder = match *width_method {
"w_4_5" => div().w_3_4(),
"w_1_4" => div().w_1_4(),
"w_2_4" => div().w_2_4(),
"w_3_5" => div().w_3_5(),
"w_2_5" => div().w_2_5(),
_ => div().w_1_2(),
}
.id("loading_div")
.h_2()
.rounded_full()
.bg(cx.theme().colors().element_active)
.with_animation(
"loading_pulsate",
Animation::new(Duration::from_millis(*duration_ms))
.repeat()
.with_easing(pulsating_between(min_opacity, max_opacity)),
|label, delta| label.opacity(delta),
);
container = container.child(placeholder);
}
container
};
v_flex()
.mb_2()
.border_1()
@@ -573,50 +652,58 @@ impl ToolCard for EditFileToolCard {
),
)
})
.when(!failed && self.preview_expanded, |card| {
card.child(
v_flex()
.relative()
.h_full()
.when(!self.full_height_expanded, |editor_container| {
editor_container
.max_h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
})
.overflow_hidden()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.child(div().pl_1().child(editor))
.when(
!self.full_height_expanded && is_collapsible,
|editor_container| editor_container.child(gradient_overlay),
),
)
.when(is_collapsible, |editor_container| {
editor_container.child(
h_flex()
.id(("expand-button", self.editor_unique_id))
.flex_none()
.cursor_pointer()
.h_5()
.justify_center()
.rounded_b_md()
.when(!self.has_diff() && !failed, |card| {
card.child(waiting_for_diff)
})
.when(
!failed && self.preview_expanded && self.has_diff(),
|card| {
card.child(
v_flex()
.relative()
.h_full()
.when(!self.full_height_expanded, |editor_container| {
editor_container
.max_h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
})
.overflow_hidden()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
.child(
Icon::new(full_height_icon)
.size(IconSize::Small)
.color(Color::Muted),
)
.tooltip(Tooltip::text(full_height_tooltip_label))
.on_click(cx.listener(move |this, _event, _window, _cx| {
this.full_height_expanded = !this.full_height_expanded;
})),
.child(editor)
.when(
!self.full_height_expanded && is_collapsible,
|editor_container| editor_container.child(gradient_overlay),
),
)
})
})
.when(is_collapsible, |editor_container| {
editor_container.child(
h_flex()
.id(("expand-button", self.editor_unique_id))
.flex_none()
.cursor_pointer()
.h_5()
.justify_center()
.rounded_b_md()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.hover(|style| {
style.bg(cx.theme().colors().element_hover.opacity(0.1))
})
.child(
Icon::new(full_height_icon)
.size(IconSize::Small)
.color(Color::Muted),
)
.tooltip(Tooltip::text(full_height_tooltip_label))
.on_click(cx.listener(move |this, _event, _window, _cx| {
this.full_height_expanded = !this.full_height_expanded;
})),
)
})
},
)
}
}

View File

@@ -43,7 +43,7 @@ impl Tool for MovePathTool {
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
false
}
fn description(&self) -> String {

View File

@@ -4,3 +4,6 @@ This tool opens a file or URL with the default application associated with it on
- On Linux, it uses something like `xdg-open`, `gio open`, `gnome-open`, `kde-open`, `wslview` as appropriate
For example, it can open a web browser with a URL, open a PDF file with the default PDF viewer, etc.
You MUST ONLY use this tool when the user has explicitly requested opening something. You MUST NEVER assume that
the user would like for you to use this tool.

View File

@@ -6,8 +6,9 @@ use gpui::{AnyWindowHandle, App, Entity, Task};
use indoc::formatdoc;
use itertools::Itertools;
use language::{Anchor, Point};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use project::{AgentLocation, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
@@ -35,11 +36,11 @@ pub struct ReadFileToolInput {
/// Optional line number to start reading on (1-based index)
#[serde(default)]
pub start_line: Option<usize>,
pub start_line: Option<u32>,
/// Optional line number to end reading on (1-based index, inclusive)
#[serde(default)]
pub end_line: Option<usize>,
pub end_line: Option<u32>,
}
pub struct ReadFileTool;
@@ -109,7 +110,7 @@ impl Tool for ReadFileTool {
let file_path = input.path.clone();
cx.spawn(async move |cx| {
if !exists.await? {
return Err(anyhow!("{} not found", file_path))
return Err(anyhow!("{} not found", file_path));
}
let buffer = cx
@@ -118,25 +119,54 @@ impl Tool for ReadFileTool {
})?
.await?;
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: Anchor::MIN,
}),
cx,
);
})?;
// Check if specific line ranges are provided
if input.start_line.is_some() || input.end_line.is_some() {
let mut anchor = None;
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
let start = input.start_line.unwrap_or(1).max(1);
let lines = text.split('\n').skip(start - 1);
let start_row = start - 1;
if start_row <= buffer.max_point().row {
let column = buffer.line_indent_for_row(start_row).raw_len();
anchor = Some(buffer.anchor_before(Point::new(start_row, column)));
}
let lines = text.split('\n').skip(start_row as usize);
if let Some(end) = input.end_line {
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
Itertools::intersperse(lines.take(count), "\n").collect()
Itertools::intersperse(lines.take(count as usize), "\n").collect()
} else {
Itertools::intersperse(lines, "\n").collect()
}
})?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
log.track_buffer(buffer.clone(), cx);
})?;
if let Some(anchor) = anchor {
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: anchor,
}),
cx,
);
})?;
}
Ok(result)
} else {
// No line ranges specified, so check file size to see if it's too big.
@@ -165,7 +195,8 @@ impl Tool for ReadFileTool {
})
}
}
}).into()
})
.into()
}
}

View File

@@ -1,204 +0,0 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{self, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ui::IconName;
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RenameToolInput {
/// The relative path to the file containing the symbol to rename.
///
/// WARNING: you MUST start this path with one of the project's root directories.
pub path: String,
/// The new name to give to the symbol.
pub new_name: String,
/// The text that comes immediately before the symbol in the file.
pub context_before_symbol: String,
/// The symbol to rename. This text must appear in the file right between
/// `context_before_symbol` and `context_after_symbol`.
///
/// The file must contain exactly one occurrence of `context_before_symbol` followed by
/// `symbol` followed by `context_after_symbol`. If the file contains zero occurrences,
/// or if it contains more than one occurrence, the tool will fail, so it is absolutely
/// critical that you verify ahead of time that the string is unique. You can search
/// the file's contents to verify this ahead of time.
///
/// To make the string more likely to be unique, include a minimum of 1 line of context
/// before the symbol, as well as a minimum of 1 line of context after the symbol.
/// If these lines of context are not enough to obtain a string that appears only once
/// in the file, then double the number of context lines until the string becomes unique.
/// (Start with 1 line before and 1 line after though, because too much context is
/// needlessly costly.)
///
/// Do not alter the context lines of code in any way, and make sure to preserve all
/// whitespace and indentation for all lines of code. The combined string must be exactly
/// as it appears in the file, or else this tool call will fail.
pub symbol: String,
/// The text that comes immediately after the symbol in the file.
pub context_after_symbol: String,
}
pub struct RenameTool;
impl Tool for RenameTool {
fn name(&self) -> String {
"rename".into()
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./rename_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<RenameToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<RenameToolInput>(input.clone()) {
Ok(input) => {
format!("Rename '{}' to '{}'", input.symbol, input.new_name)
}
Err(_) => "Rename symbol".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<RenameToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx| {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
})?;
let position = {
let Some(position) = buffer.read_with(cx, |buffer, _cx| {
find_symbol_position(&buffer, &input.context_before_symbol, &input.symbol, &input.context_after_symbol)
})? else {
return Err(anyhow!(
"Failed to locate the symbol specified by context_before_symbol, symbol, and context_after_symbol. Make sure context_before_symbol and context_after_symbol each match exactly once in the file."
));
};
buffer.read_with(cx, |buffer, _| {
position.to_point_utf16(&buffer.snapshot())
})?
};
project
.update(cx, |project, cx| {
project.perform_rename(buffer.clone(), position, input.new_name.clone(), cx)
})?
.await?;
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
})?;
Ok(format!("Renamed '{}' to '{}'", input.symbol, input.new_name))
}).into()
}
}
/// Finds the position of the symbol in the buffer, if it appears between context_before_symbol
/// and context_after_symbol, and if that combined string has one unique result in the buffer.
///
/// If an exact match fails, it tries adding a newline to the end of context_before_symbol and
/// to the beginning of context_after_symbol to accommodate line-based context matching.
fn find_symbol_position(
buffer: &Buffer,
context_before_symbol: &str,
symbol: &str,
context_after_symbol: &str,
) -> Option<language::Anchor> {
let snapshot = buffer.snapshot();
let text = snapshot.text();
// First try with exact match
let search_string = format!("{context_before_symbol}{symbol}{context_after_symbol}");
let mut positions = text.match_indices(&search_string);
let position_result = positions.next();
if let Some(position) = position_result {
// Check if the matched string is unique
if positions.next().is_none() {
let symbol_start = position.0 + context_before_symbol.len();
let symbol_start_anchor =
snapshot.anchor_before(snapshot.offset_to_point(symbol_start));
return Some(symbol_start_anchor);
}
}
// If exact match fails or is not unique, try with line-based context
// Add a newline to the end of before context and beginning of after context
let line_based_before = if context_before_symbol.ends_with('\n') {
context_before_symbol.to_string()
} else {
format!("{context_before_symbol}\n")
};
let line_based_after = if context_after_symbol.starts_with('\n') {
context_after_symbol.to_string()
} else {
format!("\n{context_after_symbol}")
};
let line_search_string = format!("{line_based_before}{symbol}{line_based_after}");
let mut line_positions = text.match_indices(&line_search_string);
let line_position = line_positions.next()?;
// The line-based search string must also appear exactly once
if line_positions.next().is_some() {
return None;
}
let line_symbol_start = line_position.0 + line_based_before.len();
let line_symbol_start_anchor =
snapshot.anchor_before(snapshot.offset_to_point(line_symbol_start));
Some(line_symbol_start_anchor)
}

View File

@@ -61,6 +61,9 @@ pub struct StreamingEditFileToolInput {
/// If true, this tool will recreate the file from scratch.
/// If false, this tool will produce granular edits to an existing file.
///
/// When a file already exists or you just created it, always prefer editing
/// it as opposed to recreating it from scratch.
pub create_or_overwrite: bool,
}
@@ -170,7 +173,7 @@ impl Tool for StreamingEditFileTool {
.update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
.context("default model not set")?
.model;
let edit_agent = EditAgent::new(model, action_log, Templates::new());
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
let buffer = project
.update(cx, |project, cx| {

View File

@@ -1,307 +0,0 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{fmt::Write, ops::Range, sync::Arc};
use ui::IconName;
use util::markdown::MarkdownInlineCode;
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct SymbolInfoToolInput {
/// The relative path to the file containing the symbol.
///
/// WARNING: you MUST start this path with one of the project's root directories.
pub path: String,
/// The information to get about the symbol.
pub command: Info,
/// The text that comes immediately before the symbol in the file.
pub context_before_symbol: String,
/// The symbol name. This text must appear in the file right between `context_before_symbol`
/// and `context_after_symbol`.
///
/// The file must contain exactly one occurrence of `context_before_symbol` followed by
/// `symbol` followed by `context_after_symbol`. If the file contains zero occurrences,
/// or if it contains more than one occurrence, the tool will fail, so it is absolutely
/// critical that you verify ahead of time that the string is unique. You can search
/// the file's contents to verify this ahead of time.
///
/// To make the string more likely to be unique, include a minimum of 1 line of context
/// before the symbol, as well as a minimum of 1 line of context after the symbol.
/// If these lines of context are not enough to obtain a string that appears only once
/// in the file, then double the number of context lines until the string becomes unique.
/// (Start with 1 line before and 1 line after though, because too much context is
/// needlessly costly.)
///
/// Do not alter the context lines of code in any way, and make sure to preserve all
/// whitespace and indentation for all lines of code. The combined string must be exactly
/// as it appears in the file, or else this tool call will fail.
pub symbol: String,
/// The text that comes immediately after the symbol in the file.
pub context_after_symbol: String,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum Info {
/// Get the symbol's definition (where it's first assigned, even if it's declared elsewhere)
Definition,
/// Get the symbol's declaration (where it's first declared)
Declaration,
/// Get the symbol's implementation
Implementation,
/// Get the symbol's type definition
TypeDefinition,
/// Find all references to the symbol in the project
References,
}
pub struct SymbolInfoTool;
impl Tool for SymbolInfoTool {
fn name(&self) -> String {
"symbol_info".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./symbol_info_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::Code
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<SymbolInfoToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<SymbolInfoToolInput>(input.clone()) {
Ok(input) => {
let symbol = MarkdownInlineCode(&input.symbol);
match input.command {
Info::Definition => {
format!("Find definition for {symbol}")
}
Info::Declaration => {
format!("Find declaration for {symbol}")
}
Info::Implementation => {
format!("Find implementation for {symbol}")
}
Info::TypeDefinition => {
format!("Find type definition for {symbol}")
}
Info::References => {
format!("Find references for {symbol}")
}
}
}
Err(_) => "Get symbol info".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx| {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
})?;
let position = {
let Some(range) = buffer.read_with(cx, |buffer, _cx| {
find_symbol_range(&buffer, &input.context_before_symbol, &input.symbol, &input.context_after_symbol)
})? else {
return Err(anyhow!(
"Failed to locate the text specified by context_before_symbol, symbol, and context_after_symbol. Make sure context_before_symbol and context_after_symbol each match exactly once in the file."
));
};
buffer.read_with(cx, |buffer, _| {
range.start.to_point_utf16(&buffer.snapshot())
})?
};
let output: String = match input.command {
Info::Definition => {
render_locations(project
.update(cx, |project, cx| {
project.definition(&buffer, position, cx)
})?
.await?.into_iter().map(|link| link.target),
cx)
}
Info::Declaration => {
render_locations(project
.update(cx, |project, cx| {
project.declaration(&buffer, position, cx)
})?
.await?.into_iter().map(|link| link.target),
cx)
}
Info::Implementation => {
render_locations(project
.update(cx, |project, cx| {
project.implementation(&buffer, position, cx)
})?
.await?.into_iter().map(|link| link.target),
cx)
}
Info::TypeDefinition => {
render_locations(project
.update(cx, |project, cx| {
project.type_definition(&buffer, position, cx)
})?
.await?.into_iter().map(|link| link.target),
cx)
}
Info::References => {
render_locations(project
.update(cx, |project, cx| {
project.references(&buffer, position, cx)
})?
.await?,
cx)
}
};
if output.is_empty() {
Err(anyhow!("None found."))
} else {
Ok(output)
}
}).into()
}
}
/// Finds the range of the symbol in the buffer, if it appears between context_before_symbol
/// and context_after_symbol, and if that combined string has one unique result in the buffer.
fn find_symbol_range(
buffer: &Buffer,
context_before_symbol: &str,
symbol: &str,
context_after_symbol: &str,
) -> Option<Range<Anchor>> {
let snapshot = buffer.snapshot();
let text = snapshot.text();
let search_string = format!("{context_before_symbol}{symbol}{context_after_symbol}");
let mut positions = text.match_indices(&search_string);
let position = positions.next()?.0;
// The combined string must appear exactly once.
if positions.next().is_some() {
return None;
}
let symbol_start = position + context_before_symbol.len();
let symbol_end = symbol_start + symbol.len();
let symbol_start_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_start));
let symbol_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_end));
Some(symbol_start_anchor..symbol_end_anchor)
}
fn render_locations(locations: impl IntoIterator<Item = Location>, cx: &mut AsyncApp) -> String {
let mut answer = String::new();
for location in locations {
location
.buffer
.read_with(cx, |buffer, _cx| {
if let Some(target_path) = buffer
.file()
.and_then(|file| file.path().as_os_str().to_str())
{
let snapshot = buffer.snapshot();
let start = location.range.start.to_point(&snapshot);
let end = location.range.end.to_point(&snapshot);
let start_line = start.row + 1;
let start_col = start.column + 1;
let end_line = end.row + 1;
let end_col = end.column + 1;
if start_line == end_line {
writeln!(answer, "{target_path}:{start_line},{start_col}")
} else {
writeln!(
answer,
"{target_path}:{start_line},{start_col}-{end_line},{end_col}",
)
}
.ok();
write_code_excerpt(&mut answer, &snapshot, &location.range);
}
})
.ok();
}
// Trim trailing newlines without reallocating.
answer.truncate(answer.trim_end().len());
answer
}
fn write_code_excerpt(buf: &mut String, snapshot: &BufferSnapshot, range: &Range<Anchor>) {
const MAX_LINE_LEN: u32 = 200;
let start = range.start.to_point(snapshot);
let end = range.end.to_point(snapshot);
for row in start.row..=end.row {
let row_start = Point::new(row, 0);
let row_end = if row < snapshot.max_point().row {
Point::new(row + 1, 0)
} else {
Point::new(row, u32::MAX)
};
buf.extend(
snapshot
.text_for_range(row_start..row_end)
.take(MAX_LINE_LEN as usize),
);
if row_end.column > MAX_LINE_LEN {
buf.push_str("\n");
}
buf.push('\n');
}
}

View File

@@ -10,6 +10,7 @@ use std::{
pub use system_clock::*;
pub const LOCAL_BRANCH_REPLICA_ID: u16 = u16::MAX;
pub const AGENT_REPLICA_ID: u16 = u16::MAX - 1;
/// A unique identifier for each distributed node.
pub type ReplicaId = u16;

View File

@@ -0,0 +1,23 @@
create table subscription_usages_v2 (
id uuid primary key,
user_id integer not null,
period_start_at timestamp without time zone not null,
period_end_at timestamp without time zone not null,
plan text not null,
model_requests int not null default 0,
edit_predictions int not null default 0
);
create unique index uix_subscription_usages_v2_on_user_id_start_at_end_at on subscription_usages_v2 (user_id, period_start_at, period_end_at);
create index ix_subscription_usages_v2_on_plan on subscription_usages_v2 (plan);
create table subscription_usage_meters_v2 (
id uuid primary key,
subscription_usage_id uuid not null references subscription_usages_v2 (id) on delete cascade,
model_id integer not null references models (id) on delete cascade,
mode text not null,
requests integer not null default 0
);
create unique index uix_subscription_usage_meters_v2_on_usage_model_mode on subscription_usage_meters_v2 (subscription_usage_id, model_id, mode);

View File

@@ -0,0 +1,2 @@
drop table subscription_usage_meters;
drop table subscription_usages;

View File

@@ -624,7 +624,7 @@ struct MigrateToNewBillingBody {
#[derive(Debug, Serialize)]
struct MigrateToNewBillingResponse {
/// The ID of the subscription that was canceled.
canceled_subscription_id: String,
canceled_subscription_id: Option<String>,
}
async fn migrate_to_new_billing(
@@ -650,39 +650,39 @@ async fn migrate_to_new_billing(
.get_active_billing_subscriptions(HashSet::from_iter([user.id]))
.await?;
let Some((_billing_customer, billing_subscription)) =
let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
old_billing_subscriptions_by_user.get(&user.id)
else {
return Err(Error::http(
StatusCode::NOT_FOUND,
"No active billing subscriptions to migrate".into(),
));
{
let stripe_subscription_id = billing_subscription
.stripe_subscription_id
.parse::<stripe::SubscriptionId>()
.context("failed to parse Stripe subscription ID from database")?;
Subscription::cancel(
&stripe_client,
&stripe_subscription_id,
stripe::CancelSubscription {
invoice_now: Some(true),
..Default::default()
},
)
.await?;
Some(stripe_subscription_id)
} else {
None
};
let stripe_subscription_id = billing_subscription
.stripe_subscription_id
.parse::<stripe::SubscriptionId>()
.context("failed to parse Stripe subscription ID from database")?;
Subscription::cancel(
&stripe_client,
&stripe_subscription_id,
stripe::CancelSubscription {
invoice_now: Some(true),
..Default::default()
},
)
.await?;
let feature_flags = app.db.list_feature_flags().await?;
let all_feature_flags = app.db.list_feature_flags().await?;
let user_feature_flags = app.db.get_user_flags(user.id).await?;
for feature_flag in ["new-billing", "assistant2"] {
let already_in_feature_flag = feature_flags.iter().any(|flag| flag.flag == feature_flag);
let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
if already_in_feature_flag {
continue;
}
let feature_flag = feature_flags
let feature_flag = all_feature_flags
.iter()
.find(|flag| flag.flag == feature_flag)
.context("failed to find feature flag: {feature_flag:?}")?;
@@ -691,7 +691,8 @@ async fn migrate_to_new_billing(
}
Ok(Json(MigrateToNewBillingResponse {
canceled_subscription_id: stripe_subscription_id.to_string(),
canceled_subscription_id: canceled_subscription_id
.map(|subscription_id| subscription_id.to_string()),
}))
}
@@ -1039,6 +1040,7 @@ async fn handle_customer_subscription_event(
billing_customer.user_id,
&existing_subscription,
subscription_kind,
subscription.status.into(),
new_period_start_at,
new_period_end_at,
)

View File

@@ -1,7 +1,7 @@
use chrono::Timelike;
use time::PrimitiveDateTime;
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::billing_subscription::{StripeSubscriptionStatus, SubscriptionKind};
use crate::db::{UserId, billing_subscription};
use super::*;
@@ -69,7 +69,7 @@ impl LlmDatabase {
Ok(
subscription_usage::Entity::insert(subscription_usage::ActiveModel {
id: ActiveValue::not_set(),
id: ActiveValue::set(Uuid::now_v7()),
user_id: ActiveValue::set(user_id),
period_start_at: ActiveValue::set(period_start_at),
period_end_at: ActiveValue::set(period_end_at),
@@ -120,12 +120,13 @@ impl LlmDatabase {
user_id: UserId,
existing_subscription: &billing_subscription::Model,
new_subscription_kind: Option<SubscriptionKind>,
new_subscription_status: StripeSubscriptionStatus,
new_period_start_at: DateTimeUtc,
new_period_end_at: DateTimeUtc,
) -> Result<Option<subscription_usage::Model>> {
self.transaction(|tx| async move {
match existing_subscription.kind {
Some(SubscriptionKind::ZedProTrial) => {
match (existing_subscription.kind, new_subscription_status) {
(Some(SubscriptionKind::ZedProTrial), StripeSubscriptionStatus::Active) => {
let trial_period_start_at = existing_subscription
.current_period_start_at()
.ok_or_else(|| anyhow!("No trial subscription period start"))?;

View File

@@ -4,10 +4,10 @@ use sea_orm::entity::prelude::*;
use time::PrimitiveDateTime;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "subscription_usages")]
#[sea_orm(table_name = "subscription_usages_v2")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub id: Uuid,
pub user_id: UserId,
pub period_start_at: PrimitiveDateTime,
pub period_end_at: PrimitiveDateTime,

View File

@@ -4,11 +4,11 @@ use serde::Serialize;
use crate::llm::db::ModelId;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "subscription_usage_meters")]
#[sea_orm(table_name = "subscription_usage_meters_v2")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub subscription_usage_id: i32,
pub id: Uuid,
pub subscription_usage_id: Uuid,
pub model_id: ModelId,
pub mode: CompletionMode,
pub requests: i32,

View File

@@ -1,7 +1,7 @@
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::billing_subscription::{StripeSubscriptionStatus, SubscriptionKind};
use crate::db::{UserId, billing_subscription};
use crate::llm::db::LlmDatabase;
use crate::test_llm_db;
@@ -12,58 +12,108 @@ test_llm_db!(
);
async fn test_transfer_existing_subscription_usage(db: &mut LlmDatabase) {
let user_id = UserId(1);
// Test when an existing Zed Pro trial subscription is upgraded to Zed Pro.
{
let user_id = UserId(1);
let now = Utc::now();
let now = Utc::now();
let trial_period_start_at = now - Duration::days(14);
let trial_period_end_at = now;
let trial_period_start_at = now - Duration::days(14);
let trial_period_end_at = now;
let new_period_start_at = now;
let new_period_end_at = now + Duration::days(30);
let new_period_start_at = now;
let new_period_end_at = now + Duration::days(30);
let existing_subscription = billing_subscription::Model {
kind: Some(SubscriptionKind::ZedProTrial),
stripe_current_period_start: Some(trial_period_start_at.timestamp()),
stripe_current_period_end: Some(trial_period_end_at.timestamp()),
..Default::default()
};
let existing_subscription = billing_subscription::Model {
kind: Some(SubscriptionKind::ZedProTrial),
stripe_current_period_start: Some(trial_period_start_at.timestamp()),
stripe_current_period_end: Some(trial_period_end_at.timestamp()),
..Default::default()
};
let existing_usage = db
.create_subscription_usage(
user_id,
trial_period_start_at,
trial_period_end_at,
SubscriptionKind::ZedProTrial,
25,
1_000,
)
.await
.unwrap();
let existing_usage = db
.create_subscription_usage(
user_id,
trial_period_start_at,
trial_period_end_at,
SubscriptionKind::ZedProTrial,
25,
1_000,
)
.await
.unwrap();
let transferred_usage = db
.transfer_existing_subscription_usage(
user_id,
&existing_subscription,
Some(SubscriptionKind::ZedPro),
new_period_start_at,
new_period_end_at,
)
.await
.unwrap();
let transferred_usage = db
.transfer_existing_subscription_usage(
user_id,
&existing_subscription,
Some(SubscriptionKind::ZedPro),
StripeSubscriptionStatus::Active,
new_period_start_at,
new_period_end_at,
)
.await
.unwrap();
assert!(
transferred_usage.is_some(),
"subscription usage not transferred successfully"
);
let transferred_usage = transferred_usage.unwrap();
assert!(
transferred_usage.is_some(),
"subscription usage not transferred successfully"
);
let transferred_usage = transferred_usage.unwrap();
assert_eq!(
transferred_usage.model_requests,
existing_usage.model_requests
);
assert_eq!(
transferred_usage.edit_predictions,
existing_usage.edit_predictions
);
assert_eq!(
transferred_usage.model_requests,
existing_usage.model_requests
);
assert_eq!(
transferred_usage.edit_predictions,
existing_usage.edit_predictions
);
}
// Test when an existing Zed Pro trial subscription is canceled.
{
let user_id = UserId(2);
let now = Utc::now();
let trial_period_start_at = now - Duration::days(14);
let trial_period_end_at = now;
let existing_subscription = billing_subscription::Model {
kind: Some(SubscriptionKind::ZedProTrial),
stripe_current_period_start: Some(trial_period_start_at.timestamp()),
stripe_current_period_end: Some(trial_period_end_at.timestamp()),
..Default::default()
};
let _existing_usage = db
.create_subscription_usage(
user_id,
trial_period_start_at,
trial_period_end_at,
SubscriptionKind::ZedProTrial,
25,
1_000,
)
.await
.unwrap();
let transferred_usage = db
.transfer_existing_subscription_usage(
user_id,
&existing_subscription,
Some(SubscriptionKind::ZedPro),
StripeSubscriptionStatus::Canceled,
trial_period_start_at,
trial_period_end_at,
)
.await
.unwrap();
assert!(
transferred_usage.is_none(),
"subscription usage was transferred when it should not have been"
);
}
}

View File

@@ -30,6 +30,8 @@ pub struct LlmTokenClaims {
pub has_llm_closed_beta_feature_flag: bool,
pub bypass_account_age_check: bool,
pub has_llm_subscription: bool,
#[serde(default)]
pub use_llm_request_queue: bool,
pub max_monthly_spend_in_cents: u32,
pub custom_llm_monthly_allowance_in_cents: Option<u32>,
#[serde(default)]
@@ -93,6 +95,7 @@ impl LlmTokenClaims {
.custom_llm_monthly_allowance_in_cents
.map(|allowance| allowance as u32),
use_new_billing: feature_flags.iter().any(|flag| flag == "new-billing"),
use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
plan: subscription
.as_ref()
.and_then(|subscription| subscription.kind)

View File

@@ -13,6 +13,7 @@ use gpui::{BackgroundExecutor, Context, Entity, TestAppContext, Window};
use rpc::{RECEIVE_TIMEOUT, proto::PeerId};
use serde_json::json;
use std::ops::Range;
use workspace::CollaboratorId;
#[gpui::test]
async fn test_core_channel_buffers(
@@ -300,13 +301,20 @@ fn assert_remote_selections(
cx: &mut Context<Editor>,
) {
let snapshot = editor.snapshot(window, cx);
let hub = editor.collaboration_hub().unwrap();
let collaborators = hub.collaborators(cx);
let range = Anchor::min()..Anchor::max();
let remote_selections = snapshot
.remote_selections_in_range(&range, editor.collaboration_hub().unwrap(), cx)
.remote_selections_in_range(&range, hub, cx)
.map(|s| {
let CollaboratorId::PeerId(peer_id) = s.collaborator_id else {
panic!("unexpected collaborator id");
};
let start = s.selection.start.to_offset(&snapshot.buffer_snapshot);
let end = s.selection.end.to_offset(&snapshot.buffer_snapshot);
(s.participant_index, start..end)
let user_id = collaborators.get(&peer_id).unwrap().user_id;
let participant_index = hub.user_participant_indices(cx).get(&user_id).copied();
(participant_index, start..end)
})
.collect::<Vec<_>>();
assert_eq!(

View File

@@ -18,7 +18,7 @@ use serde_json::json;
use settings::SettingsStore;
use text::{Point, ToPoint};
use util::{path, test::sample_text};
use workspace::{SplitDirection, Workspace, item::ItemHandle as _};
use workspace::{CollaboratorId, SplitDirection, Workspace, item::ItemHandle as _};
use super::TestClient;
@@ -425,7 +425,7 @@ async fn test_basic_following(
executor.run_until_parked();
assert_eq!(
workspace_a.update(cx_a, |workspace, _| workspace.leader_for_pane(&pane_a)),
Some(peer_id_b)
Some(peer_id_b.into())
);
assert_eq!(
workspace_a.update_in(cx_a, |workspace, _, cx| workspace
@@ -1267,7 +1267,7 @@ async fn test_auto_unfollowing(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont
executor.run_until_parked();
assert_eq!(
workspace_b.update(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)),
Some(leader_id)
Some(leader_id.into())
);
let editor_b2 = workspace_b.update(cx_b, |workspace, cx| {
workspace
@@ -1292,7 +1292,7 @@ async fn test_auto_unfollowing(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont
executor.run_until_parked();
assert_eq!(
workspace_b.update(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)),
Some(leader_id)
Some(leader_id.into())
);
// When client B edits, it automatically stops following client A.
@@ -1308,7 +1308,7 @@ async fn test_auto_unfollowing(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont
executor.run_until_parked();
assert_eq!(
workspace_b.update(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)),
Some(leader_id)
Some(leader_id.into())
);
// When client B scrolls, it automatically stops following client A.
@@ -1326,7 +1326,7 @@ async fn test_auto_unfollowing(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont
executor.run_until_parked();
assert_eq!(
workspace_b.update(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)),
Some(leader_id)
Some(leader_id.into())
);
// When client B activates a different pane, it continues following client A in the original pane.
@@ -1335,7 +1335,7 @@ async fn test_auto_unfollowing(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont
});
assert_eq!(
workspace_b.update(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)),
Some(leader_id)
Some(leader_id.into())
);
workspace_b.update_in(cx_b, |workspace, window, cx| {
@@ -1343,7 +1343,7 @@ async fn test_auto_unfollowing(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont
});
assert_eq!(
workspace_b.update(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)),
Some(leader_id)
Some(leader_id.into())
);
// When client B activates a different item in the original pane, it automatically stops following client A.
@@ -1406,13 +1406,13 @@ async fn test_peers_simultaneously_following_each_other(
workspace_a.update(cx_a, |workspace, _| {
assert_eq!(
workspace.leader_for_pane(workspace.active_pane()),
Some(client_b_id)
Some(client_b_id.into())
);
});
workspace_b.update(cx_b, |workspace, _| {
assert_eq!(
workspace.leader_for_pane(workspace.active_pane()),
Some(client_a_id)
Some(client_a_id.into())
);
});
}
@@ -1513,7 +1513,7 @@ async fn test_following_across_workspaces(cx_a: &mut TestAppContext, cx_b: &mut
workspace_b_project_a.update(&mut cx_b2, |workspace, cx| {
assert!(workspace.is_being_followed(client_a.peer_id().unwrap()));
assert_eq!(
client_a.peer_id(),
client_a.peer_id().map(Into::into),
workspace.leader_for_pane(workspace.active_pane())
);
let item = workspace.active_item(cx).unwrap();
@@ -1554,7 +1554,7 @@ async fn test_following_across_workspaces(cx_a: &mut TestAppContext, cx_b: &mut
workspace_a.update(cx_a, |workspace, cx| {
assert!(workspace.is_being_followed(client_b.peer_id().unwrap()));
assert_eq!(
client_b.peer_id(),
client_b.peer_id().map(Into::into),
workspace.leader_for_pane(workspace.active_pane())
);
let item = workspace.active_pane().read(cx).active_item().unwrap();
@@ -1615,7 +1615,7 @@ async fn test_following_across_workspaces(cx_a: &mut TestAppContext, cx_b: &mut
assert_eq!(workspace.project().read(cx).remote_id(), Some(project_b_id));
assert!(workspace.is_being_followed(client_b.peer_id().unwrap()));
assert_eq!(
client_b.peer_id(),
client_b.peer_id().map(Into::into),
workspace.leader_for_pane(workspace.active_pane())
);
let item = workspace.active_item(cx).unwrap();
@@ -1866,7 +1866,11 @@ fn pane_summaries(workspace: &Entity<Workspace>, cx: &mut VisualTestContext) ->
.panes()
.iter()
.map(|pane| {
let leader = workspace.leader_for_pane(pane);
let leader = match workspace.leader_for_pane(pane) {
Some(CollaboratorId::PeerId(peer_id)) => Some(peer_id),
Some(CollaboratorId::Agent) => unimplemented!(),
None => None,
};
let active = pane == active_pane;
let pane = pane.read(cx);
let active_ix = pane.active_item_index();
@@ -1985,7 +1989,7 @@ async fn test_following_to_channel_notes_without_a_shared_project(
let channel_notes_1_b = workspace_b.update(cx_b, |workspace, cx| {
assert_eq!(
workspace.leader_for_pane(workspace.active_pane()),
Some(client_a.peer_id().unwrap())
Some(client_a.peer_id().unwrap().into())
);
workspace
.active_item(cx)
@@ -2015,7 +2019,7 @@ async fn test_following_to_channel_notes_without_a_shared_project(
let channel_notes_2_b = workspace_b.update(cx_b, |workspace, cx| {
assert_eq!(
workspace.leader_for_pane(workspace.active_pane()),
Some(client_a.peer_id().unwrap())
Some(client_a.peer_id().unwrap().into())
);
workspace
.active_item(cx)

View File

@@ -22,7 +22,7 @@ use std::{
};
use ui::prelude::*;
use util::ResultExt;
use workspace::item::TabContentParams;
use workspace::{CollaboratorId, item::TabContentParams};
use workspace::{
ItemNavHistory, Pane, SaveIntent, Toast, ViewId, Workspace, WorkspaceId,
item::{FollowableItem, Item, ItemEvent, ItemHandle},
@@ -654,15 +654,14 @@ impl FollowableItem for ChannelView {
})
}
fn set_leader_peer_id(
fn set_leader_id(
&mut self,
leader_peer_id: Option<PeerId>,
leader_id: Option<CollaboratorId>,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.editor.update(cx, |editor, cx| {
editor.set_leader_peer_id(leader_peer_id, window, cx)
})
self.editor
.update(cx, |editor, cx| editor.set_leader_id(leader_id, window, cx))
}
fn is_project_item(&self, _window: &Window, _cx: &App) -> bool {

View File

@@ -18,6 +18,9 @@ pub trait Component {
fn name() -> &'static str {
std::any::type_name::<Self>()
}
fn id() -> ComponentId {
ComponentId(Self::name())
}
/// Returns a name that the component should be sorted by.
///
/// Implement this if the component should be sorted in an alternate order than its name.
@@ -81,7 +84,7 @@ pub fn register_component<T: Component>() {
let component_data = (T::scope(), T::name(), T::sort_name(), T::description());
let mut data = COMPONENT_DATA.write();
data.components.push(component_data);
data.previews.insert(T::name(), T::preview);
data.previews.insert(T::id().0, T::preview);
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]

View File

@@ -110,18 +110,7 @@ struct ComponentPreview {
active_page: PreviewPage,
components: Vec<ComponentMetadata>,
component_list: ListState,
agent_previews: Vec<
Box<
dyn Fn(
&Self,
WeakEntity<Workspace>,
Entity<ActiveThread>,
WeakEntity<ThreadStore>,
&mut Window,
&mut App,
) -> Option<AnyElement>,
>,
>,
agent_previews: Vec<ComponentId>,
cursor_index: usize,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
@@ -191,38 +180,7 @@ impl ComponentPreview {
);
// Initialize agent previews
let agent_previews = agent::all_agent_previews()
.into_iter()
.map(|id| {
Box::new(
move |_self: &ComponentPreview,
workspace: WeakEntity<Workspace>,
active_thread: Entity<ActiveThread>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App| {
agent::get_agent_preview(
&id,
workspace,
active_thread,
thread_store,
window,
cx,
)
},
)
as Box<
dyn Fn(
&ComponentPreview,
WeakEntity<Workspace>,
Entity<ActiveThread>,
WeakEntity<ThreadStore>,
&mut Window,
&mut App,
) -> Option<AnyElement>,
>
})
.collect::<Vec<_>>();
let agent_previews = agent::all_agent_previews();
let mut component_preview = Self {
workspace_id: None,
@@ -635,44 +593,65 @@ impl ComponentPreview {
let description = component.description();
v_flex()
.py_2()
.child(
v_flex()
.border_1()
.border_color(cx.theme().colors().border)
.rounded_sm()
.w_full()
.gap_4()
.py_4()
.px_6()
.flex_none()
.child(
v_flex()
.gap_1()
.child(
h_flex().gap_1().text_xl().child(div().child(name)).when(
!matches!(scope, ComponentScope::None),
|this| {
this.child(div().opacity(0.5).child(format!("({})", scope)))
},
),
// Build the content container
let mut preview_container = v_flex().py_2().child(
v_flex()
.border_1()
.border_color(cx.theme().colors().border)
.rounded_sm()
.w_full()
.gap_4()
.py_4()
.px_6()
.flex_none()
.child(
v_flex()
.gap_1()
.child(
h_flex()
.gap_1()
.text_xl()
.child(div().child(name))
.when(!matches!(scope, ComponentScope::None), |this| {
this.child(div().opacity(0.5).child(format!("({})", scope)))
}),
)
.when_some(description, |this, description| {
this.child(
div()
.text_ui_sm(cx)
.text_color(cx.theme().colors().text_muted)
.max_w(px(600.0))
.child(description),
)
.when_some(description, |this, description| {
this.child(
div()
.text_ui_sm(cx)
.text_color(cx.theme().colors().text_muted)
.max_w(px(600.0))
.child(description),
)
}),
)
.when_some(component.preview(), |this, preview| {
this.children(preview(window, cx))
}),
)
.into_any_element()
}),
),
);
// Check if the component's scope is Agent
if scope == ComponentScope::Agent {
if let (Some(thread_store), Some(active_thread)) = (
self.thread_store.as_ref().map(|ts| ts.downgrade()),
self.active_thread.clone(),
) {
if let Some(element) = agent::get_agent_preview(
&component.id(),
self.workspace.clone(),
active_thread,
thread_store,
window,
cx,
) {
preview_container = preview_container.child(element);
} else if let Some(preview) = component.preview() {
preview_container = preview_container.children(preview(window, cx));
}
}
} else if let Some(preview) = component.preview() {
preview_container = preview_container.children(preview(window, cx));
}
preview_container.into_any_element()
}
fn render_all_components(&self, cx: &Context<Self>) -> impl IntoElement {
@@ -711,7 +690,12 @@ impl ComponentPreview {
v_flex()
.id("render-component-page")
.size_full()
.child(ComponentPreviewPage::new(component.clone()))
.child(ComponentPreviewPage::new(
component.clone(),
self.workspace.clone(),
self.thread_store.as_ref().map(|ts| ts.downgrade()),
self.active_thread.clone(),
))
.into_any_element()
} else {
v_flex()
@@ -732,13 +716,13 @@ impl ComponentPreview {
.id("render-active-thread")
.size_full()
.child(
v_flex().children(self.agent_previews.iter().filter_map(|preview_fn| {
v_flex().children(self.agent_previews.iter().filter_map(|component_id| {
if let (Some(thread_store), Some(active_thread)) = (
self.thread_store.as_ref().map(|ts| ts.downgrade()),
self.active_thread.clone(),
) {
preview_fn(
self,
agent::get_agent_preview(
component_id,
self.workspace.clone(),
active_thread,
thread_store,
@@ -894,7 +878,7 @@ impl Default for ActivePageId {
impl From<ComponentId> for ActivePageId {
fn from(id: ComponentId) -> Self {
ActivePageId(id.0.to_string())
Self(id.0.to_string())
}
}
@@ -1073,16 +1057,25 @@ impl SerializableItem for ComponentPreview {
pub struct ComponentPreviewPage {
// languages: Arc<LanguageRegistry>,
component: ComponentMetadata,
workspace: WeakEntity<Workspace>,
thread_store: Option<WeakEntity<ThreadStore>>,
active_thread: Option<Entity<ActiveThread>>,
}
impl ComponentPreviewPage {
pub fn new(
component: ComponentMetadata,
workspace: WeakEntity<Workspace>,
thread_store: Option<WeakEntity<ThreadStore>>,
active_thread: Option<Entity<ActiveThread>>,
// languages: Arc<LanguageRegistry>
) -> Self {
Self {
// languages,
component,
workspace,
thread_store,
active_thread,
}
}
@@ -1113,12 +1106,32 @@ impl ComponentPreviewPage {
}
fn render_preview(&self, window: &mut Window, cx: &mut App) -> impl IntoElement {
// Try to get agent preview first if we have an active thread
let maybe_agent_preview = if let (Some(thread_store), Some(active_thread)) =
(self.thread_store.as_ref(), self.active_thread.as_ref())
{
agent::get_agent_preview(
&self.component.id(),
self.workspace.clone(),
active_thread.clone(),
thread_store.clone(),
window,
cx,
)
} else {
None
};
v_flex()
.flex_1()
.px_12()
.py_6()
.bg(cx.theme().colors().editor_background)
.child(if let Some(preview) = self.component.preview() {
.child(if let Some(element) = maybe_agent_preview {
// Use agent preview if available
element
} else if let Some(preview) = self.component.preview() {
// Fall back to component preview
preview(window, cx).unwrap_or_else(|| {
div()
.child("Failed to load preview. This path should be unreachable")

View File

@@ -6,7 +6,7 @@ use crate::{
persistence,
};
use crate::{new_session_modal::NewSessionModal, session::DebugSession};
use anyhow::{Result, anyhow};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use command_palette_hooks::CommandPaletteFilter;
use dap::DebugRequest;
@@ -353,7 +353,6 @@ impl DebugPanel {
};
let dap_store_handle = self.project.read(cx).dap_store().clone();
let breakpoint_store = self.project.read(cx).breakpoint_store();
let definition = parent_session.read(cx).definition().clone();
let mut binary = parent_session.read(cx).binary().clone();
binary.request_args = request.clone();
@@ -364,13 +363,7 @@ impl DebugPanel {
dap_store.new_session(definition.clone(), Some(parent_session.clone()), cx);
let task = session.update(cx, |session, cx| {
session.boot(
binary,
worktree,
breakpoint_store,
dap_store_handle.downgrade(),
cx,
)
session.boot(binary, worktree, dap_store_handle.downgrade(), cx)
});
(session, task)
})?;
@@ -500,7 +493,7 @@ impl DebugPanel {
workspace.spawn_in_terminal(task.resolved.clone(), window, cx)
})?;
let exit_status = run_build.await?;
let exit_status = run_build.await.transpose()?.context("task cancelled")?;
if !exit_status.success() {
anyhow::bail!("Build failed");
}

View File

@@ -7,11 +7,11 @@ use gpui::{App, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task
use project::Project;
use project::debugger::session::Session;
use project::worktree_store::WorktreeStore;
use rpc::proto::{self, PeerId};
use rpc::proto;
use running::RunningState;
use ui::{Indicator, prelude::*};
use workspace::{
FollowableItem, ViewId, Workspace,
CollaboratorId, FollowableItem, ViewId, Workspace,
item::{self, Item},
};
@@ -189,9 +189,9 @@ impl FollowableItem for DebugSession {
Task::ready(Ok(()))
}
fn set_leader_peer_id(
fn set_leader_id(
&mut self,
_leader_peer_id: Option<PeerId>,
_leader_id: Option<CollaboratorId>,
_window: &mut Window,
_cx: &mut Context<Self>,
) {

View File

@@ -1663,6 +1663,33 @@ async fn test_active_debug_line_setting(executor: BackgroundExecutor, cx: &mut T
"Second stacktrace request handler was not called"
);
client
.fake_event(dap::messages::Events::Continued(dap::ContinuedEvent {
thread_id: 0,
all_threads_continued: Some(true),
}))
.await;
cx.run_until_parked();
second_editor.update(cx, |editor, _| {
let active_debug_lines: Vec<_> = editor.highlighted_rows::<ActiveDebugLine>().collect();
assert!(
active_debug_lines.is_empty(),
"There shouldn't be any active debug lines"
);
});
main_editor.update(cx, |editor, _| {
let active_debug_lines: Vec<_> = editor.highlighted_rows::<ActiveDebugLine>().collect();
assert!(
active_debug_lines.is_empty(),
"There shouldn't be any active debug lines"
);
});
// Clean up
let shutdown_session = project.update(cx, |project, cx| {
project.dap_store().update(cx, |dap_store, cx| {

View File

@@ -226,7 +226,7 @@ impl ProjectDiagnosticsEditor {
cx.observe_global_in::<IncludeWarnings>(window, |this, window, cx| {
this.include_warnings = cx.global::<IncludeWarnings>().0;
this.diagnostics.clear();
this.update_all_diagnostics(window, cx);
this.update_all_diagnostics(false, window, cx);
})
.detach();
cx.observe_release(&cx.entity(), |editor, _, cx| {
@@ -254,7 +254,7 @@ impl ProjectDiagnosticsEditor {
},
_subscription: project_event_subscription,
};
this.update_all_diagnostics(window, cx);
this.update_all_diagnostics(true, window, cx);
this
}
@@ -346,13 +346,13 @@ impl ProjectDiagnosticsEditor {
if self.cargo_diagnostics_fetch.fetch_task.is_some() {
self.stop_cargo_diagnostics_fetch(cx);
} else {
self.update_all_diagnostics(window, cx);
self.update_all_diagnostics(false, window, cx);
}
} else {
if self.update_excerpts_task.is_some() {
self.update_excerpts_task = None;
} else {
self.update_all_diagnostics(window, cx);
self.update_all_diagnostics(false, window, cx);
}
}
cx.notify();
@@ -371,10 +371,17 @@ impl ProjectDiagnosticsEditor {
}
}
fn update_all_diagnostics(&mut self, window: &mut Window, cx: &mut Context<Self>) {
fn update_all_diagnostics(
&mut self,
first_launch: bool,
window: &mut Window,
cx: &mut Context<Self>,
) {
let cargo_diagnostics_sources = self.cargo_diagnostics_sources(cx);
if cargo_diagnostics_sources.is_empty() {
self.update_all_excerpts(window, cx);
} else if first_launch && !self.summary.is_empty() {
self.update_all_excerpts(window, cx);
} else {
self.fetch_cargo_diagnostics(Arc::new(cargo_diagnostics_sources), cx);
}

View File

@@ -391,7 +391,7 @@ impl DisplayMap {
&mut self,
crease_ids: impl IntoIterator<Item = CreaseId>,
cx: &mut Context<Self>,
) {
) -> Vec<(CreaseId, Range<Anchor>)> {
let snapshot = self.buffer.read(cx).snapshot(cx);
self.crease_map.remove(crease_ids, &snapshot)
}

View File

@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use std::{cmp::Ordering, fmt::Debug, ops::Range, sync::Arc};
use sum_tree::{Bias, SeekTarget, SumTree};
use text::Point;
use ui::{App, IconName, SharedString, Window};
use ui::{App, SharedString, Window};
use crate::{BlockStyle, FoldPlaceholder, RenderBlock};
@@ -40,6 +40,10 @@ impl CreaseSnapshot {
}
}
pub fn creases(&self) -> impl Iterator<Item = (CreaseId, &Crease<Anchor>)> {
self.creases.iter().map(|item| (item.id, &item.crease))
}
/// Returns the first Crease starting on the specified buffer row.
pub fn query_row<'a>(
&'a self,
@@ -147,7 +151,7 @@ pub enum Crease<T> {
/// Metadata about a [`Crease`], that is used for serialization.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct CreaseMetadata {
pub icon: IconName,
pub icon_path: SharedString,
pub label: SharedString,
}
@@ -237,6 +241,13 @@ impl<T> Crease<T> {
Crease::Block { range, .. } => range,
}
}
pub fn metadata(&self) -> Option<&CreaseMetadata> {
match self {
Self::Inline { metadata, .. } => metadata.as_ref(),
Self::Block { .. } => None,
}
}
}
impl<T> std::fmt::Debug for Crease<T>
@@ -305,7 +316,7 @@ impl CreaseMap {
&mut self,
ids: impl IntoIterator<Item = CreaseId>,
snapshot: &MultiBufferSnapshot,
) {
) -> Vec<(CreaseId, Range<Anchor>)> {
let mut removals = Vec::new();
for id in ids {
if let Some(range) = self.id_to_range.remove(&id) {
@@ -320,11 +331,11 @@ impl CreaseMap {
let mut new_creases = SumTree::new(snapshot);
let mut cursor = self.snapshot.creases.cursor::<ItemSummary>(snapshot);
for (id, range) in removals {
new_creases.append(cursor.slice(&range, Bias::Left, snapshot), snapshot);
for (id, range) in &removals {
new_creases.append(cursor.slice(range, Bias::Left, snapshot), snapshot);
while let Some(item) = cursor.item() {
cursor.next(snapshot);
if item.id == id {
if item.id == *id {
break;
} else {
new_creases.push(item.clone(), snapshot);
@@ -335,6 +346,8 @@ impl CreaseMap {
new_creases.append(cursor.suffix(snapshot), snapshot);
new_creases
};
removals
}
}

View File

@@ -56,7 +56,7 @@ use anyhow::{Context as _, Result, anyhow};
use blink_manager::BlinkManager;
use buffer_diff::DiffHunkStatus;
use client::{Collaborator, ParticipantIndex};
use clock::ReplicaId;
use clock::{AGENT_REPLICA_ID, ReplicaId};
use collections::{BTreeMap, HashMap, HashSet, VecDeque};
use convert_case::{Case, Casing};
use display_map::*;
@@ -201,7 +201,7 @@ use ui::{
};
use util::{RangeExt, ResultExt, TryFutureExt, maybe, post_inc};
use workspace::{
Item as WorkspaceItem, ItemId, ItemNavHistory, OpenInTerminal, OpenTerminal,
CollaboratorId, Item as WorkspaceItem, ItemId, ItemNavHistory, OpenInTerminal, OpenTerminal,
RestoreOnStartupBehavior, SERIALIZATION_THROTTLE_TIME, SplitDirection, TabBarSettings, Toast,
ViewId, Workspace, WorkspaceId, WorkspaceSettings,
item::{ItemHandle, PreviewTabsSettings},
@@ -914,7 +914,7 @@ pub struct Editor {
input_enabled: bool,
use_modal_editing: bool,
read_only: bool,
leader_peer_id: Option<PeerId>,
leader_id: Option<CollaboratorId>,
remote_id: Option<ViewId>,
pub hover_state: HoverState,
pending_mouse_down: Option<Rc<RefCell<Option<MouseDownEvent>>>>,
@@ -981,6 +981,8 @@ pub struct Editor {
addons: HashMap<TypeId, Box<dyn Addon>>,
registered_buffers: HashMap<BufferId, OpenLspBufferHandle>,
load_diff_task: Option<Shared<Task<()>>>,
/// Whether we are temporarily displaying a diff other than git's
temporary_diff_override: bool,
selection_mark_mode: bool,
toggle_fold_multiple_buffers: Task<()>,
_scroll_cursor_center_top_bottom_task: Task<()>,
@@ -1057,10 +1059,10 @@ pub struct RemoteSelection {
pub replica_id: ReplicaId,
pub selection: Selection<Anchor>,
pub cursor_shape: CursorShape,
pub peer_id: PeerId,
pub collaborator_id: CollaboratorId,
pub line_mode: bool,
pub participant_index: Option<ParticipantIndex>,
pub user_name: Option<SharedString>,
pub color: PlayerColor,
}
#[derive(Clone, Debug)]
@@ -1626,7 +1628,8 @@ impl Editor {
let mut load_uncommitted_diff = None;
if let Some(project) = project.clone() {
load_uncommitted_diff = Some(
get_uncommitted_diff_for_buffer(
update_uncommitted_diff_for_buffer(
cx.entity(),
&project,
buffer.read(cx).all_buffers(),
buffer.clone(),
@@ -1720,7 +1723,7 @@ impl Editor {
use_auto_surround: true,
auto_replace_emoji_shortcode: false,
jsx_tag_auto_close_enabled_in_any_buffer: false,
leader_peer_id: None,
leader_id: None,
remote_id: None,
hover_state: Default::default(),
pending_mouse_down: None,
@@ -1802,6 +1805,7 @@ impl Editor {
serialize_folds: Task::ready(()),
text_style_refinement: None,
load_diff_task: load_uncommitted_diff,
temporary_diff_override: false,
mouse_cursor_hidden: false,
hide_mouse_mode: EditorSettings::get_global(cx)
.hide_mouse
@@ -1941,7 +1945,7 @@ impl Editor {
.is_some_and(|menu| menu.context_menu.focus_handle(cx).is_focused(window))
}
fn key_context(&self, window: &Window, cx: &App) -> KeyContext {
pub fn key_context(&self, window: &Window, cx: &App) -> KeyContext {
self.key_context_internal(self.has_active_inline_completion(), window, cx)
}
@@ -2171,8 +2175,8 @@ impl Editor {
});
}
pub fn leader_peer_id(&self) -> Option<PeerId> {
self.leader_peer_id
pub fn leader_id(&self) -> Option<CollaboratorId> {
self.leader_id
}
pub fn buffer(&self) -> &Entity<MultiBuffer> {
@@ -2513,7 +2517,7 @@ impl Editor {
}
}
if self.focus_handle.is_focused(window) && self.leader_peer_id.is_none() {
if self.focus_handle.is_focused(window) && self.leader_id.is_none() {
self.buffer.update(cx, |buffer, cx| {
buffer.set_active_selections(
&self.selections.disjoint_anchors(),
@@ -13649,7 +13653,7 @@ impl Editor {
self.refresh_inline_completion(false, true, window, cx);
}
fn go_to_next_hunk(&mut self, _: &GoToHunk, window: &mut Window, cx: &mut Context<Self>) {
pub fn go_to_next_hunk(&mut self, _: &GoToHunk, window: &mut Window, cx: &mut Context<Self>) {
self.hide_mouse_cursor(&HideMouseCursorOrigin::MovementAction);
let snapshot = self.snapshot(window, cx);
let selection = self.selections.newest::<Point>(cx);
@@ -16239,9 +16243,9 @@ impl Editor {
&mut self,
ids: impl IntoIterator<Item = CreaseId>,
cx: &mut Context<Self>,
) {
) -> Vec<(CreaseId, Range<Anchor>)> {
self.display_map
.update(cx, |map, cx| map.remove_creases(ids, cx));
.update(cx, |map, cx| map.remove_creases(ids, cx))
}
pub fn longest_row(&self, cx: &mut App) -> DisplayRow {
@@ -17820,7 +17824,8 @@ impl Editor {
let buffer_id = buffer.read(cx).remote_id();
if self.buffer.read(cx).diff_for(buffer_id).is_none() {
if let Some(project) = &self.project {
get_uncommitted_diff_for_buffer(
update_uncommitted_diff_for_buffer(
cx.entity(),
project,
[buffer.clone()],
self.buffer.clone(),
@@ -17896,6 +17901,32 @@ impl Editor {
};
}
pub fn start_temporary_diff_override(&mut self) {
self.load_diff_task.take();
self.temporary_diff_override = true;
}
pub fn end_temporary_diff_override(&mut self, cx: &mut Context<Self>) {
self.temporary_diff_override = false;
self.set_render_diff_hunk_controls(Arc::new(render_diff_hunk_controls), cx);
self.buffer.update(cx, |buffer, cx| {
buffer.set_all_diff_hunks_collapsed(cx);
});
if let Some(project) = self.project.clone() {
self.load_diff_task = Some(
update_uncommitted_diff_for_buffer(
cx.entity(),
&project,
self.buffer.read(cx).all_buffers(),
self.buffer.clone(),
cx,
)
.shared(),
);
}
}
fn on_display_map_changed(
&mut self,
_: Entity<DisplayMap>,
@@ -18459,7 +18490,7 @@ impl Editor {
self.show_cursor_names(window, cx);
self.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx);
if self.leader_peer_id.is_none() {
if self.leader_id.is_none() {
buffer.set_active_selections(
&self.selections.disjoint_anchors(),
self.selections.line_mode,
@@ -18875,7 +18906,8 @@ fn insert_extra_newline_tree_sitter(buffer: &MultiBufferSnapshot, range: Range<u
.all(|c| c.is_whitespace() && c != '\n')
}
fn get_uncommitted_diff_for_buffer(
fn update_uncommitted_diff_for_buffer(
editor: Entity<Editor>,
project: &Entity<Project>,
buffers: impl IntoIterator<Item = Entity<Buffer>>,
buffer: Entity<MultiBuffer>,
@@ -18891,6 +18923,13 @@ fn get_uncommitted_diff_for_buffer(
});
cx.spawn(async move |cx| {
let diffs = future::join_all(tasks).await;
if editor
.read_with(cx, |editor, _cx| editor.temporary_diff_override)
.unwrap_or(false)
{
return;
}
buffer
.update(cx, |buffer, cx| {
for diff in diffs.into_iter().flatten() {
@@ -19889,18 +19928,34 @@ impl EditorSnapshot {
self.buffer_snapshot
.selections_in_range(range, false)
.filter_map(move |(replica_id, line_mode, cursor_shape, selection)| {
let collaborator = collaborators_by_replica_id.get(&replica_id)?;
let participant_index = participant_indices.get(&collaborator.user_id).copied();
let user_name = participant_names.get(&collaborator.user_id).cloned();
Some(RemoteSelection {
replica_id,
selection,
cursor_shape,
line_mode,
participant_index,
peer_id: collaborator.peer_id,
user_name,
})
if replica_id == AGENT_REPLICA_ID {
Some(RemoteSelection {
replica_id,
selection,
cursor_shape,
line_mode,
collaborator_id: CollaboratorId::Agent,
user_name: Some("Agent".into()),
color: cx.theme().players().agent(),
})
} else {
let collaborator = collaborators_by_replica_id.get(&replica_id)?;
let participant_index = participant_indices.get(&collaborator.user_id).copied();
let user_name = participant_names.get(&collaborator.user_id).cloned();
Some(RemoteSelection {
replica_id,
selection,
cursor_shape,
line_mode,
collaborator_id: CollaboratorId::PeerId(collaborator.peer_id),
user_name,
color: if let Some(index) = participant_index {
cx.theme().players().color_for_participant(index.0)
} else {
cx.theme().players().absent()
},
})
}
})
}

View File

@@ -101,6 +101,7 @@ pub struct Toolbar {
pub breadcrumbs: bool,
pub quick_actions: bool,
pub selections_menu: bool,
pub agent_review: bool,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
@@ -400,11 +401,15 @@ pub struct ToolbarContent {
///
/// Default: true
pub quick_actions: Option<bool>,
/// Whether to show the selections menu in the editor toolbar
/// Whether to show the selections menu in the editor toolbar.
///
/// Default: true
pub selections_menu: Option<bool>,
/// Whether to display Agent review buttons in the editor toolbar.
/// Only applicable while reviewing a file edited by the Agent.
///
/// Default: true
pub agent_review: Option<bool>,
}
/// Scrollbar related settings

View File

@@ -12650,7 +12650,7 @@ async fn test_following_with_multiple_excerpts(cx: &mut TestAppContext) {
Editor::from_state_proto(
workspace_entity,
ViewId {
creator: Default::default(),
creator: CollaboratorId::PeerId(PeerId::default()),
id: 0,
},
&mut state_message,
@@ -12737,7 +12737,7 @@ async fn test_following_with_multiple_excerpts(cx: &mut TestAppContext) {
Editor::from_state_proto(
workspace_entity,
ViewId {
creator: Default::default(),
creator: CollaboratorId::PeerId(PeerId::default()),
id: 0,
},
&mut state_message,

View File

@@ -28,7 +28,6 @@ use crate::{
scroll::scroll_amount::ScrollAmount,
};
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
use client::ParticipantIndex;
use collections::{BTreeMap, HashMap};
use feature_flags::{DebuggerFeatureFlag, FeatureFlagAppExt};
use file_icons::FileIcons;
@@ -82,7 +81,7 @@ use theme::{ActiveTheme, Appearance, BufferLineHeight, PlayerColor};
use ui::{ButtonLike, KeyBinding, POPOVER_Y_PADDING, Tooltip, h_flex, prelude::*};
use unicode_segmentation::UnicodeSegmentation;
use util::{RangeExt, ResultExt, debug_panic};
use workspace::{Workspace, item::Item, notifications::NotifyTaskExt};
use workspace::{CollaboratorId, Workspace, item::Item, notifications::NotifyTaskExt};
const INLINE_BLAME_PADDING_EM_WIDTHS: f32 = 7.;
@@ -1126,7 +1125,7 @@ impl EditorElement {
editor.cursor_shape,
&snapshot.display_snapshot,
is_newest,
editor.leader_peer_id.is_none(),
editor.leader_id.is_none(),
None,
);
if is_newest {
@@ -1150,18 +1149,29 @@ impl EditorElement {
if let Some(collaboration_hub) = &editor.collaboration_hub {
// When following someone, render the local selections in their color.
if let Some(leader_id) = editor.leader_peer_id {
if let Some(collaborator) = collaboration_hub.collaborators(cx).get(&leader_id)
{
if let Some(participant_index) = collaboration_hub
.user_participant_indices(cx)
.get(&collaborator.user_id)
{
if let Some(leader_id) = editor.leader_id {
match leader_id {
CollaboratorId::PeerId(peer_id) => {
if let Some(collaborator) =
collaboration_hub.collaborators(cx).get(&peer_id)
{
if let Some(participant_index) = collaboration_hub
.user_participant_indices(cx)
.get(&collaborator.user_id)
{
if let Some((local_selection_style, _)) = selections.first_mut()
{
*local_selection_style = cx
.theme()
.players()
.color_for_participant(participant_index.0);
}
}
}
}
CollaboratorId::Agent => {
if let Some((local_selection_style, _)) = selections.first_mut() {
*local_selection_style = cx
.theme()
.players()
.color_for_participant(participant_index.0);
*local_selection_style = cx.theme().players().agent();
}
}
}
@@ -1173,12 +1183,9 @@ impl EditorElement {
collaboration_hub.as_ref(),
cx,
) {
let selection_style =
Self::get_participant_color(selection.participant_index, cx);
// Don't re-render the leader's selections, since the local selections
// match theirs.
if Some(selection.peer_id) == editor.leader_peer_id {
if Some(selection.collaborator_id) == editor.leader_id {
continue;
}
let key = HoveredCursor {
@@ -1191,7 +1198,7 @@ impl EditorElement {
remote_selections
.entry(selection.replica_id)
.or_insert((selection_style, Vec::new()))
.or_insert((selection.color, Vec::new()))
.1
.push(SelectionLayout::new(
selection.selection,
@@ -1246,9 +1253,11 @@ impl EditorElement {
collaboration_hub.deref(),
cx,
) {
let color = Self::get_participant_color(remote_selection.participant_index, cx);
add_cursor(remote_selection.selection.head(), color.cursor);
if Some(remote_selection.peer_id) == editor.leader_peer_id {
add_cursor(
remote_selection.selection.head(),
remote_selection.color.cursor,
);
if Some(remote_selection.collaborator_id) == editor.leader_id {
skip_local = true;
}
}
@@ -2446,14 +2455,6 @@ impl EditorElement {
Some(button)
}
fn get_participant_color(participant_index: Option<ParticipantIndex>, cx: &App) -> PlayerColor {
if let Some(index) = participant_index {
cx.theme().players().color_for_participant(index.0)
} else {
cx.theme().players().absent()
}
}
fn calculate_relative_line_numbers(
&self,
snapshot: &EditorSnapshot,

View File

@@ -23,7 +23,7 @@ use project::{
Project, ProjectItem as _, ProjectPath, lsp_store::FormatTrigger,
project_settings::ProjectSettings, search::SearchQuery,
};
use rpc::proto::{self, PeerId, update_view};
use rpc::proto::{self, update_view};
use settings::Settings;
use std::{
any::TypeId,
@@ -39,7 +39,7 @@ use theme::{Theme, ThemeSettings};
use ui::{IconDecorationKind, prelude::*};
use util::{ResultExt, TryFutureExt, paths::PathExt};
use workspace::{
ItemId, ItemNavHistory, ToolbarItemLocation, ViewId, Workspace, WorkspaceId,
CollaboratorId, ItemId, ItemNavHistory, ToolbarItemLocation, ViewId, Workspace, WorkspaceId,
item::{FollowableItem, Item, ItemEvent, ProjectItem},
searchable::{Direction, SearchEvent, SearchableItem, SearchableItemHandle},
};
@@ -170,14 +170,14 @@ impl FollowableItem for Editor {
}))
}
fn set_leader_peer_id(
fn set_leader_id(
&mut self,
leader_peer_id: Option<PeerId>,
leader_id: Option<CollaboratorId>,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.leader_peer_id = leader_peer_id;
if self.leader_peer_id.is_some() {
self.leader_id = leader_id;
if self.leader_id.is_some() {
self.buffer.update(cx, |buffer, cx| {
buffer.remove_active_selections(cx);
});
@@ -350,6 +350,30 @@ impl FollowableItem for Editor {
None
}
}
fn update_agent_location(
&mut self,
location: language::Anchor,
window: &mut Window,
cx: &mut Context<Self>,
) {
let buffer = self.buffer.read(cx);
let buffer = buffer.read(cx);
let Some((excerpt_id, _, _)) = buffer.as_singleton() else {
return;
};
let position = buffer.anchor_in_excerpt(*excerpt_id, location).unwrap();
let selection = Selection {
id: 0,
reversed: false,
start: position,
end: position,
goal: SelectionGoal::None,
};
drop(buffer);
self.set_selections_from_remote(vec![selection], None, window, cx);
self.request_autoscroll_remotely(Autoscroll::center(), cx);
}
}
async fn update_editor_from_message(
@@ -1293,7 +1317,7 @@ impl ProjectItem for Editor {
fn for_project_item(
project: Entity<Project>,
pane: &Pane,
pane: Option<&Pane>,
buffer: Entity<Buffer>,
window: &mut Window,
cx: &mut Context<Self>,
@@ -1304,7 +1328,7 @@ impl ProjectItem for Editor {
{
if WorkspaceSettings::get(None, cx).restore_on_file_reopen {
if let Some(restoration_data) = Self::project_item_kind()
.and_then(|kind| pane.project_item_restoration_data.get(&kind))
.and_then(|kind| pane.as_ref()?.project_item_restoration_data.get(&kind))
.and_then(|data| data.downcast_ref::<EditorRestorationData>())
.and_then(|data| {
let file = project::File::from_dyn(buffer.read(cx).file())?;

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use crate::Editor;
use collections::HashMap;
use futures::stream::FuturesUnordered;
use gpui::AsyncApp;
use gpui::{App, AppContext as _, Entity, Task};
use itertools::Itertools;
use language::Buffer;
@@ -74,6 +75,39 @@ where
})
}
async fn lsp_task_context(
project: &Entity<Project>,
buffer: &Entity<Buffer>,
cx: &mut AsyncApp,
) -> Option<TaskContext> {
let worktree_store = project
.update(cx, |project, _| project.worktree_store())
.ok()?;
let worktree_abs_path = cx
.update(|cx| {
let worktree_id = buffer.read(cx).file().map(|f| f.worktree_id(cx));
worktree_id
.and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
.and_then(|worktree| worktree.read(cx).root_dir())
})
.ok()?;
let project_env = project
.update(cx, |project, cx| {
project.buffer_environment(&buffer, &worktree_store, cx)
})
.ok()?
.await;
Some(TaskContext {
cwd: worktree_abs_path.map(|p| p.to_path_buf()),
project_env: project_env.unwrap_or_default(),
..TaskContext::default()
})
}
pub fn lsp_tasks(
project: Entity<Project>,
task_sources: &HashMap<LanguageServerName, Vec<BufferId>>,
@@ -97,13 +131,16 @@ pub fn lsp_tasks(
cx.spawn(async move |cx| {
let mut lsp_tasks = Vec::new();
let lsp_task_context = TaskContext::default();
while let Some(server_to_query) = lsp_task_sources.next().await {
if let Some((server_id, buffers)) = server_to_query {
let source_kind = TaskSourceKind::Lsp(server_id);
let id_base = source_kind.to_id_base();
let mut new_lsp_tasks = Vec::new();
for buffer in buffers {
let lsp_buffer_context = lsp_task_context(&project, &buffer, cx)
.await
.unwrap_or_default();
if let Ok(runnables_task) = project.update(cx, |project, cx| {
let buffer_id = buffer.read(cx).remote_id();
project.request_lsp(
@@ -120,7 +157,7 @@ pub fn lsp_tasks(
new_lsp_tasks.extend(new_runnables.runnables.into_iter().filter_map(
|(location, runnable)| {
let resolved_task =
runnable.resolve_task(&id_base, &lsp_task_context)?;
runnable.resolve_task(&id_base, &lsp_buffer_context)?;
Some((location, resolved_task))
},
));

View File

@@ -46,9 +46,12 @@ struct Args {
/// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
#[arg(value_name = "EXAMPLE_SUBSTRING")]
filter: Vec<String>,
/// Model to use (default: "claude-3-7-sonnet-latest")
/// ID of model to use.
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model: String,
/// Model provider to use.
#[arg(long, default_value = "anthropic")]
provider: String,
#[arg(long, value_delimiter = ',', default_value = "rs,ts")]
languages: Vec<String>,
/// How many times to run each example.
@@ -123,7 +126,7 @@ fn main() {
let mut cumulative_tool_metrics = ToolMetrics::default();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap();
let model_provider_id = model.provider_id();
let model_provider = model_registry.provider(&model_provider_id).unwrap();
@@ -169,11 +172,14 @@ fn main() {
continue;
}
if meta.language_server.map_or(false, |language| {
!languages.contains(&language.file_extension)
}) {
skipped.push(meta.name);
continue;
if let Some(language) = meta.language_server {
if !languages.contains(&language.file_extension) {
panic!(
"Eval for {:?} could not be run because no language server was found for extension {:?}",
meta.name,
language.file_extension
);
}
}
// TODO: This creates a worktree per repetition. Ideally these examples should
@@ -449,27 +455,36 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
}
pub fn find_model(
model_name: &str,
provider_id: &str,
model_id: &str,
model_registry: &LanguageModelRegistry,
cx: &App,
) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model = model_registry
let matching_models = model_registry
.available_models(cx)
.find(|model| model.id().0 == model_name);
.filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id)
.collect::<Vec<_>>();
let Some(model) = model else {
return Err(anyhow!(
"No language model named {} was available. Available models: {}",
model_name,
match matching_models.as_slice() {
[model] => Ok(model.clone()),
[] => Err(anyhow!(
"No language model with ID {} was available. Available models: {}",
model_id,
model_registry
.available_models(cx)
.map(|model| model.id().0.clone())
.collect::<Vec<_>>()
.join(", ")
));
};
Ok(model)
)),
_ => Err(anyhow!(
"Multiple language models with ID {} available - use `--provider` to choose one of: {:?}",
model_id,
matching_models
.iter()
.map(|model| model.provider_id().0)
.collect::<Vec<_>>()
)),
}
}
pub fn commit_sha_for_path(repo_path: &Path) -> String {

View File

@@ -119,6 +119,7 @@ impl ExampleContext {
text.to_string(),
ContextLoadResult::default(),
None,
Vec::new(),
cx,
);
})
@@ -233,9 +234,11 @@ impl ExampleContext {
tx.try_send(Err(anyhow!(err.clone()))).ok();
}
},
ThreadEvent::StreamedAssistantText(_, _)
ThreadEvent::NewRequest
| ThreadEvent::StreamedAssistantText(_, _)
| ThreadEvent::StreamedAssistantThinking(_, _)
| ThreadEvent::UsePendingTools { .. } => {}
| ThreadEvent::UsePendingTools { .. }
| ThreadEvent::CompletionCanceled => {}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,

View File

@@ -0,0 +1,61 @@
use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
use anyhow::Result;
use assistant_tools::StreamingEditFileToolInput;
use async_trait::async_trait;
pub struct CommentTranslation;
#[async_trait(?Send)]
impl Example for CommentTranslation {
fn meta(&self) -> ExampleMetadata {
ExampleMetadata {
name: "comment_translation".to_string(),
url: "https://github.com/servo/font-kit.git".to_string(),
revision: "504d084e29bce4f60614bc702e91af7f7d9e60ad".to_string(),
language_server: None,
max_assertions: Some(1),
}
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
cx.push_user_message(r#"
Edit the following files and translate all their comments to italian, in this exact order:
- font-kit/src/family.rs
- font-kit/src/canvas.rs
- font-kit/src/error.rs
"#);
cx.run_to_end().await?;
let mut create_or_overwrite_count = 0;
cx.agent_thread().read_with(cx, |thread, cx| {
for message in thread.messages() {
for tool_use in thread.tool_uses_for_message(message.id, cx) {
if tool_use.name == "edit_file" {
let input: StreamingEditFileToolInput =
serde_json::from_value(tool_use.input)?;
if input.create_or_overwrite {
create_or_overwrite_count += 1;
}
}
}
}
anyhow::Ok(())
})??;
cx.assert_eq(create_or_overwrite_count, 0, "no_creation_or_overwrite")?;
Ok(())
}
fn diff_assertions(&self) -> Vec<JudgeAssertion> {
vec![JudgeAssertion {
id: "comments_translated".to_string(),
description: concat!(
"- Only `family.rs`, `canvas.rs` and `error.rs` should have changed.\n",
"- Their doc comments should have been all translated to Italian."
)
.into(),
}]
}
}

View File

@@ -13,13 +13,17 @@ use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
mod add_arg_to_trait_method;
mod code_block_citations;
mod comment_translation;
mod file_search;
mod planets;
pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
let mut threads: Vec<Rc<dyn Example>> = vec![
Rc::new(file_search::FileSearchExample),
Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
Rc::new(code_block_citations::CodeBlockCitations),
Rc::new(planets::Planets),
Rc::new(comment_translation::CommentTranslation),
];
for example_path in list_declarative_examples(examples_dir).unwrap() {

View File

@@ -0,0 +1,73 @@
use anyhow::Result;
use assistant_tool::Tool;
use assistant_tools::{OpenTool, TerminalTool};
use async_trait::async_trait;
use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
pub struct Planets;
#[async_trait(?Send)]
impl Example for Planets {
fn meta(&self) -> ExampleMetadata {
ExampleMetadata {
name: "planets".to_string(),
url: "https://github.com/roc-lang/roc".to_string(), // This commit in this repo is just the Apache2 license,
revision: "59e49c75214f60b4dc4a45092292061c8c26ce27".to_string(), // so effectively a blank project.
language_server: None,
max_assertions: None,
}
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
cx.push_user_message(
r#"
Make a plain JavaScript web page which renders an animated 3D solar system.
Let me drag to rotate the camera around.
Do not use npm.
"#
.to_string(),
);
let response = cx.run_to_end().await?;
let mut open_tool_uses = 0;
let mut terminal_tool_uses = 0;
for tool_use in response.tool_uses() {
if tool_use.name == OpenTool.name() {
open_tool_uses += 1;
} else if tool_use.name == TerminalTool.name() {
terminal_tool_uses += 1;
}
}
// The open tool should only be used when requested, which it was not.
cx.assert_eq(open_tool_uses, 0, "`open` tool was not used")
.ok();
// No reason to use the terminal if not using npm.
cx.assert_eq(terminal_tool_uses, 0, "`terminal` tool was not used")
.ok();
Ok(())
}
fn diff_assertions(&self) -> Vec<JudgeAssertion> {
vec![
JudgeAssertion {
id: "animated solar system".to_string(),
description: "This page should render a solar system, and it should be animated."
.to_string(),
},
JudgeAssertion {
id: "drag to rotate camera".to_string(),
description: "The user can drag to rotate the camera around.".to_string(),
},
JudgeAssertion {
id: "plain JavaScript".to_string(),
description:
"The code base uses plain JavaScript and no npm, along with HTML and CSS."
.to_string(),
},
]
}
}

View File

@@ -1017,7 +1017,8 @@ pub fn response_events_to_markdown(
}
Ok(
LanguageModelCompletionEvent::UsageUpdate(_)
| LanguageModelCompletionEvent::StartMessage { .. },
| LanguageModelCompletionEvent::StartMessage { .. }
| LanguageModelCompletionEvent::StatusUpdate { .. },
) => {}
Err(error) => {
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
@@ -1092,6 +1093,7 @@ impl ThreadDialog {
// Skip these
Ok(LanguageModelCompletionEvent::UsageUpdate(_))
| Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}

View File

@@ -1,10 +1,10 @@
/// Configuration for a context server.
/// Configuration for context server setup and installation.
#[derive(Debug, Clone)]
pub struct ContextServerConfiguration {
/// Installation instructions for the user.
/// Installation instructions in Markdown format.
pub installation_instructions: String,
/// Default settings for the context server.
pub default_settings: String,
/// JSON schema describing server settings.
/// JSON schema for settings validation.
pub settings_schema: serde_json::Value,
/// Default settings template.
pub default_settings: String,
}

View File

@@ -6,8 +6,7 @@ repository = "https://github.com/zed-industries/zed"
documentation = "https://docs.rs/zed_extension_api"
keywords = ["zed", "extension"]
edition.workspace = true
# Change back to `true` when we're ready to publish v0.5.0.
publish = false
publish = true
license = "Apache-2.0"
[lints]

View File

@@ -23,7 +23,7 @@ need to set your `crate-type` accordingly:
```toml
[dependencies]
zed_extension_api = "0.4.0"
zed_extension_api = "0.5.0"
[lib]
crate-type = ["cdylib"]
@@ -63,6 +63,7 @@ Here is the compatibility of the `zed_extension_api` with versions of Zed:
| Zed version | `zed_extension_api` version |
| ----------- | --------------------------- |
| `0.186.x` | `0.0.1` - `0.5.0` |
| `0.184.x` | `0.0.1` - `0.4.0` |
| `0.178.x` | `0.0.1` - `0.3.0` |
| `0.162.x` | `0.0.1` - `0.2.0` |

View File

@@ -1,11 +1,11 @@
interface context-server {
///
/// Configuration for context server setup and installation.
record context-server-configuration {
///
/// Installation instructions in Markdown format.
installation-instructions: string,
///
/// JSON schema for settings validation.
settings-schema: string,
///
/// Default settings template.
default-settings: string,
}
}

View File

@@ -62,7 +62,7 @@ pub fn wasm_api_version_range(release_channel: ReleaseChannel) -> RangeInclusive
let max_version = match release_channel {
ReleaseChannel::Dev | ReleaseChannel::Nightly => latest::MAX_VERSION,
ReleaseChannel::Stable | ReleaseChannel::Preview => since_v0_4_0::MAX_VERSION,
ReleaseChannel::Stable | ReleaseChannel::Preview => latest::MAX_VERSION,
};
since_v0_0_1::MIN_VERSION..=max_version
@@ -113,8 +113,6 @@ impl Extension {
let _ = release_channel;
if version >= latest::MIN_VERSION {
authorize_access_to_unreleased_wasm_api_version(release_channel)?;
let extension =
latest::Extension::instantiate_async(store, component, latest::linker())
.await

View File

@@ -8,7 +8,6 @@ use wasmtime::component::{Linker, Resource};
use super::latest;
pub const MIN_VERSION: SemanticVersion = SemanticVersion::new(0, 4, 0);
pub const MAX_VERSION: SemanticVersion = SemanticVersion::new(0, 4, 0);
wasmtime::component::bindgen!({
async: true,

View File

@@ -19,12 +19,11 @@ pub static ZED_DISABLE_STAFF: LazyLock<bool> = LazyLock::new(|| {
impl FeatureFlags {
fn has_flag<T: FeatureFlag>(&self) -> bool {
if self.staff && T::enabled_for_staff() {
if T::enabled_for_all() {
return true;
}
#[cfg(debug_assertions)]
if T::enabled_in_development() {
if self.staff && T::enabled_for_staff() {
return true;
}
@@ -48,21 +47,38 @@ pub trait FeatureFlag {
true
}
fn enabled_in_development() -> bool {
Self::enabled_for_staff() && !*ZED_DISABLE_STAFF
/// Returns whether this feature flag is enabled for everyone.
///
/// This is generally done on the server, but we provide this as a way to entirely enable a feature flag client-side
/// without needing to remove all of the call sites.
fn enabled_for_all() -> bool {
false
}
}
/// Controls the values of various feature flags for the Agent launch.
///
/// Change this to `true` when we're ready to build the release candidate.
const AGENT_LAUNCH: bool = false;
pub struct Assistant2FeatureFlag;
impl FeatureFlag for Assistant2FeatureFlag {
const NAME: &'static str = "assistant2";
fn enabled_for_all() -> bool {
AGENT_LAUNCH
}
}
pub struct AgentStreamEditsFeatureFlag;
impl FeatureFlag for AgentStreamEditsFeatureFlag {
const NAME: &'static str = "agent-stream-edits";
fn enabled_for_all() -> bool {
AGENT_LAUNCH
}
}
pub struct NewBillingFeatureFlag;
@@ -70,8 +86,8 @@ pub struct NewBillingFeatureFlag;
impl FeatureFlag for NewBillingFeatureFlag {
const NAME: &'static str = "new-billing";
fn enabled_for_staff() -> bool {
false
fn enabled_for_all() -> bool {
AGENT_LAUNCH
}
}
@@ -144,7 +160,6 @@ where
if self
.try_global::<FeatureFlags>()
.is_some_and(|f| f.has_flag::<T>())
|| cfg!(debug_assertions) && T::enabled_in_development()
{
self.defer_in(window, move |view, window, cx| {
callback(view, window, cx);

Some files were not shown because too many files have changed in this diff Show More