Compare commits

..

67 Commits

Author SHA1 Message Date
Conrad Irwin
d36304963e History, is history 2025-08-13 13:44:39 -06:00
Conrad Irwin
0d71351b02 Merge branch 'main' into message-editor 2025-08-13 13:07:25 -06:00
Conrad Irwin
b06fe288f3 Clip Clop 2025-08-13 13:05:28 -06:00
smit
4a35498829 copilot: Fix Copilot fails to sign in (#36138)
Closes #36093

Pin copilot version to 1.354 for now until further investigation.

Release Notes:

- Fixes issue where Copilot failed to sign in.

Co-authored-by: MrSubidubi <dev@bahn.sh>
2025-08-14 00:19:37 +05:30
Conrad Irwin
fd0ffb737f Create a new MessageEditor 2025-08-13 12:01:50 -06:00
Joseph T. Lyons
e52f148304 Bump Zed to v0.201 (#36132)
Release Notes:

-N/A
2025-08-13 17:56:51 +00:00
Danilo Leal
cb0bc463f1 agent2: Add new "new thread" selector in the toolbar (#36133)
Release Notes:

- N/A
2025-08-13 14:45:37 -03:00
ponychicken
9a375f1419 Add some documentation for Helix mode (#35641)
Because there is literally no mention of it in the docs

Release Notes:

- N/A

---------

Co-authored-by: ponychicken <183302+ponychicken@users.noreply.github.com>
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-08-13 17:36:18 +00:00
Marshall Bowers
4238e640fa emmet: Bump to v0.0.6 (#36129)
This PR bumps the Emmet extension to v0.0.6.

Changes:

- https://github.com/zed-industries/zed/pull/36126

Release Notes:

- N/A
2025-08-13 16:55:02 +00:00
Anthony Eid
0b9c9f5f2d onboarding: Make Welcome page persistent (#36127)
Release Notes:

- N/A
2025-08-13 16:42:09 +00:00
Marshall Bowers
2da80e4641 emmet: Use index.js directly to launch language server (#36126)
This PR updates the Emmet extension to use the `index.js` file directly
to launch the language server.

This provides better cross-platform support, as we're not relying on
platform-specific `.bin` wrappers.

Release Notes:

- N/A
2025-08-13 16:34:18 +00:00
Danilo Leal
d9a94a5496 onboarding: Remove feature flag and old welcome crate (#36110)
Release Notes:

- N/A

---------

Co-authored-by: MrSubidubi <dev@bahn.sh>
Co-authored-by: Anthony <anthony@zed.dev>
2025-08-13 13:18:24 -03:00
Anthony Eid
a7442d8880 onboarding: Add more telemetry (#36121)
1. Welcome Page Open
2. Welcome Nav clicked
3. Skip clicked
4. Font changed
5. Import settings clicked
6. Inlay Hints
7. Git Blame
8. Format on Save
9. Font Ligature
10. Ai Enabled
11. Ai Provider Modal open


Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
2025-08-13 12:02:14 -04:00
Gilmar Sales
6c1f19571a Enhance icon detection for files with custom suffixes (#34170)
Fixes custom file suffixes (module.ts) of some icon themes like: 

- **Symbols Icon Theme** 
<img width="212" alt="image"
src="https://github.com/user-attachments/assets/419ba1b4-9d8e-46cd-891b-62fb63a8c5ae"
/>

- **Bearded Icon Theme**
<img width="209" alt="image"
src="https://github.com/user-attachments/assets/72974fce-fa72-4368-8d96-7feea7b59b7a"
/>

Release Notes:

- Fixed icon detection for files with custom suffixes like `module.ts`
that are overwritten by the language's icon `.ts`
2025-08-13 11:59:59 -04:00
Ben Brandt
23cd5b59b2 agent2: Initial infra for checkpoints and message editing (#36120)
Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-08-13 15:46:28 +00:00
Marshall Bowers
f4b0332f78 Hoist rodio to workspace level (#36113)
This PR hoists `rodio` up to a workspace dependency.

Release Notes:

- N/A
2025-08-13 13:50:13 +00:00
Danilo Leal
abde7306e3 onboarding: Adjust page layout (#36112)
Fix max-height and make it scrollable as well, if needed.

Release Notes:

- N/A
2025-08-13 10:35:47 -03:00
Ben Brandt
2b3dbe8815 agent2: Allow tools to be provider specific (#36111)
Our WebSearch tool requires access to a Zed provider

Release Notes:

- N/A
2025-08-13 13:22:05 +00:00
Finn Evers
7f1a5c6ad7 ui: Make toggle button group responsive (#36100)
This PR improves the toggle button group to be more responsive across
different layouts. This is accomplished by ensuring each button takes up
the same amount of space in the parent containers layout.

Ideally, this should be done with grids instead of a flexbox container,
as this would be much better suited for this purpose. Yet, since we lack
support for this, we go with this route for now.

| Before | After |
| --- | --- |
| <img width="1608" height="1094" alt="Bildschirmfoto 2025-08-13 um 11
24 26"
src="https://github.com/user-attachments/assets/2a4b5a59-6483-4f79-8fcb-e26e22071795"
/> | <img width="1608" height="1094" alt="Bildschirmfoto 2025-08-13 um
11 29 36"
src="https://github.com/user-attachments/assets/e6402729-6a8f-4a44-b79e-a569406edfff"
/> |


Release Notes:

- N/A
2025-08-13 14:02:20 +02:00
localcc
6307105976 Don't show default shell breadcrumbs (#36070)
Release Notes:

- N/A
2025-08-13 13:58:09 +02:00
Kirill Bulatov
8d63312eca Small worktree scan style fixes (#36104)
Part of https://github.com/zed-industries/zed/issues/35780

Release Notes:

- N/A
2025-08-13 14:29:53 +03:00
Finn Evers
81474a3de0 Change default pane split directions (#36101)
Closes #32538

This PR adjusts the defaults for splitting panes along the horizontal
and vertical actions. Based upon user feedback, the adjusted values seem
more reasonable as default settings, hence, go with these instead.

Release Notes:

- Changed the default split directions for the `pane: split horizontal`
and `pane: split vertical` actions. You can restore the old behavior by
modifying the `pane_split_direction_horizontal` and
`pane_split_direction_vertical` values in your settings.
2025-08-13 13:17:03 +02:00
Ben Brandt
db497ac867 Agent2 Model Selector (#36028)
Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-08-13 09:01:02 +00:00
Cretezy
8ff2e3e195 language_models: Add reasoning_effort for custom models (#35929)
Release Notes:

- Added `reasoning_effort` support to custom models

Tested using the following config:
```json5
  "language_models": {
    "openai": {
      "available_models": [
        {
          "name": "gpt-5-mini",
          "display_name": "GPT 5 Mini (custom reasoning)",
          "max_output_tokens": 128000,
          "max_tokens": 272000,
          "reasoning_effort": "high" // Can be minimal, low, medium (default), and high
        }
      ],
      "version": "1"
    }
  }
```

Docs:
https://platform.openai.com/docs/api-reference/chat/create#chat_create-reasoning_effort

This work could be used to split the GPT 5/5-mini/5-nano into each of
it's reasoning effort variant. E.g. `gpt-5`, `gpt-5 low`, `gpt-5
minimal`, `gpt-5 high`, and same for mini/nano.

Release Notes:

* Added a setting to control `reasoning_effort` in OpenAI models
2025-08-13 06:09:16 +00:00
Anthony Eid
96093aa465 onboarding: Link git clone button with action (#35999)
Release Notes:

- N/A
2025-08-13 01:18:11 -04:00
morgankrey
dc87f4b32e Add 4.1 to models page (#36086)
Adds opus 4.1 to models page in docs

Release Notes:

- N/A
2025-08-12 21:15:48 -06:00
Cole Miller
1957e1f642 Add locations to native agent tool calls, and wire them up to UI (#36058)
Release Notes:

- N/A

---------

Co-authored-by: Conrad <conrad@zed.dev>
2025-08-13 01:48:28 +00:00
Cole Miller
d78bd8f1d7 Fix external agent still being marked as generating after error response (#35992)
Release Notes:

- N/A
2025-08-12 21:41:00 -04:00
张小白
32975c4208 windows: Fix auto update failure when launching from the cli (#34303)
Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-08-12 17:04:30 -07:00
Piotr Osiewicz
658d56bd72 cli: Do not rely on Spotlight for --channel support (#36082)
I've recently disabled Spotlight on my Mac and found that this code path
(which I rely on a lot) ceased working for me.

Closes #ISSUE

Release Notes:

- N/A
2025-08-12 22:37:11 +00:00
Anthony Eid
13a2c53381 onboarding: Fix onboarding font context menu not scrolling to selected entry open (#36080)
The fix was changing the picker kind we used from `list` variant to a
`uniform` list

`Picker::list()` still has a bug where it's unable to scroll to it's
selected entry when the list is first openned. This is likely caused by
list not knowing the pixel offset of each element it would have to
scroll pass to get to the selected element


Release Notes:

- N/A

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-08-12 18:02:10 -04:00
Max Brunsfeld
cd234e28ce Eliminate host targets from rust toolchain file (#36077)
Only cross-compilation targets need to be listed in the rust toolchain.
So we only need to list the wasi target for extensions, and the musl
target for the linux remote server. Previously, we were causing mac,
linux, and windows target to get installed onto all developer
workstations, which is unnecessary.

Release Notes:

- N/A
2025-08-12 14:36:48 -07:00
Michael Sloan
b564b1d5d0 outline: Fix nesting in multi-name declarations in Go and C++ (#36076)
An alternative might be to adjust the logic to not nest items when their
ranges are the same, but then clicking them doesn't work properly /
moving the cursor does not change which is selected. This could probably
be made to work with some extra logic there, but it seems overkill.

The downside of fixing it at the query level is that other parts of the
declaration are not inside the item range. This seems to be fine for
single line declarations - the nearest outline item is highlighted.
However, if a part of the declaration is not included in an item range
and is on its own line, then no outline item is highlighted.

Release Notes:

- Outline Panel: Fixed nesting of var and field declarations with
multiple identifiers in Go and C++

C++ before:

<img width="743" height="227" alt="image"
src="https://github.com/user-attachments/assets/af1a1d76-ecdc-4999-ae9c-95591726ccca"
/>

C++ after:

<img width="795" height="250" alt="image"
src="https://github.com/user-attachments/assets/49667ed3-e088-48b3-a9f0-6a119b5e7648"
/>

Go before:

<img width="859" height="306" alt="image"
src="https://github.com/user-attachments/assets/ecc7530a-ca16-4f37-b8d1-60687f178b12"
/>

Go after:

<img width="900" height="334" alt="image"
src="https://github.com/user-attachments/assets/d741cfb0-59e5-4d27-bd6a-f422204dc972"
/>
2025-08-12 21:08:19 +00:00
Richard Feldman
48ae02c1ca Don't retry for PaymentRequiredError or ModelRequestLimitReachedError (#36075)
Release Notes:

- Don't auto-retry for "payment required" or "model request limit
reached" errors (since retrying won't help)
2025-08-12 21:06:01 +00:00
Anthony Eid
255bb0a3f8 telemetry: Reduce the amount of telemetry events fired (#36060)
1. Extension loaded events are now condensed into a single event with a
Vec of (extension_id, extension_version) called id_and_versions.
2. Editor Saved & AutoSaved are merged into a singular event with a type
field that is either "manual" or "autosave”.
3. Editor Edited event will only fire once every 10 minutes now.
4. Editor Closed event is fired when an editor item (tab) is removed
from a pane



cc: @katie-z-geer 

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
2025-08-12 19:56:27 +00:00
Danilo Leal
628b1058be agent2: Fix some UI glitches (#36067)
Release Notes:

- N/A
2025-08-12 16:31:54 -03:00
Oleksiy Syvokon
7167f193c0 open_ai: Send prompt_cache_key to improve caching (#36065)
Release Notes:

- N/A

Co-authored-by: Michael Sloan <mgsloan@gmail.com>
2025-08-12 21:51:23 +03:00
Oleksiy Syvokon
7ff0f1525e open_ai: Log inputs that caused parsing errors (#36063)
Release Notes:

- N/A

Co-authored-by: Michael Sloan <mgsloan@gmail.com>
2025-08-12 21:49:19 +03:00
Filip Binkiewicz
7df8e05ad9 Ignore whitespace in git blame invocation (#35960)
This works around a bug wherein inline git blame is unavailable for
files with CRLF line endings. At the same time, this prevents users from
seeing whitespace-only changes in the editor's git blame

Closes #35836

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-08-12 11:47:15 -07:00
Marshall Bowers
d030bb6281 emmet: Bump to v0.0.5 (#36066)
This PR bumps the Emmet extension to v0.0.5.

Changes:

- https://github.com/zed-industries/zed/pull/35599
- https://github.com/zed-industries/zed/pull/36064

Release Notes:

- N/A
2025-08-12 18:41:26 +00:00
张小白
b62f959528 windows: Fix message loop using too much CPU (#35969)
Closes #34374

This is a leftover issue from #34374. Back in #34374, I wanted to use
DirectX to handle vsync, after all, that’s how 99% of Windows apps do
it. But after discussing with @maxbrunsfeld , we decided to stick with
the original vsync approach given gpui’s architecture.

In my tests, there’s no noticeable performance difference between this
PR’s approach and DirectX vsync. That said, this PR’s method does have a
theoretical advantage, it doesn’t block the main thread while waiting
for vsync.


The only difference is that in this PR, on Windows 11 we use a newer API
instead of `DwmFlush`, since Chrome’s tests have shown that `DwmFlush`
has some problems. This PR also removes the use of
`MsgWaitForMultipleObjects`.


Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-08-12 11:28:47 -07:00
Marshall Bowers
3a04657730 emmet: Add workaround for leading / on Windows paths (#36064)
This PR adds a workaround for the leading `/` on Windows paths
(https://github.com/zed-industries/zed/issues/20559).

Release Notes:

- N/A
2025-08-12 18:24:25 +00:00
Mikayla Maki
42b7dbeaee Remove beta tag from cursor keymap (#36061)
Release Notes:

- N/A

Co-authored-by: Anthony Eid <hello@anthonyeid.me>
2025-08-12 17:53:19 +00:00
Max Brunsfeld
bfbb18476f Fix management of rust-analyzer binaries on windows (#36056)
Closes https://github.com/zed-industries/zed/issues/34472


* Avoid removing the just-downloaded exe
* Invoke exe within nested version directory

Release Notes:

- Fix issue where Rust-analyzer was not installed correctly on windows

Co-authored-by: Lukas Wirth <lukas@zed.dev>
2025-08-12 17:26:56 +00:00
Dino
978b75bba9 vim: Support filename in :tabedit and :tabnew commands (#35775)
Update both `:tabedit` and `:tabnew` commands in order to support a
single argument, a filename, that, when provided, ensures that the new
tab either opens an existing file or associates the new tab with the
filename, so that when saving the buffer's content, the file is created.

Relates to #21112 

Release Notes:

- vim: Added support for filenames in both `:tabnew` and `:tabedit` commands
2025-08-12 11:13:36 -06:00
localcc
1f20d5bf54 Fix nightly icon (#36051)
Release Notes:

- N/A
2025-08-12 16:18:42 +00:00
Rishabh Bothra
9de04ce215 language_models: Add vision support for OpenAI gpt-5, gpt-5-mini, and gpt-5-nano models (#36047)
## Summary
Enable image processing capabilities for GPT-5 series models by updating
the `supports_images()` method.

## Changes
- Add vision support for `gpt-5`, `gpt-5-mini`, and `gpt-5-nano` models
- Update `supports_images()` method in
`crates/language_models/src/provider/open_ai.rs`

## Models with Vision Support (after this PR)
- gpt-4o
- gpt-4o-mini
- gpt-4.1
- gpt-4.1-mini
- gpt-4.1-nano
- gpt-5 (new)
- gpt-5-mini (new)
- gpt-5-nano (new)
- o1
- o3
- o4-mini

This brings GPT-5 vision capabilities in line with other OpenAI models
that support image processing.

Release Notes:

- Added vision support for OpenAI models
2025-08-12 16:04:51 +00:00
Oleksiy Syvokon
d8fc53608e docs: Update OpenAI models list (#36050)
Closes #ISSUE

Release Notes:

- N/A
2025-08-12 16:03:13 +00:00
Joseph T. Lyons
39c19abdfd Update windows alpha GitHub Issue template (#36049)
Release Notes:

- N/A
2025-08-12 15:55:10 +00:00
Danilo Leal
b105028c05 agent2: Add custom UI for resource link content blocks (#36005)
Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-08-12 12:39:27 -03:00
Piotr Osiewicz
d2162446d0 python: Fix venv activation in remote projects (#36043)
Crux of the issue was that we were checking whether a venv activation
script exists on local filesystem, which is obviously wrong for remote
projects. This PR also does away with `source` for venv activation in
favor of `.`, which is compliant with `sh`

Co-authored-by: Lukas Wirth <lukas@zed.dev>

Closes #34648

Release Notes:

- Python: fixed activation of virtual environments in terminals for
remote projects

Co-authored-by: Lukas Wirth <lukas@zed.dev>
2025-08-12 14:33:46 +00:00
Piotr Osiewicz
360d4db87c python: Fix flickering in the status bar (#36039)
- **util: Have maybe! use async closures instead of async blocks**
- **python: Fix flickering of virtual environment indicator in status
bar**

Closes #30723

Release Notes:

- Python: Fixed flickering of the status bar virtual environment
indicator

---------

Co-authored-by: Lukas Wirth <lukas@zed.dev>
2025-08-12 13:36:28 +00:00
Agus Zubiaga
44953375cc Include mention context in acp-based native agent (#36006)
Also adds data-layer support for symbols, thread, and rules.

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
2025-08-12 13:12:58 +00:00
Antonio Scandurra
2444321756 Support profiles in agent2 (#36034)
We still need a profile selector.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-08-12 12:17:48 +00:00
Piotr Osiewicz
13bf45dd4a python: Fix toolchain serialization not working with multiple venvs in a single worktree (#36035)
Our database did not allow more than entry for a given toolchain for a
single worktree (due to incorrect primary key)

Co-authored-by: Lukas Wirth <lukas@zed.dev>

Release Notes:

- Python: Fixed toolchain selector not working with multiple venvs in a
single worktree.

Co-authored-by: Lukas Wirth <lukas@zed.dev>
2025-08-12 12:10:53 +00:00
Lukas Spiss
b61b71405d go: Add support for running sub-tests in table tests (#35657)
One killer feature for the Go runner is to execute individual subtests
within a table-test easily. Goland has had this feature forever, while
in VSCode this has been notably missing.


https://github.com/user-attachments/assets/363417a2-d1b1-43ca-8377-08ce062d6104


Release Notes:

- Added support to run Go table-test subtests.
2025-08-12 11:56:33 +03:00
Michael Sloan
cc5eb24066 zeta: Add latency telemetry for 1% of edit predictions (#36020)
Release Notes:

- N/A

Co-authored-by: Oleksiy <oleksiy@zed.dev>
2025-08-12 06:47:54 +00:00
Conrad Irwin
52a9101970 vim: Add ctrl-y/e in insert mode (#36017)
Closes #17292

Release Notes:

- vim: Added ctrl-y/ctrl-e in insert mode to copy the next character
from the line above or below
2025-08-11 23:20:09 -06:00
Conrad Irwin
1a798830cb Fix running vim tests with --features neovim (#36014)
This was broken incidentally in
https://github.com/zed-industries/zed/pull/33417

A better fix would be to fix app shutdown to take control of the
executor so that we *can* run
foreground tasks; but that is a bit fiddly (draft #36015) 

Release Notes:

- N/A
2025-08-12 05:08:58 +00:00
Kirill Bulatov
481e3e5092 Ignore capability registrations with empty capabilities (#36000) 2025-08-12 07:53:20 +03:00
Matt
b35e69692d docs: Add a missing comma in Rust debugging JSON (#36007)
Update the Rust debugging doc to include a missing comma in one of the
example JSON's.
2025-08-12 03:06:02 +00:00
Conrad Irwin
add67bde43 Remove unnecessary argument from Vim#update_editor (#36001)
Release Notes:

- N/A
2025-08-11 16:10:06 -06:00
Victor Tran
fa3d0aaed4 gpui: Allow selection of "Services" menu independent of menu title (#34115)
Release Notes:

- N/A

---

In the same vein as #29538, the "Services" menu on macOS depended on the
text being exactly "Services", not allowing for i18n of the menu name.

This PR introduces a new menu type called `OsMenu` that defines a
special menu that can be populated by the system. Currently, it takes
one enum value, `ServicesMenu` that tells the system to populate its
contents with the items it would usually populate the "Services" menu
with.

An example of this being used has been implemented in the `set_menus`
example:
`cargo run -p gpui --example set_menus`

---

Point to consider:

In `mac/platform.rs:414` the existing code for setting the "Services"
menu remains for backwards compatibility. Should this remain now that
this new method exists to set the menu, or should it be removed?

---------

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-08-11 21:10:14 +00:00
Danilo Leal
094e878ccf agent2: Refine terminal tool call display (#35984)
Release Notes:

- N/A
2025-08-11 17:50:47 -03:00
Joseph T. Lyons
54d4665100 Add windows issue template (#35998)
Release Notes:

- N/A
2025-08-11 19:25:18 +00:00
localcc
2c84e33b7b Fix icon padding (#35990)
Release Notes:

- N/A
2025-08-11 19:57:39 +02:00
Bennet Bo Fenner
bb6ea22944 agent2: Port more tools (#35987)
Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-08-11 17:24:48 +00:00
188 changed files with 7958 additions and 3344 deletions

View File

@@ -0,0 +1,35 @@
name: Bug Report (Windows Alpha)
description: Zed Windows Alpha Related Bugs
type: "Bug"
labels: ["windows"]
title: "Windows Alpha: <a short description of the Windows bug>"
body:
- type: textarea
attributes:
label: Summary
description: Describe the bug with a one-line summary, and provide detailed reproduction steps
value: |
<!-- Please insert a one-line summary of the issue below -->
SUMMARY_SENTENCE_HERE
### Description
<!-- Describe with sufficient detail to reproduce from a clean Zed install. -->
Steps to trigger the problem:
1.
2.
3.
**Expected Behavior**:
**Actual Behavior**:
validations:
required: true
- type: textarea
id: environment
attributes:
label: Zed Version and System Specs
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
placeholder: |
Output of "zed: copy system specs into clipboard"
validations:
required: true

49
Cargo.lock generated
View File

@@ -10,6 +10,7 @@ dependencies = [
"agent-client-protocol",
"anyhow",
"buffer_diff",
"collections",
"editor",
"env_logger 0.11.8",
"futures 0.3.31",
@@ -17,7 +18,6 @@ dependencies = [
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"markdown",
"parking_lot",
"project",
@@ -29,7 +29,10 @@ dependencies = [
"tempfile",
"terminal",
"ui",
"url",
"util",
"uuid",
"watch",
"workspace-hack",
]
@@ -196,6 +199,7 @@ dependencies = [
"clock",
"cloud_llm_client",
"collections",
"context_server",
"ctor",
"editor",
"env_logger 0.11.8",
@@ -204,6 +208,8 @@ dependencies = [
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
"html_to_markdown",
"http_client",
"indoc",
"itertools 0.14.0",
"language",
@@ -227,6 +233,7 @@ dependencies = [
"task",
"tempfile",
"terminal",
"text",
"theme",
"tree-sitter-rust",
"ui",
@@ -6440,6 +6447,7 @@ dependencies = [
"log",
"parking_lot",
"pretty_assertions",
"rand 0.8.5",
"regex",
"rope",
"schemars",
@@ -11144,14 +11152,13 @@ dependencies = [
"ai_onboarding",
"anyhow",
"client",
"command_palette_hooks",
"component",
"db",
"documented",
"editor",
"feature_flags",
"fs",
"fuzzy",
"git",
"gpui",
"itertools 0.14.0",
"language",
@@ -11163,6 +11170,7 @@ dependencies = [
"schemars",
"serde",
"settings",
"telemetry",
"theme",
"ui",
"util",
@@ -11238,6 +11246,7 @@ dependencies = [
"anyhow",
"futures 0.3.31",
"http_client",
"log",
"schemars",
"serde",
"serde_json",
@@ -18023,6 +18032,7 @@ dependencies = [
"command_palette_hooks",
"db",
"editor",
"env_logger 0.11.8",
"futures 0.3.31",
"git_ui",
"gpui",
@@ -18876,33 +18886,6 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
[[package]]
name = "welcome"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"component",
"db",
"documented",
"editor",
"fuzzy",
"gpui",
"install_cli",
"language",
"picker",
"project",
"serde",
"settings",
"telemetry",
"ui",
"util",
"vim_mode_setting",
"workspace",
"workspace-hack",
"zed_actions",
]
[[package]]
name = "which"
version = "4.4.2"
@@ -20517,7 +20500,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.200.0"
version = "0.201.0"
dependencies = [
"activity_indicator",
"agent",
@@ -20657,7 +20640,6 @@ dependencies = [
"watch",
"web_search",
"web_search_providers",
"welcome",
"windows 0.61.1",
"winresource",
"workspace",
@@ -20681,7 +20663,7 @@ dependencies = [
[[package]]
name = "zed_emmet"
version = "0.0.4"
version = "0.0.6"
dependencies = [
"zed_extension_api 0.1.0",
]
@@ -20920,6 +20902,7 @@ dependencies = [
"menu",
"postage",
"project",
"rand 0.8.5",
"regex",
"release_channel",
"reqwest_client",

View File

@@ -185,7 +185,6 @@ members = [
"crates/watch",
"crates/web_search",
"crates/web_search_providers",
"crates/welcome",
"crates/workspace",
"crates/worktree",
"crates/x_ai",
@@ -412,7 +411,6 @@ vim_mode_setting = { path = "crates/vim_mode_setting" }
watch = { path = "crates/watch" }
web_search = { path = "crates/web_search" }
web_search_providers = { path = "crates/web_search_providers" }
welcome = { path = "crates/welcome" }
workspace = { path = "crates/workspace" }
worktree = { path = "crates/worktree" }
x_ai = { path = "crates/x_ai" }
@@ -566,6 +564,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77
"socks",
"stream",
] }
rodio = { version = "0.21.1", default-features = false }
rsa = "0.9.6"
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
"async-dispatcher-runtime",
@@ -714,6 +713,7 @@ features = [
"Win32_System_LibraryLoader",
"Win32_System_Memory",
"Win32_System_Ole",
"Win32_System_Performance",
"Win32_System_Pipes",
"Win32_System_SystemInformation",
"Win32_System_SystemServices",

View File

@@ -239,6 +239,7 @@
"ctrl-shift-a": "agent::ToggleContextPicker",
"ctrl-shift-j": "agent::ToggleNavigationMenu",
"ctrl-shift-i": "agent::ToggleOptionsMenu",
"ctrl-alt-shift-n": "agent::ToggleNewThreadMenu",
"shift-alt-escape": "agent::ExpandMessageEditor",
"ctrl->": "assistant::QuoteSelection",
"ctrl-alt-e": "agent::RemoveAllContext",
@@ -330,8 +331,6 @@
"use_key_equivalents": true,
"bindings": {
"enter": "agent::Chat",
"up": "agent::PreviousHistoryMessage",
"down": "agent::NextHistoryMessage",
"shift-ctrl-r": "agent::OpenAgentDiff",
"ctrl-shift-y": "agent::KeepAll",
"ctrl-shift-n": "agent::RejectAll"

View File

@@ -279,6 +279,7 @@
"cmd-shift-a": "agent::ToggleContextPicker",
"cmd-shift-j": "agent::ToggleNavigationMenu",
"cmd-shift-i": "agent::ToggleOptionsMenu",
"cmd-alt-shift-n": "agent::ToggleNewThreadMenu",
"shift-alt-escape": "agent::ExpandMessageEditor",
"cmd->": "assistant::QuoteSelection",
"cmd-alt-e": "agent::RemoveAllContext",
@@ -382,8 +383,6 @@
"use_key_equivalents": true,
"bindings": {
"enter": "agent::Chat",
"up": "agent::PreviousHistoryMessage",
"down": "agent::NextHistoryMessage",
"shift-ctrl-r": "agent::OpenAgentDiff",
"cmd-shift-y": "agent::KeepAll",
"cmd-shift-n": "agent::RejectAll"

View File

@@ -333,10 +333,14 @@
"ctrl-x ctrl-c": "editor::ShowEditPrediction", // zed specific
"ctrl-x ctrl-l": "editor::ToggleCodeActions", // zed specific
"ctrl-x ctrl-z": "editor::Cancel",
"ctrl-x ctrl-e": "vim::LineDown",
"ctrl-x ctrl-y": "vim::LineUp",
"ctrl-w": "editor::DeleteToPreviousWordStart",
"ctrl-u": "editor::DeleteToBeginningOfLine",
"ctrl-t": "vim::Indent",
"ctrl-d": "vim::Outdent",
"ctrl-y": "vim::InsertFromAbove",
"ctrl-e": "vim::InsertFromBelow",
"ctrl-k": ["vim::PushDigraph", {}],
"ctrl-v": ["vim::PushLiteral", {}],
"ctrl-shift-v": "editor::Paste", // note: this is *very* similar to ctrl-v in vim, but ctrl-shift-v on linux is the typical shortcut for paste when ctrl-v is already in use.

View File

@@ -82,10 +82,10 @@
// Layout mode of the bottom dock. Defaults to "contained"
// choices: contained, full, left_aligned, right_aligned
"bottom_dock_layout": "contained",
// The direction that you want to split panes horizontally. Defaults to "up"
"pane_split_direction_horizontal": "up",
// The direction that you want to split panes vertically. Defaults to "left"
"pane_split_direction_vertical": "left",
// The direction that you want to split panes horizontally. Defaults to "down"
"pane_split_direction_horizontal": "down",
// The direction that you want to split panes vertically. Defaults to "right"
"pane_split_direction_vertical": "right",
// Centered layout related settings.
"centered_layout": {
// The relative width of the left padding of the central pane from the

View File

@@ -20,12 +20,12 @@ action_log.workspace = true
agent-client-protocol.workspace = true
anyhow.workspace = true
buffer_diff.workspace = true
collections.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
markdown.workspace = true
project.workspace = true
serde.workspace = true
@@ -34,7 +34,10 @@ settings.workspace = true
smol.workspace = true
terminal.workspace = true
ui.workspace = true
url.workspace = true
util.workspace = true
uuid.workspace = true
watch.workspace = true
workspace-hack.workspace = true
[dev-dependencies]

File diff suppressed because it is too large Load Diff

View File

@@ -1,18 +1,78 @@
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use crate::AcpThread;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use gpui::{AsyncApp, Entity, Task};
use language_model::LanguageModel;
use collections::IndexMap;
use gpui::{AsyncApp, Entity, SharedString, Task};
use project::Project;
use ui::App;
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
use crate::AcpThread;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct UserMessageId(Arc<str>);
impl UserMessageId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string().into())
}
}
pub trait AgentConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>;
fn auth_methods(&self) -> &[acp::AuthMethod];
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(
&self,
user_message_id: Option<UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
fn session_editor(
&self,
_session_id: &acp::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn AgentSessionEditor>> {
None
}
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
///
/// If the agent does not support model selection, returns [None].
/// This allows sharing the selector in UI components.
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
None
}
}
pub trait AgentSessionEditor {
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
}
#[derive(Debug)]
pub struct AuthRequired;
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthRequired")
}
}
/// Trait for agents that support listing, selecting, and querying language models.
///
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
pub trait ModelSelector: 'static {
pub trait AgentModelSelector: 'static {
/// Lists all available language models for this agent.
///
/// # Parameters
@@ -20,7 +80,7 @@ pub trait ModelSelector: 'static {
///
/// # Returns
/// A task resolving to the list of models or an error (e.g., if no models are configured).
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
/// Selects a model for a specific session (thread).
///
@@ -37,8 +97,8 @@ pub trait ModelSelector: 'static {
fn select_model(
&self,
session_id: acp::SessionId,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
model_id: AgentModelId,
cx: &mut App,
) -> Task<Result<()>>;
/// Retrieves the currently selected model for a specific session (thread).
@@ -52,42 +112,51 @@ pub trait ModelSelector: 'static {
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>>;
cx: &mut App,
) -> Task<Result<AgentModelInfo>>;
/// Whenever the model list is updated the receiver will be notified.
fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
}
pub trait AgentConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct AgentModelId(pub SharedString);
fn auth_methods(&self) -> &[acp::AuthMethod];
impl std::ops::Deref for AgentModelId {
type Target = SharedString;
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
-> Task<Result<acp::PromptResponse>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
///
/// If the agent does not support model selection, returns [None].
/// This allows sharing the selector in UI components.
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
None // Default impl for agents that don't support it
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug)]
pub struct AuthRequired;
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
impl fmt::Display for AgentModelId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthRequired")
self.0.fmt(f)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AgentModelInfo {
pub id: AgentModelId,
pub name: SharedString,
pub icon: Option<IconName>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct AgentModelGroupName(pub SharedString);
#[derive(Debug, Clone)]
pub enum AgentModelList {
Flat(Vec<AgentModelInfo>),
Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
}
impl AgentModelList {
pub fn is_empty(&self) -> bool {
match self {
AgentModelList::Flat(models) => models.is_empty(),
AgentModelList::Grouped(groups) => groups.is_empty(),
}
}
}

View File

@@ -0,0 +1,125 @@
use agent_client_protocol as acp;
use anyhow::{Result, bail};
use std::path::PathBuf;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum MentionUri {
File(PathBuf),
Symbol(PathBuf, String),
Thread(acp::SessionId),
Rule(String),
}
impl MentionUri {
pub fn parse(input: &str) -> Result<Self> {
let url = url::Url::parse(input)?;
let path = url.path();
match url.scheme() {
"file" => {
if let Some(fragment) = url.fragment() {
Ok(Self::Symbol(path.into(), fragment.into()))
} else {
let file_path =
PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path));
Ok(Self::File(file_path))
}
}
"zed" => {
if let Some(thread) = path.strip_prefix("/agent/thread/") {
Ok(Self::Thread(acp::SessionId(thread.into())))
} else if let Some(rule) = path.strip_prefix("/agent/rule/") {
Ok(Self::Rule(rule.into()))
} else {
bail!("invalid zed url: {:?}", input);
}
}
other => bail!("unrecognized scheme {:?}", other),
}
}
pub fn name(&self) -> String {
match self {
MentionUri::File(path) => path.file_name().unwrap().to_string_lossy().into_owned(),
MentionUri::Symbol(_path, name) => name.clone(),
MentionUri::Thread(thread) => thread.to_string(),
MentionUri::Rule(rule) => rule.clone(),
}
}
pub fn to_link(&self) -> String {
let name = self.name();
let uri = self.to_uri();
format!("[{name}]({uri})")
}
pub fn to_uri(&self) -> String {
match self {
MentionUri::File(path) => {
format!("file://{}", path.display())
}
MentionUri::Symbol(path, name) => {
format!("file://{}#{}", path.display(), name)
}
MentionUri::Thread(thread) => {
format!("zed:///agent/thread/{}", thread.0)
}
MentionUri::Rule(rule) => {
format!("zed:///agent/rule/{}", rule)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mention_uri_parse_and_display() {
// Test file URI
let file_uri = "file:///path/to/file.rs";
let parsed = MentionUri::parse(file_uri).unwrap();
match &parsed {
MentionUri::File(path) => assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"),
_ => panic!("Expected File variant"),
}
assert_eq!(parsed.to_uri(), file_uri);
// Test symbol URI
let symbol_uri = "file:///path/to/file.rs#MySymbol";
let parsed = MentionUri::parse(symbol_uri).unwrap();
match &parsed {
MentionUri::Symbol(path, symbol) => {
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
assert_eq!(symbol, "MySymbol");
}
_ => panic!("Expected Symbol variant"),
}
assert_eq!(parsed.to_uri(), symbol_uri);
// Test thread URI
let thread_uri = "zed:///agent/thread/session123";
let parsed = MentionUri::parse(thread_uri).unwrap();
match &parsed {
MentionUri::Thread(session_id) => assert_eq!(session_id.0.as_ref(), "session123"),
_ => panic!("Expected Thread variant"),
}
assert_eq!(parsed.to_uri(), thread_uri);
// Test rule URI
let rule_uri = "zed:///agent/rule/my_rule";
let parsed = MentionUri::parse(rule_uri).unwrap();
match &parsed {
MentionUri::Rule(rule) => assert_eq!(rule, "my_rule"),
_ => panic!("Expected Rule variant"),
}
assert_eq!(parsed.to_uri(), rule_uri);
// Test invalid scheme
assert!(MentionUri::parse("http://example.com").is_err());
// Test invalid zed path
assert!(MentionUri::parse("zed:///invalid/path").is_err());
}
}

View File

@@ -29,8 +29,14 @@ impl Terminal {
cx: &mut Context<Self>,
) -> Self {
Self {
command: cx
.new(|cx| Markdown::new(command.into(), Some(language_registry.clone()), None, cx)),
command: cx.new(|cx| {
Markdown::new(
format!("```\n{}\n```", command).into(),
Some(language_registry.clone()),
None,
cx,
)
}),
working_dir,
terminal,
started_at: Instant::now(),

View File

@@ -17,8 +17,6 @@ use util::{
pub struct ActionLog {
/// Buffers that we want to notify the model about when they change.
tracked_buffers: BTreeMap<Entity<Buffer>, TrackedBuffer>,
/// Has the model edited a file since it last checked diagnostics?
edited_since_project_diagnostics_check: bool,
/// The project this action log is associated with
project: Entity<Project>,
}
@@ -28,7 +26,6 @@ impl ActionLog {
pub fn new(project: Entity<Project>) -> Self {
Self {
tracked_buffers: BTreeMap::default(),
edited_since_project_diagnostics_check: false,
project,
}
}
@@ -37,16 +34,6 @@ impl ActionLog {
&self.project
}
/// Notifies a diagnostics check
pub fn checked_project_diagnostics(&mut self) {
self.edited_since_project_diagnostics_check = false;
}
/// Returns true if any files have been edited since the last project diagnostics check
pub fn has_edited_files_since_project_diagnostics_check(&self) -> bool {
self.edited_since_project_diagnostics_check
}
pub fn latest_snapshot(&self, buffer: &Entity<Buffer>) -> Option<text::BufferSnapshot> {
Some(self.tracked_buffers.get(buffer)?.snapshot.clone())
}
@@ -543,14 +530,11 @@ impl ActionLog {
/// Mark a buffer as created by agent, so we can refresh it in the context
pub fn buffer_created(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.edited_since_project_diagnostics_check = true;
self.track_buffer_internal(buffer.clone(), true, cx);
}
/// Mark a buffer as edited by agent, so we can refresh it in the context
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.edited_since_project_diagnostics_check = true;
let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx);
if let TrackedBufferStatus::Deleted = tracked_buffer.status {
tracked_buffer.status = TrackedBufferStatus::Modified;

View File

@@ -716,18 +716,10 @@ impl ActivityIndicator {
})),
tooltip_message: Some(Self::version_tooltip_message(&version)),
}),
AutoUpdateStatus::Updated {
binary_path,
version,
} => Some(Content {
AutoUpdateStatus::Updated { version } => Some(Content {
icon: None,
message: "Click to restart and update Zed".to_string(),
on_click: Some(Arc::new({
let reload = workspace::Reload {
binary_path: Some(binary_path.clone()),
};
move |_, _, cx| workspace::reload(&reload, cx)
})),
on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))),
tooltip_message: Some(Self::version_tooltip_message(&version)),
}),
AutoUpdateStatus::Errored => Some(Content {

View File

@@ -2268,6 +2268,15 @@ impl Thread {
max_attempts: 3,
})
}
Other(err)
if err.is::<PaymentRequiredError>()
|| err.is::<ModelRequestLimitReachedError>() =>
{
// Retrying won't help for Payment Required or Model Request Limit errors (where
// the user must upgrade to usage-based billing to get more requests, or else wait
// for a significant amount of time for the request limit to reset).
None
}
// Conservatively assume that any other errors are non-retryable
HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,

View File

@@ -23,10 +23,13 @@ assistant_tools.workspace = true
chrono.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
context_server.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
http_client.workspace = true
indoc.workspace = true
itertools.workspace = true
language.workspace = true
@@ -46,6 +49,7 @@ settings.workspace = true
smol.workspace = true
task.workspace = true
terminal.workspace = true
text.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
@@ -58,6 +62,7 @@ workspace-hack.workspace = true
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] }
context_server = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }

View File

@@ -1,21 +1,26 @@
use crate::{AgentResponseEvent, Thread, templates::Templates};
use crate::{
CopyPathTool, CreateDirectoryTool, EditFileTool, FindPathTool, GrepTool, ListDirectoryTool,
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool,
ToolCallAuthorization, WebSearchTool,
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
WebSearchTool,
};
use acp_thread::ModelSelector;
use acp_thread::AgentModelSelector;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap};
use fs::Fs;
use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
use language_model::{LanguageModel, LanguageModelRegistry};
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use settings::update_settings_file;
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
@@ -48,6 +53,104 @@ struct Session {
_subscription: Subscription,
}
pub struct LanguageModels {
/// Access language model by ID
models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
/// Cached list for returning language model information
model_list: acp_thread::AgentModelList,
refresh_models_rx: watch::Receiver<()>,
refresh_models_tx: watch::Sender<()>,
}
impl LanguageModels {
fn new(cx: &App) -> Self {
let (refresh_models_tx, refresh_models_rx) = watch::channel(());
let mut this = Self {
models: HashMap::default(),
model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
refresh_models_rx,
refresh_models_tx,
};
this.refresh_list(cx);
this
}
fn refresh_list(&mut self, cx: &App) {
let providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
let mut language_model_list = IndexMap::default();
let mut recommended_models = HashSet::default();
let mut recommended = Vec::new();
for provider in &providers {
for model in provider.recommended_models(cx) {
recommended_models.insert(model.id());
recommended.push(Self::map_language_model_to_info(&model, &provider));
}
}
if !recommended.is_empty() {
language_model_list.insert(
acp_thread::AgentModelGroupName("Recommended".into()),
recommended,
);
}
let mut models = HashMap::default();
for provider in providers {
let mut provider_models = Vec::new();
for model in provider.provided_models(cx) {
let model_info = Self::map_language_model_to_info(&model, &provider);
let model_id = model_info.id.clone();
if !recommended_models.contains(&model.id()) {
provider_models.push(model_info);
}
models.insert(model_id, model);
}
if !provider_models.is_empty() {
language_model_list.insert(
acp_thread::AgentModelGroupName(provider.name().0.clone()),
provider_models,
);
}
}
self.models = models;
self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
self.refresh_models_tx.send(()).ok();
}
fn watch(&self) -> watch::Receiver<()> {
self.refresh_models_rx.clone()
}
pub fn model_from_id(
&self,
model_id: &acp_thread::AgentModelId,
) -> Option<Arc<dyn LanguageModel>> {
self.models.get(model_id).cloned()
}
fn map_language_model_to_info(
model: &Arc<dyn LanguageModel>,
provider: &Arc<dyn LanguageModelProvider>,
) -> acp_thread::AgentModelInfo {
acp_thread::AgentModelInfo {
id: Self::model_id(model),
name: model.name().0,
icon: Some(provider.icon()),
}
}
fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
}
}
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
@@ -55,10 +158,14 @@ pub struct NativeAgent {
project_context: Rc<RefCell<ProjectContext>>,
project_context_needs_refresh: watch::Sender<()>,
_maintain_project_context: Task<Result<()>>,
context_server_registry: Entity<ContextServerRegistry>,
/// Shared templates for all threads
templates: Arc<Templates>,
/// Cached model information
models: LanguageModels,
project: Entity<Project>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>,
}
@@ -67,6 +174,7 @@ impl NativeAgent {
project: Entity<Project>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
cx: &mut AsyncApp,
) -> Result<Entity<NativeAgent>> {
log::info!("Creating new NativeAgent");
@@ -76,7 +184,13 @@ impl NativeAgent {
.await;
cx.new(|cx| {
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
let mut subscriptions = vec![
cx.subscribe(&project, Self::handle_project_event),
cx.subscribe(
&LanguageModelRegistry::global(cx),
Self::handle_models_updated_event,
),
];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
}
@@ -90,14 +204,23 @@ impl NativeAgent {
_maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
}),
context_server_registry: cx.new(|cx| {
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
}),
templates,
models: LanguageModels::new(cx),
project,
prompt_store,
fs,
_subscriptions: subscriptions,
}
})
}
pub fn models(&self) -> &LanguageModels {
&self.models
}
async fn maintain_project_context(
this: WeakEntity<Self>,
mut needs_refresh: watch::Receiver<()>,
@@ -293,75 +416,104 @@ impl NativeAgent {
) {
self.project_context_needs_refresh.send(()).ok();
}
fn handle_models_updated_event(
&mut self,
_registry: Entity<LanguageModelRegistry>,
_event: &language_model::Event,
cx: &mut Context<Self>,
) {
self.models.refresh_list(cx);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, _| {
let model_id = LanguageModels::model_id(&thread.selected_model);
if let Some(model) = self.models.model_from_id(&model_id) {
thread.selected_model = model.clone();
}
});
}
}
}
/// Wrapper struct that implements the AgentConnection trait
#[derive(Clone)]
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
impl ModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
impl AgentModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
log::debug!("NativeAgentConnection::list_models called");
cx.spawn(async move |cx| {
cx.update(|cx| {
let registry = LanguageModelRegistry::read_global(cx);
let models = registry.available_models(cx).collect::<Vec<_>>();
log::info!("Found {} available models", models.len());
if models.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(models)
}
})?
let list = self.0.read(cx).models.model_list.clone();
Task::ready(if list.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(list)
})
}
fn select_model(
&self,
session_id: acp::SessionId,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
model_id: acp_thread::AgentModelId,
cx: &mut App,
) -> Task<Result<()>> {
log::info!(
"Setting model for session {}: {:?}",
session_id,
model.name()
);
let agent = self.0.clone();
log::info!("Setting model for session {}: {}", session_id, model_id);
let Some(thread) = self
.0
.read(cx)
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
else {
return Task::ready(Err(anyhow!("Session not found")));
};
cx.spawn(async move |cx| {
agent.update(cx, |agent, cx| {
if let Some(session) = agent.sessions.get(&session_id) {
session.thread.update(cx, |thread, _cx| {
thread.selected_model = model;
});
Ok(())
} else {
Err(anyhow!("Session not found"))
}
})?
})
let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
};
thread.update(cx, |thread, _cx| {
thread.selected_model = model.clone();
});
update_settings_file::<AgentSettings>(
self.0.read(cx).fs.clone(),
cx,
move |settings, _cx| {
settings.set_model(model);
},
);
Task::ready(Ok(()))
}
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>> {
let agent = self.0.clone();
cx: &mut App,
) -> Task<Result<acp_thread::AgentModelInfo>> {
let session_id = session_id.clone();
cx.spawn(async move |cx| {
let thread = agent
.read_with(cx, |agent, _| {
agent
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
})?
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
Ok(selected)
})
let Some(thread) = self
.0
.read(cx)
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
else {
return Task::ready(Err(anyhow!("Session not found")));
};
let model = thread.read(cx).selected_model.clone();
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
else {
return Task::ready(Err(anyhow!("Provider not found")));
};
Task::ready(Ok(LanguageModels::map_language_model_to_info(
&model, &provider,
)))
}
fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
self.0.read(cx).models.watch()
}
}
@@ -385,7 +537,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
acp_thread::AcpThread::new(
"agent2",
self.clone(),
project.clone(),
session_id.clone(),
cx,
)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
@@ -403,28 +561,37 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
let default_model = registry
.default_model()
.map(|configured| {
log::info!(
"Using configured default model: {:?} from provider: {:?}",
configured.model.name(),
configured.provider.name()
);
configured.model
.and_then(|default_model| {
agent
.models
.model_from_id(&LanguageModels::model_id(&default_model.model))
})
.ok_or_else(|| {
log::warn!("No default model configured in settings");
anyhow!("No default model configured. Please configure a default model in settings.")
anyhow!(
"No default model. Please configure a default model in settings."
)
})?;
let thread = cx.new(|cx| {
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
let mut thread = Thread::new(
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
action_log.clone(),
agent.templates.clone(),
default_model,
cx,
);
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
thread.add_tool(MovePathTool::new(project.clone()));
thread.add_tool(ListDirectoryTool::new(project.clone()));
thread.add_tool(OpenTool::new(project.clone()));
thread.add_tool(ThinkingTool);
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(GrepTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(EditFileTool::new(cx.entity()));
@@ -448,7 +615,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
acp_thread: acp_thread.downgrade(),
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
})
}),
},
);
})?;
@@ -465,15 +632,17 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
Task::ready(Ok(()))
}
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
}
fn prompt(
&self,
id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let id = id.expect("UserMessageId is required");
let session_id = params.session_id.clone();
let agent = self.0.clone();
log::info!("Received prompt request for session: {}", session_id);
@@ -494,10 +663,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
})?;
log::debug!("Found session for: {}", session_id);
// Convert prompt to message
let message = convert_prompt_to_message(params.prompt);
log::info!("Converted prompt to message: {} chars", message.len());
log::debug!("Message content: {}", message);
let content: Vec<UserMessageContent> = params
.prompt
.into_iter()
.map(Into::into)
.collect::<Vec<_>>();
log::info!("Converted prompt to message: {} chars", content.len());
log::debug!("Message id: {:?}", id);
log::debug!("Message content: {:?}", content);
// Get model using the ModelSelector capability (always available for agent2)
// Get the selected model from the thread directly
@@ -505,7 +678,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Send to thread
log::info!("Sending message to thread with model: {:?}", model.name());
let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
let mut response_stream =
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
@@ -599,44 +773,33 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
}
});
}
fn session_editor(
&self,
session_id: &agent_client_protocol::SessionId,
cx: &mut App,
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
self.0.update(cx, |agent, _cx| {
agent
.sessions
.get(session_id)
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
})
}
}
/// Convert ACP content blocks to a message string
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
log::debug!("Converting {} content blocks to message", blocks.len());
let mut message = String::new();
struct NativeAgentSessionEditor(Entity<Thread>);
for block in blocks {
match block {
acp::ContentBlock::Text(text) => {
log::trace!("Processing text block: {} chars", text.text.len());
message.push_str(&text.text);
}
acp::ContentBlock::ResourceLink(link) => {
log::trace!("Processing resource link: {}", link.uri);
message.push_str(&format!(" @{} ", link.uri));
}
acp::ContentBlock::Image(_) => {
log::trace!("Processing image block");
message.push_str(" [image] ");
}
acp::ContentBlock::Audio(_) => {
log::trace!("Processing audio block");
message.push_str(" [audio] ");
}
acp::ContentBlock::Resource(resource) => {
log::trace!("Processing resource block: {:?}", resource.resource);
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
}
}
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
}
message
}
#[cfg(test)]
mod tests {
use super::*;
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
use fs::FakeFs;
use gpui::TestAppContext;
use serde_json::json;
@@ -654,9 +817,15 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
.await
.unwrap();
let agent = NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
agent.read_with(cx, |agent, _| {
assert_eq!(agent.project_context.borrow().worktrees, vec![])
});
@@ -697,13 +866,131 @@ mod tests {
});
}
#[gpui::test]
async fn test_listing_models(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
let connection = NativeAgentConnection(
NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap(),
);
let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
let acp_thread::AgentModelList::Grouped(models) = models else {
panic!("Unexpected model group");
};
assert_eq!(
models,
IndexMap::from_iter([(
AgentModelGroupName("Fake".into()),
vec![AgentModelInfo {
id: AgentModelId("fake/fake".into()),
name: "Fake".into(),
icon: Some(ui::IconName::ZedAssistant),
}]
)])
);
}
#[gpui::test]
async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.create_dir(paths::settings_file().parent().unwrap())
.await
.unwrap();
fs.insert_file(
paths::settings_file(),
json!({
"agent": {
"default_model": {
"provider": "foo",
"model": "bar"
}
}
})
.to_string()
.into_bytes(),
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
// Create the agent and connection
let agent = NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Create a thread/session
let acp_thread = cx
.update(|cx| {
Rc::new(connection.clone()).new_thread(
project.clone(),
Path::new("/a"),
&mut cx.to_async(),
)
})
.await
.unwrap();
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
// Select a model
let model_id = AgentModelId("fake/fake".into());
cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
.await
.unwrap();
// Verify the thread has the selected model
agent.read_with(cx, |agent, _| {
let session = agent.sessions.get(&session_id).unwrap();
session.thread.read_with(cx, |thread, _| {
assert_eq!(thread.selected_model.id().0, "fake");
});
});
cx.run_until_parked();
// Verify settings file was updated
let settings_content = fs.load(paths::settings_file()).await.unwrap();
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
// Check that the agent settings contain the selected model
assert_eq!(
settings_json["agent"]["default_model"]["model"],
json!("fake")
);
assert_eq!(
settings_json["agent"]["default_model"]["provider"],
json!("fake")
);
}
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
agent_settings::init(cx);
language::init(cx);
LanguageModelRegistry::test(cx);
});
}
}

View File

@@ -1,8 +1,8 @@
use std::path::Path;
use std::rc::Rc;
use std::{path::Path, rc::Rc, sync::Arc};
use agent_servers::AgentServer;
use anyhow::Result;
use fs::Fs;
use gpui::{App, Entity, Task};
use project::Project;
use prompt_store::PromptStore;
@@ -10,7 +10,15 @@ use prompt_store::PromptStore;
use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
#[derive(Clone)]
pub struct NativeAgentServer;
pub struct NativeAgentServer {
fs: Arc<dyn Fs>,
}
impl NativeAgentServer {
pub fn new(fs: Arc<dyn Fs>) -> Self {
Self { fs }
}
}
impl AgentServer for NativeAgentServer {
fn name(&self) -> &'static str {
@@ -41,6 +49,7 @@ impl AgentServer for NativeAgentServer {
_root_dir
);
let project = project.clone();
let fs = self.fs.clone();
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
@@ -48,7 +57,7 @@ impl AgentServer for NativeAgentServer {
let prompt_store = prompt_store.await?;
log::debug!("Creating native agent entity");
let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?;
let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?;
// Create the connection wrapper
let connection = NativeAgentConnection(agent);

View File

@@ -1,7 +1,8 @@
use super::*;
use acp_thread::AgentConnection;
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
use action_log::ActionLog;
use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId;
use anyhow::Result;
use client::{Client, UserStore};
use fs::{FakeFs, Fs};
@@ -12,8 +13,8 @@ use gpui::{
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
StopReason, fake_provider::FakeLanguageModel,
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
fake_provider::FakeLanguageModel,
};
use project::Project;
use prompt_store::ProjectContext;
@@ -36,15 +37,19 @@ async fn test_echo(cx: &mut TestAppContext) {
let events = thread
.update(cx, |thread, cx| {
thread.send("Testing: Reply with 'Hello'", cx)
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
})
.collect()
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.messages().last().unwrap().content,
vec![MessageContent::Text("Hello".to_string())]
);
thread.last_message().unwrap().to_markdown(),
indoc! {"
## Assistant
Hello
"}
)
});
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
}
@@ -57,12 +62,13 @@ async fn test_thinking(cx: &mut TestAppContext) {
let events = thread
.update(cx, |thread, cx| {
thread.send(
indoc! {"
UserMessageId::new(),
[indoc! {"
Testing:
Generate a thinking step where you just think the word 'Think',
and have your final answer be 'Hello'
"},
"}],
cx,
)
})
@@ -70,9 +76,10 @@ async fn test_thinking(cx: &mut TestAppContext) {
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.messages().last().unwrap().to_markdown(),
thread.last_message().unwrap().to_markdown(),
indoc! {"
## assistant
## Assistant
<think>Think</think>
Hello
"}
@@ -93,7 +100,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
project_context.borrow_mut().shell = "test-shell".into();
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread.update(cx, |thread, cx| thread.send("abc", cx));
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["abc"], cx)
});
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(
@@ -130,7 +139,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
UserMessageId::new(),
["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
cx,
)
})
@@ -144,7 +154,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
thread.remove_tool(&AgentTool::name(&EchoTool));
thread.add_tool(DelayTool);
thread.send(
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
UserMessageId::new(),
[
"Now call the delay tool with 200ms.",
"When the timer goes off, then you echo the output of the tool.",
],
cx,
)
})
@@ -154,18 +168,21 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
thread.update(cx, |thread, _cx| {
assert!(
thread
.messages()
.last()
.last_message()
.unwrap()
.as_agent_message()
.unwrap()
.content
.iter()
.any(|content| {
if let MessageContent::Text(text) = content {
if let AgentMessageContent::Text(text) = content {
text.contains("Ding")
} else {
false
}
})
}),
"{}",
thread.to_markdown()
);
});
}
@@ -178,7 +195,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
// Test a tool call that's likely to complete *before* streaming stops.
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(WordListTool);
thread.send("Test the word_list tool.", cx)
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
});
let mut saw_partial_tool_use = false;
@@ -186,8 +203,10 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message
let last_content = thread.messages().last().unwrap().content.last().unwrap();
if let MessageContent::ToolUse(last_tool_use) = last_content {
let message = thread.last_message().unwrap();
let agent_message = message.as_agent_message().unwrap();
let last_content = agent_message.content.last().unwrap();
if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
assert_eq!(last_tool_use.name.as_ref(), "word_list");
if tool_call.status == acp::ToolCallStatus::Pending {
if !last_tool_use.is_input_complete
@@ -225,7 +244,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(ToolRequiringPermission);
thread.send("abc", cx)
thread.send(UserMessageId::new(), ["abc"], cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
@@ -269,14 +288,14 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
assert_eq!(
message.content,
vec![
MessageContent::ToolResult(LanguageModelToolResult {
language_model::MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
}),
MessageContent::ToolResult(LanguageModelToolResult {
language_model::MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: true,
@@ -309,13 +328,15 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let message = completion.messages.last().unwrap();
assert_eq!(
message.content,
vec![MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
})]
vec![language_model::MessageContent::ToolResult(
LanguageModelToolResult {
tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
}
)]
);
// Simulate a final tool call, ensuring we don't trigger authorization.
@@ -334,13 +355,15 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let message = completion.messages.last().unwrap();
assert_eq!(
message.content,
vec![MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: "tool_id_4".into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
})]
vec![language_model::MessageContent::ToolResult(
LanguageModelToolResult {
tool_use_id: "tool_id_4".into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
}
)]
);
}
@@ -349,7 +372,9 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
let mut events = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["abc"], cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -441,7 +466,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
.update(cx, |thread, cx| {
thread.add_tool(DelayTool);
thread.send(
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
UserMessageId::new(),
[
"Call the delay tool twice in the same message.",
"Once with 100ms. Once with 300ms.",
"When both timers are complete, describe the outputs.",
],
cx,
)
})
@@ -452,12 +482,13 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
thread.update(cx, |thread, _cx| {
let last_message = thread.messages().last().unwrap();
let text = last_message
let last_message = thread.last_message().unwrap();
let agent_message = last_message.as_agent_message().unwrap();
let text = agent_message
.content
.iter()
.filter_map(|content| {
if let MessageContent::Text(text) = content {
if let AgentMessageContent::Text(text) = content {
Some(text.as_str())
} else {
None
@@ -469,6 +500,82 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
});
}
#[gpui::test]
async fn test_profiles(cx: &mut TestAppContext) {
let ThreadTest {
model, thread, fs, ..
} = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
thread.update(cx, |thread, _cx| {
thread.add_tool(DelayTool);
thread.add_tool(EchoTool);
thread.add_tool(InfiniteTool);
});
// Override profiles and wait for settings to be loaded.
fs.insert_file(
paths::settings_file(),
json!({
"agent": {
"profiles": {
"test-1": {
"name": "Test Profile 1",
"tools": {
EchoTool.name(): true,
DelayTool.name(): true,
}
},
"test-2": {
"name": "Test Profile 2",
"tools": {
InfiniteTool.name(): true,
}
}
}
}
})
.to_string()
.into_bytes(),
)
.await;
cx.run_until_parked();
// Test that test-1 profile (default) has echo and delay tools
thread.update(cx, |thread, cx| {
thread.set_profile(AgentProfileId("test-1".into()));
thread.send(UserMessageId::new(), ["test"], cx);
});
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(pending_completions.len(), 1);
let completion = pending_completions.pop().unwrap();
let tool_names: Vec<String> = completion
.tools
.iter()
.map(|tool| tool.name.clone())
.collect();
assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
fake_model.end_last_completion_stream();
// Switch to test-2 profile, and verify that it has only the infinite tool.
thread.update(cx, |thread, cx| {
thread.set_profile(AgentProfileId("test-2".into()));
thread.send(UserMessageId::new(), ["test2"], cx)
});
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(pending_completions.len(), 1);
let completion = pending_completions.pop().unwrap();
let tool_names: Vec<String> = completion
.tools
.iter()
.map(|tool| tool.name.clone())
.collect();
assert_eq!(tool_names, vec![InfiniteTool.name()]);
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_cancellation(cx: &mut TestAppContext) {
@@ -478,7 +585,8 @@ async fn test_cancellation(cx: &mut TestAppContext) {
thread.add_tool(InfiniteTool);
thread.add_tool(EchoTool);
thread.send(
"Call the echo tool and then call the infinite tool, then explain their output",
UserMessageId::new(),
["Call the echo tool, then call the infinite tool, then explain their output"],
cx,
)
});
@@ -523,14 +631,20 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Ensure we can still send a new message after cancellation.
let events = thread
.update(cx, |thread, cx| {
thread.send("Testing: reply with 'Hello' then stop.", cx)
thread.send(
UserMessageId::new(),
["Testing: reply with 'Hello' then stop."],
cx,
)
})
.collect::<Vec<_>>()
.await;
thread.update(cx, |thread, _cx| {
let message = thread.last_message().unwrap();
let agent_message = message.as_agent_message().unwrap();
assert_eq!(
thread.messages().last().unwrap().content,
vec![MessageContent::Text("Hello".to_string())]
agent_message.content,
vec![AgentMessageContent::Text("Hello".to_string())]
);
});
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
@@ -541,13 +655,16 @@ async fn test_refusal(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
let events = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello"], cx)
});
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## user
## User
Hello
"}
);
@@ -559,9 +676,12 @@ async fn test_refusal(cx: &mut TestAppContext) {
assert_eq!(
thread.to_markdown(),
indoc! {"
## user
## User
Hello
## assistant
## Assistant
Hey!
"}
);
@@ -577,6 +697,85 @@ async fn test_refusal(cx: &mut TestAppContext) {
});
}
#[gpui::test]
async fn test_truncate(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let message_id = UserMessageId::new();
thread.update(cx, |thread, cx| {
thread.send(message_id.clone(), ["Hello"], cx)
});
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hello
"}
);
});
fake_model.send_last_completion_stream_text_chunk("Hey!");
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hello
## Assistant
Hey!
"}
);
});
thread
.update(cx, |thread, _cx| thread.truncate(message_id))
.unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(thread.to_markdown(), "");
});
// Ensure we can still send a new message after truncation.
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hi"], cx)
});
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hi
"}
);
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hi
## Assistant
Ahoy!
"}
);
});
}
#[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) {
cx.update(settings::init);
@@ -595,19 +794,26 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
language_models::init(user_store.clone(), client.clone(), cx);
Project::init_settings(cx);
LanguageModelRegistry::test(cx);
agent_settings::init(cx);
});
cx.executor().forbid_parking();
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
let cwd = Path::new("/test");
// Create agent and connection
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
.await
.unwrap();
let agent = NativeAgent::new(
project.clone(),
templates.clone(),
None,
fake_fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Test model_selector returns Some
@@ -620,22 +826,22 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
// Test list_models
let listed_models = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.list_models(&mut async_cx)
})
.update(|cx| selector.list_models(cx))
.await
.expect("list_models should succeed");
let AgentModelList::Grouped(listed_models) = listed_models else {
panic!("Unexpected model list type");
};
assert!(!listed_models.is_empty(), "should have at least one model");
assert_eq!(listed_models[0].id().0, "fake");
assert_eq!(
listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
"fake/fake"
);
// Create a thread using new_thread
let connection_rc = Rc::new(connection.clone());
let acp_thread = cx
.update(|cx| {
let mut async_cx = cx.to_async();
connection_rc.new_thread(project, cwd, &mut async_cx)
})
.update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async()))
.await
.expect("new_thread should succeed");
@@ -644,12 +850,12 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
// Test selected_model returns the default
let model = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.selected_model(&session_id, &mut async_cx)
})
.update(|cx| selector.selected_model(&session_id, cx))
.await
.expect("selected_model should succeed");
let model = cx
.update(|cx| agent.read(cx).models().model_from_id(&model.id))
.unwrap();
let model = model.as_fake();
assert_eq!(model.id().0, "fake", "should return default model");
@@ -683,6 +889,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let result = cx
.update(|cx| {
connection.prompt(
Some(acp_thread::UserMessageId::new()),
acp::PromptRequest {
session_id: session_id.clone(),
prompt: vec!["ghi".into()],
@@ -705,7 +912,9 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
let mut events = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Think"], cx)
});
cx.run_until_parked();
// Simulate streaming partial input.
@@ -790,6 +999,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
id: acp::ToolCallId("1".into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
raw_output: Some("Finished thinking.".into()),
..Default::default()
},
}
@@ -813,6 +1023,7 @@ struct ThreadTest {
model: Arc<dyn LanguageModel>,
thread: Entity<Thread>,
project_context: Rc<RefCell<ProjectContext>>,
fs: Arc<FakeFs>,
}
enum TestModel {
@@ -835,30 +1046,57 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.executor().allow_parking();
let fs = FakeFs::new(cx.background_executor.clone());
fs.create_dir(paths::settings_file().parent().unwrap())
.await
.unwrap();
fs.insert_file(
paths::settings_file(),
json!({
"agent": {
"default_profile": "test-profile",
"profiles": {
"test-profile": {
"name": "Test Profile",
"tools": {
EchoTool.name(): true,
DelayTool.name(): true,
WordListTool.name(): true,
ToolRequiringPermission.name(): true,
InfiniteTool.name(): true,
}
}
}
}
})
.to_string()
.into_bytes(),
)
.await;
cx.update(|cx| {
settings::init(cx);
watch_settings(fs.clone(), cx);
Project::init_settings(cx);
agent_settings::init(cx);
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
watch_settings(fs.clone(), cx);
});
let templates = Templates::new();
fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let model = cx
.update(|cx| {
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
if let TestModel::Fake = model {
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
} else {
@@ -881,20 +1119,25 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
.await;
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_| {
let thread = cx.new(|cx| {
Thread::new(
project,
project_context.clone(),
context_server_registry,
action_log,
templates,
model.clone(),
cx,
)
});
ThreadTest {
model,
thread,
project_context,
fs,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,10 @@
mod context_server_registry;
mod copy_path_tool;
mod create_directory_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
mod grep_tool;
mod list_directory_tool;
@@ -13,10 +16,13 @@ mod terminal_tool;
mod thinking_tool;
mod web_search_tool;
pub use context_server_registry::*;
pub use copy_path_tool::*;
pub use create_directory_tool::*;
pub use delete_path_tool::*;
pub use diagnostics_tool::*;
pub use edit_file_tool::*;
pub use fetch_tool::*;
pub use find_path_tool::*;
pub use grep_tool::*;
pub use list_directory_tool::*;

View File

@@ -0,0 +1,231 @@
use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
use agent_client_protocol::ToolKind;
use anyhow::{Result, anyhow, bail};
use collections::{BTreeMap, HashMap};
use context_server::ContextServerId;
use gpui::{App, Context, Entity, SharedString, Task};
use project::context_server_store::{ContextServerStatus, ContextServerStore};
use std::sync::Arc;
use util::ResultExt;
pub struct ContextServerRegistry {
server_store: Entity<ContextServerStore>,
registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
_subscription: gpui::Subscription,
}
struct RegisteredContextServer {
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
load_tools: Task<Result<()>>,
}
impl ContextServerRegistry {
pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
let mut this = Self {
server_store: server_store.clone(),
registered_servers: HashMap::default(),
_subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
};
for server in server_store.read(cx).running_servers() {
this.reload_tools_for_server(server.id(), cx);
}
this
}
pub fn servers(
&self,
) -> impl Iterator<
Item = (
&ContextServerId,
&BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
),
> {
self.registered_servers
.iter()
.map(|(id, server)| (id, &server.tools))
}
fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
return;
};
let Some(client) = server.client() else {
return;
};
if !client.capable(context_server::protocol::ServerCapability::Tools) {
return;
}
let registered_server =
self.registered_servers
.entry(server_id.clone())
.or_insert(RegisteredContextServer {
tools: BTreeMap::default(),
load_tools: Task::ready(Ok(())),
});
registered_server.load_tools = cx.spawn(async move |this, cx| {
let response = client
.request::<context_server::types::requests::ListTools>(())
.await;
this.update(cx, |this, cx| {
let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
return;
};
registered_server.tools.clear();
if let Some(response) = response.log_err() {
for tool in response.tools {
let tool = Arc::new(ContextServerTool::new(
this.server_store.clone(),
server.id(),
tool,
));
registered_server.tools.insert(tool.name(), tool);
}
cx.notify();
}
})
});
}
fn handle_context_server_store_event(
&mut self,
_: Entity<ContextServerStore>,
event: &project::context_server_store::Event,
cx: &mut Context<Self>,
) {
match event {
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status {
ContextServerStatus::Starting => {}
ContextServerStatus::Running => {
self.reload_tools_for_server(server_id.clone(), cx);
}
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
self.registered_servers.remove(&server_id);
cx.notify();
}
}
}
}
}
}
struct ContextServerTool {
store: Entity<ContextServerStore>,
server_id: ContextServerId,
tool: context_server::types::Tool,
}
impl ContextServerTool {
fn new(
store: Entity<ContextServerStore>,
server_id: ContextServerId,
tool: context_server::types::Tool,
) -> Self {
Self {
store,
server_id,
tool,
}
}
}
impl AnyAgentTool for ContextServerTool {
fn name(&self) -> SharedString {
self.tool.name.clone().into()
}
fn description(&self) -> SharedString {
self.tool.description.clone().unwrap_or_default().into()
}
fn kind(&self) -> ToolKind {
ToolKind::Other
}
fn initial_title(&self, _input: serde_json::Value) -> SharedString {
format!("Run MCP tool `{}`", self.tool.name).into()
}
fn input_schema(
&self,
format: language_model::LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let mut schema = self.tool.input_schema.clone();
assistant_tool::adapt_schema_to_format(&mut schema, format)?;
Ok(match schema {
serde_json::Value::Null => {
serde_json::json!({ "type": "object", "properties": [] })
}
serde_json::Value::Object(map) if map.is_empty() => {
serde_json::json!({ "type": "object", "properties": [] })
}
_ => schema,
})
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<AgentToolOutput>> {
let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
return Task::ready(Err(anyhow!("Context server not found")));
};
let tool_name = self.tool.name.clone();
let server_clone = server.clone();
let input_clone = input.clone();
cx.spawn(async move |_cx| {
let Some(protocol) = server_clone.client() else {
bail!("Context server not initialized");
};
let arguments = if let serde_json::Value::Object(map) = input_clone {
Some(map.into_iter().collect())
} else {
None
};
log::trace!(
"Running tool: {} with arguments: {:?}",
tool_name,
arguments
);
let response = protocol
.request::<context_server::types::requests::CallTool>(
context_server::types::CallToolParams {
name: tool_name,
arguments,
meta: None,
},
)
.await?;
let mut result = String::new();
for content in response.content {
match content {
context_server::types::ToolResponseContent::Text { text } => {
result.push_str(&text);
}
context_server::types::ToolResponseContent::Image { .. } => {
log::warn!("Ignoring image content from tool response");
}
context_server::types::ToolResponseContent::Audio { .. } => {
log::warn!("Ignoring audio content from tool response");
}
context_server::types::ToolResponseContent::Resource { .. } => {
log::warn!("Ignoring resource content from tool response");
}
}
}
Ok(AgentToolOutput {
raw_output: result.clone().into(),
llm_output: result.into(),
})
})
}
}

View File

@@ -0,0 +1,163 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
use gpui::{App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{fmt::Write, path::Path, sync::Arc};
use ui::SharedString;
use util::markdown::MarkdownInlineCode;
/// Get errors and warnings for the project or a specific file.
///
/// This tool can be invoked after a series of edits to determine if further edits are necessary, or if the user asks to fix errors or warnings in their codebase.
///
/// When a path is provided, shows all diagnostics for that specific file.
/// When no path is provided, shows a summary of error and warning counts for all files in the project.
///
/// <example>
/// To get diagnostics for a specific file:
/// {
/// "path": "src/main.rs"
/// }
///
/// To get a project-wide diagnostic summary:
/// {}
/// </example>
///
/// <guidelines>
/// - If you think you can fix a diagnostic, make 1-2 attempts and then give up.
/// - Don't remove code you've generated just because you can't fix an error. The user can help you fix it.
/// </guidelines>
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct DiagnosticsToolInput {
/// The path to get diagnostics for. If not provided, returns a project-wide summary.
///
/// This path should never be absolute, and the first component
/// of the path should always be a root directory in a project.
///
/// <example>
/// If the project has the following root directories:
///
/// - lorem
/// - ipsum
///
/// If you wanna access diagnostics for `dolor.txt` in `ipsum`, you should use the path `ipsum/dolor.txt`.
/// </example>
pub path: Option<String>,
}
pub struct DiagnosticsTool {
project: Entity<Project>,
}
impl DiagnosticsTool {
pub fn new(project: Entity<Project>) -> Self {
Self { project }
}
}
impl AgentTool for DiagnosticsTool {
type Input = DiagnosticsToolInput;
type Output = String;
fn name(&self) -> SharedString {
"diagnostics".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Read
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
if let Some(path) = input.ok().and_then(|input| match input.path {
Some(path) if !path.is_empty() => Some(path),
_ => None,
}) {
format!("Check diagnostics for {}", MarkdownInlineCode(&path)).into()
} else {
"Check project diagnostics".into()
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
match input.path {
Some(path) if !path.is_empty() => {
let Some(project_path) = self.project.read(cx).find_project_path(&path, cx) else {
return Task::ready(Err(anyhow!("Could not find path {path} in project",)));
};
let buffer = self
.project
.update(cx, |project, cx| project.open_buffer(project_path, cx));
cx.spawn(async move |cx| {
let mut output = String::new();
let buffer = buffer.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
for (_, group) in snapshot.diagnostic_groups(None) {
let entry = &group.entries[group.primary_ix];
let range = entry.range.to_point(&snapshot);
let severity = match entry.diagnostic.severity {
DiagnosticSeverity::ERROR => "error",
DiagnosticSeverity::WARNING => "warning",
_ => continue,
};
writeln!(
output,
"{} at line {}: {}",
severity,
range.start.row + 1,
entry.diagnostic.message
)?;
}
if output.is_empty() {
Ok("File doesn't have errors or warnings!".to_string())
} else {
Ok(output)
}
})
}
_ => {
let project = self.project.read(cx);
let mut output = String::new();
let mut has_diagnostics = false;
for (project_path, _, summary) in project.diagnostic_summaries(true, cx) {
if summary.error_count > 0 || summary.warning_count > 0 {
let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx)
else {
continue;
};
has_diagnostics = true;
output.push_str(&format!(
"{}: {} error(s), {} warning(s)\n",
Path::new(worktree.read(cx).root_name())
.join(project_path.path)
.display(),
summary.error_count,
summary.warning_count
));
}
}
if has_diagnostics {
Task::ready(Ok(output))
} else {
Task::ready(Ok("No errors or warnings found in the project.".into()))
}
}
}
}
}

View File

@@ -1,12 +1,13 @@
use crate::{AgentTool, Thread, ToolCallEventStream};
use acp_thread::Diff;
use agent_client_protocol as acp;
use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
use anyhow::{Context as _, Result, anyhow};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc;
use language::ToPoint;
use language::language_settings::{self, FormatOnSave};
use language_model::LanguageModelToolResultContent;
use paths;
@@ -225,6 +226,16 @@ impl AgentTool for EditFileTool {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let abs_path = project.read(cx).absolute_path(&project_path, cx);
if let Some(abs_path) = abs_path.clone() {
event_stream.update_fields(ToolCallUpdateFields {
locations: Some(vec![acp::ToolCallLocation {
path: abs_path,
line: None,
}]),
..Default::default()
});
}
let request = self.thread.update(cx, |thread, cx| {
thread.build_completion_request(CompletionIntent::ToolResults, cx)
@@ -283,13 +294,38 @@ impl AgentTool for EditFileTool {
let mut hallucinated_old_text = false;
let mut ambiguous_ranges = Vec::new();
let mut emitted_location = false;
while let Some(event) = events.next().await {
match event {
EditAgentOutputEvent::Edited => {},
EditAgentOutputEvent::Edited(range) => {
if !emitted_location {
let line = buffer.update(cx, |buffer, _cx| {
range.start.to_point(&buffer.snapshot()).row
}).ok();
if let Some(abs_path) = abs_path.clone() {
event_stream.update_fields(ToolCallUpdateFields {
locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
..Default::default()
});
}
emitted_location = true;
}
},
EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
EditAgentOutputEvent::ResolvingEditRange(range) => {
diff.update(cx, |card, cx| card.reveal_range(range, cx))?;
diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
// if !emitted_location {
// let line = buffer.update(cx, |buffer, _cx| {
// range.start.to_point(&buffer.snapshot()).row
// }).ok();
// if let Some(abs_path) = abs_path.clone() {
// event_stream.update_fields(ToolCallUpdateFields {
// locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
// ..Default::default()
// });
// }
// }
}
}
}
@@ -454,9 +490,8 @@ fn resolve_path(
#[cfg(test)]
mod tests {
use crate::Templates;
use super::*;
use crate::{ContextServerRegistry, Templates};
use action_log::ActionLog;
use client::TelemetrySettings;
use fs::Fs;
@@ -475,9 +510,20 @@ mod tests {
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread =
cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
context_server_registry,
action_log,
Templates::new(),
model,
cx,
)
});
let result = cx
.update(|cx| {
let input = EditFileToolInput {
@@ -661,14 +707,18 @@ mod tests {
});
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|_| {
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
@@ -792,15 +842,19 @@ mod tests {
.unwrap();
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|_| {
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
@@ -914,15 +968,19 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|_| {
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1041,15 +1099,19 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|_| {
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1148,14 +1210,18 @@ mod tests {
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|_| {
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1225,14 +1291,18 @@ mod tests {
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|_| {
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1305,14 +1375,18 @@ mod tests {
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|_| {
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1382,14 +1456,18 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|_| {
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });

View File

@@ -0,0 +1,155 @@
use std::rc::Rc;
use std::sync::Arc;
use std::{borrow::Cow, cell::RefCell};
use agent_client_protocol as acp;
use anyhow::{Context as _, Result, bail};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext as _, Task};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
use http_client::{AsyncBody, HttpClientWithUrl};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::SharedString;
use util::markdown::MarkdownEscaped;
use crate::{AgentTool, ToolCallEventStream};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
enum ContentType {
Html,
Plaintext,
Json,
}
/// Fetches a URL and returns the content as Markdown.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct FetchToolInput {
/// The URL to fetch.
url: String,
}
pub struct FetchTool {
http_client: Arc<HttpClientWithUrl>,
}
impl FetchTool {
pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
Self { http_client }
}
async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
let url = if !url.starts_with("https://") && !url.starts_with("http://") {
Cow::Owned(format!("https://{url}"))
} else {
Cow::Borrowed(url)
};
let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
let mut body = Vec::new();
response
.body_mut()
.read_to_end(&mut body)
.await
.context("error reading response body")?;
if response.status().is_client_error() {
let text = String::from_utf8_lossy(body.as_slice());
bail!(
"status error {}, response: {text:?}",
response.status().as_u16()
);
}
let Some(content_type) = response.headers().get("content-type") else {
bail!("missing Content-Type header");
};
let content_type = content_type
.to_str()
.context("invalid Content-Type header")?;
let content_type = if content_type.starts_with("text/plain") {
ContentType::Plaintext
} else if content_type.starts_with("application/json") {
ContentType::Json
} else {
ContentType::Html
};
match content_type {
ContentType::Html => {
let mut handlers: Vec<TagHandler> = vec![
Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
Rc::new(RefCell::new(markdown::ParagraphHandler)),
Rc::new(RefCell::new(markdown::HeadingHandler)),
Rc::new(RefCell::new(markdown::ListHandler)),
Rc::new(RefCell::new(markdown::TableHandler::new())),
Rc::new(RefCell::new(markdown::StyledTextHandler)),
];
if url.contains("wikipedia.org") {
use html_to_markdown::structure::wikipedia;
handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
handlers.push(Rc::new(
RefCell::new(wikipedia::WikipediaCodeHandler::new()),
));
} else {
handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
}
convert_html_to_markdown(&body[..], &mut handlers)
}
ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
ContentType::Json => {
let json: serde_json::Value = serde_json::from_slice(&body)?;
Ok(format!(
"```json\n{}\n```",
serde_json::to_string_pretty(&json)?
))
}
}
}
}
impl AgentTool for FetchTool {
type Input = FetchToolInput;
type Output = String;
fn name(&self) -> SharedString {
"fetch".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Fetch
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
match input {
Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(),
Err(_) => "Fetch URL".into(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let text = cx.background_spawn({
let http_client = self.http_client.clone();
async move { Self::build_message(http_client, &input.url).await }
});
cx.foreground_executor().spawn(async move {
let text = text.await?;
if text.trim().is_empty() {
bail!("no textual content found");
}
Ok(text)
})
}
}

View File

@@ -139,9 +139,6 @@ impl AgentTool for FindPathTool {
})
.collect(),
),
raw_output: Some(serde_json::json!({
"paths": &matches,
})),
..Default::default()
});

View File

@@ -101,7 +101,7 @@ impl AgentTool for GrepTool {
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
const CONTEXT_LINES: u32 = 2;
@@ -282,33 +282,22 @@ impl AgentTool for GrepTool {
}
}
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![output.clone().into()]),
..Default::default()
});
matches_found += 1;
}
}
let output = if matches_found == 0 {
"No matches found".to_string()
if matches_found == 0 {
Ok("No matches found".into())
} else if has_more_matches {
format!(
Ok(format!(
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
input.offset + 1,
input.offset + matches_found,
input.offset + RESULTS_PER_PAGE,
)
))
} else {
format!("Found {matches_found} matches:\n{output}")
};
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![output.clone().into()]),
..Default::default()
});
Ok(output)
Ok(format!("Found {matches_found} matches:\n{output}"))
}
})
}
}

View File

@@ -47,20 +47,13 @@ impl AgentTool for NowTool {
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
let now = match input.timezone {
Timezone::Utc => Utc::now().to_rfc3339(),
Timezone::Local => Local::now().to_rfc3339(),
};
let content = format!("The current datetime is {now}.");
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![content.clone().into()]),
..Default::default()
});
Task::ready(Ok(content))
Task::ready(Ok(format!("The current datetime is {now}.")))
}
}

View File

@@ -1,10 +1,10 @@
use action_log::ActionLog;
use agent_client_protocol::{self as acp};
use agent_client_protocol::{self as acp, ToolCallUpdateFields};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::outline;
use gpui::{App, Entity, SharedString, Task};
use indoc::formatdoc;
use language::{Anchor, Point};
use language::Point;
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store};
use schemars::JsonSchema;
@@ -97,7 +97,7 @@ impl AgentTool for ReadFileTool {
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<LanguageModelToolResultContent>> {
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
@@ -166,7 +166,9 @@ impl AgentTool for ReadFileTool {
cx.spawn(async move |cx| {
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
project.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})
})?
.await?;
if buffer.read_with(cx, |buffer, _| {
@@ -178,19 +180,10 @@ impl AgentTool for ReadFileTool {
anyhow::bail!("{file_path} not found");
}
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: Anchor::MIN,
}),
cx,
);
})?;
let mut anchor = None;
// Check if specific line ranges are provided
if input.start_line.is_some() || input.end_line.is_some() {
let mut anchor = None;
let result = if input.start_line.is_some() || input.end_line.is_some() {
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
@@ -214,18 +207,6 @@ impl AgentTool for ReadFileTool {
log.buffer_read(buffer.clone(), cx);
})?;
if let Some(anchor) = anchor {
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: anchor,
}),
cx,
);
})?;
}
Ok(result.into())
} else {
// No line ranges specified, so check file size to see if it's too big.
@@ -236,7 +217,7 @@ impl AgentTool for ReadFileTool {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer, cx);
log.buffer_read(buffer.clone(), cx);
})?;
Ok(result.into())
@@ -244,7 +225,8 @@ impl AgentTool for ReadFileTool {
// File is too big, so return the outline
// and a suggestion to read again with line numbers.
let outline =
outline::file_outline(project, file_path, action_log, None, cx).await?;
outline::file_outline(project.clone(), file_path, action_log, None, cx)
.await?;
Ok(formatdoc! {"
This file was too big to read all at once.
@@ -261,7 +243,28 @@ impl AgentTool for ReadFileTool {
}
.into())
}
}
};
project.update(cx, |project, cx| {
if let Some(abs_path) = project.absolute_path(&project_path, cx) {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: anchor.unwrap_or(text::Anchor::MIN),
}),
cx,
);
event_stream.update_fields(ToolCallUpdateFields {
locations: Some(vec![acp::ToolCallLocation {
path: abs_path,
line: input.start_line.map(|line| line.saturating_sub(1)),
}]),
..Default::default()
});
}
})?;
result
})
}
}

View File

@@ -5,7 +5,9 @@ use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
use cloud_llm_client::WebSearchResponse;
use gpui::{App, AppContext, Task};
use language_model::LanguageModelToolResultContent;
use language_model::{
LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::prelude::*;
@@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool {
"Searching the Web".into()
}
/// We currently only support Zed Cloud as a provider.
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
provider == &ZED_CLOUD_PROVIDER_ID
}
fn run(
self: Arc<Self>,
input: Self::Input,

View File

@@ -467,6 +467,7 @@ impl AgentConnection for AcpConnection {
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {

View File

@@ -171,6 +171,7 @@ impl AgentConnection for AcpConnection {
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {

View File

@@ -210,6 +210,7 @@ impl AgentConnection for ClaudeAgentConnection {
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
@@ -423,7 +424,7 @@ impl ClaudeAgentSession {
if !turn_state.borrow().is_cancelled() {
thread
.update(cx, |thread, cx| {
thread.push_user_content_block(text.into(), cx)
thread.push_user_content_block(None, text.into(), cx)
})
.log_err();
}

View File

@@ -48,6 +48,20 @@ pub struct AgentProfileSettings {
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
}
impl AgentProfileSettings {
pub fn is_tool_enabled(&self, tool_name: &str) -> bool {
self.tools.get(tool_name) == Some(&true)
}
pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool {
self.enable_all_context_servers
|| self
.context_servers
.get(server_id)
.map_or(false, |preset| preset.tools.get(tool_name) == Some(&true))
}
}
#[derive(Debug, Clone, Default)]
pub struct ContextServerPreset {
pub tools: IndexMap<Arc<str>, bool>,

View File

@@ -1,6 +1,9 @@
mod completion_provider;
mod message_history;
mod message_editor;
mod model_selector;
mod model_selector_popover;
mod thread_view;
pub use message_history::MessageHistory;
pub use model_selector::AcpModelSelector;
pub use model_selector_popover::AcpModelSelectorPopover;
pub use thread_view::AcpThreadView;

View File

@@ -1,18 +1,20 @@
use std::ops::Range;
use std::path::Path;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::Result;
use acp_thread::MentionUri;
use anyhow::{Context as _, Result};
use collections::HashMap;
use editor::display_map::CreaseId;
use editor::{CompletionProvider, Editor, ExcerptId};
use file_icons::FileIcons;
use futures::future::try_join_all;
use gpui::{App, Entity, Task, WeakEntity};
use language::{Buffer, CodeLabel, HighlightId};
use lsp::CompletionContext;
use parking_lot::Mutex;
use project::{Completion, CompletionIntent, CompletionResponse, ProjectPath, WorktreeId};
use project::{Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, WorktreeId};
use rope::Point;
use text::{Anchor, ToPoint};
use ui::prelude::*;
@@ -23,21 +25,63 @@ use crate::context_picker::file_context_picker::{extract_file_name_and_directory
#[derive(Default)]
pub struct MentionSet {
paths_by_crease_id: HashMap<CreaseId, ProjectPath>,
paths_by_crease_id: HashMap<CreaseId, MentionUri>,
}
impl MentionSet {
pub fn insert(&mut self, crease_id: CreaseId, path: ProjectPath) {
self.paths_by_crease_id.insert(crease_id, path);
}
pub fn path_for_crease_id(&self, crease_id: CreaseId) -> Option<ProjectPath> {
self.paths_by_crease_id.get(&crease_id).cloned()
pub fn insert(&mut self, crease_id: CreaseId, path: PathBuf) {
self.paths_by_crease_id
.insert(crease_id, MentionUri::File(path));
}
pub fn drain(&mut self) -> impl Iterator<Item = CreaseId> {
self.paths_by_crease_id.drain().map(|(id, _)| id)
}
pub fn contents(
&self,
project: Entity<Project>,
cx: &mut App,
) -> Task<Result<HashMap<CreaseId, Mention>>> {
let contents = self
.paths_by_crease_id
.iter()
.map(|(crease_id, uri)| match uri {
MentionUri::File(path) => {
let crease_id = *crease_id;
let uri = uri.clone();
let path = path.to_path_buf();
let buffer_task = project.update(cx, |project, cx| {
let path = project
.find_project_path(path, cx)
.context("Failed to find project path")?;
anyhow::Ok(project.open_buffer(path, cx))
});
cx.spawn(async move |cx| {
let buffer = buffer_task?.await?;
let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
anyhow::Ok((crease_id, Mention { uri, content }))
})
}
_ => {
// TODO
unimplemented!()
}
})
.collect::<Vec<_>>();
cx.spawn(async move |_cx| {
let contents = try_join_all(contents).await?.into_iter().collect();
anyhow::Ok(contents)
})
}
}
pub struct Mention {
pub uri: MentionUri,
pub content: String,
}
pub struct ContextPickerCompletionProvider {
@@ -68,6 +112,7 @@ impl ContextPickerCompletionProvider {
source_range: Range<Anchor>,
editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>,
project: Entity<Project>,
cx: &App,
) -> Completion {
let (file_name, directory) =
@@ -112,6 +157,7 @@ impl ContextPickerCompletionProvider {
new_text_len - 1,
editor,
mention_set,
project,
)),
}
}
@@ -159,6 +205,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
return Task::ready(Ok(Vec::new()));
};
let project = workspace.read(cx).project().clone();
let snapshot = buffer.read(cx).snapshot();
let source_range = snapshot.anchor_before(state.source_range.start)
..snapshot.anchor_after(state.source_range.end);
@@ -195,6 +242,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
source_range.clone(),
editor.clone(),
mention_set.clone(),
project.clone(),
cx,
)
})
@@ -254,6 +302,7 @@ fn confirm_completion_callback(
content_len: usize,
editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>,
project: Entity<Project>,
) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> {
Arc::new(move |_, window, cx| {
let crease_text = crease_text.clone();
@@ -261,6 +310,7 @@ fn confirm_completion_callback(
let editor = editor.clone();
let project_path = project_path.clone();
let mention_set = mention_set.clone();
let project = project.clone();
window.defer(cx, move |window, cx| {
let crease_id = crate::context_picker::insert_crease_for_mention(
excerpt_id,
@@ -272,8 +322,13 @@ fn confirm_completion_callback(
window,
cx,
);
let Some(path) = project.read(cx).absolute_path(&project_path, cx) else {
return;
};
if let Some(crease_id) = crease_id {
mention_set.lock().insert(crease_id, project_path);
mention_set.lock().insert(crease_id, path);
}
});
false

View File

@@ -0,0 +1,479 @@
use crate::acp::completion_provider::ContextPickerCompletionProvider;
use crate::acp::completion_provider::MentionSet;
use acp_thread::MentionUri;
use agent_client_protocol as acp;
use anyhow::Result;
use collections::HashSet;
use editor::{
AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode,
EditorStyle, MultiBuffer,
};
use file_icons::FileIcons;
use gpui::{
AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, Task, TextStyle, WeakEntity,
};
use language::Language;
use language::{Buffer, BufferSnapshot};
use parking_lot::Mutex;
use project::{CompletionIntent, Project};
use settings::Settings;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use theme::ThemeSettings;
use ui::{
ActiveTheme, App, IconName, InteractiveElement, IntoElement, ParentElement, Render,
SharedString, Styled, TextSize, Window, div,
};
use util::ResultExt;
use workspace::Workspace;
use zed_actions::agent::Chat;
pub const MIN_EDITOR_LINES: usize = 4;
pub const MAX_EDITOR_LINES: usize = 8;
pub struct MessageEditor {
editor: Entity<Editor>,
project: Entity<Project>,
mention_set: Arc<Mutex<MentionSet>>,
}
pub enum MessageEditorEvent {
Chat,
}
impl EventEmitter<MessageEditorEvent> for MessageEditor {}
impl MessageEditor {
pub fn new(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let language = Language::new(
language::LanguageConfig {
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
..Default::default()
},
None,
);
let mention_set = Arc::new(Mutex::new(MentionSet::default()));
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
min_lines: MIN_EDITOR_LINES,
max_lines: Some(MAX_EDITOR_LINES),
},
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Message the agent @ to include files", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_use_modal_editing(true);
editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new(
mention_set.clone(),
workspace,
cx.weak_entity(),
))));
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
Self {
editor,
project,
mention_set,
}
}
pub fn is_empty(&self, cx: &App) -> bool {
self.editor.read(cx).is_empty(cx)
}
pub fn contents(&self, cx: &mut Context<Self>) -> Task<Result<Vec<acp::ContentBlock>>> {
let contents = self.mention_set.lock().contents(self.project.clone(), cx);
let editor = self.editor.clone();
cx.spawn(async move |_, cx| {
let contents = contents.await?;
editor.update(cx, |editor, cx| {
let mut ix = 0;
let mut chunks: Vec<acp::ContentBlock> = Vec::new();
let text = editor.text(cx);
editor.display_map.update(cx, |map, cx| {
let snapshot = map.snapshot(cx);
for (crease_id, crease) in snapshot.crease_snapshot.creases() {
// Skip creases that have been edited out of the message buffer.
if !crease.range().start.is_valid(&snapshot.buffer_snapshot) {
continue;
}
if let Some(mention) = contents.get(&crease_id) {
let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot);
if crease_range.start > ix {
chunks.push(text[ix..crease_range.start].into());
}
chunks.push(acp::ContentBlock::Resource(acp::EmbeddedResource {
annotations: None,
resource: acp::EmbeddedResourceResource::TextResourceContents(
acp::TextResourceContents {
mime_type: None,
text: mention.content.clone(),
uri: mention.uri.to_uri(),
},
),
}));
ix = crease_range.end;
}
}
if ix < text.len() {
let last_chunk = text[ix..].trim_end();
if !last_chunk.is_empty() {
chunks.push(last_chunk.into());
}
}
});
chunks
})
})
}
pub fn clear(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| {
editor.clear(window, cx);
editor.remove_creases(self.mention_set.lock().drain(), cx)
});
}
fn chat(&mut self, _: &Chat, _: &mut Window, cx: &mut Context<Self>) {
cx.emit(MessageEditorEvent::Chat)
}
pub fn insert_dragged_files(
&self,
paths: Vec<project::ProjectPath>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let buffer = self.editor.read(cx).buffer().clone();
let Some((&excerpt_id, _, _)) = buffer.read(cx).snapshot(cx).as_singleton() else {
return;
};
let Some(buffer) = buffer.read(cx).as_singleton() else {
return;
};
for path in paths {
let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else {
continue;
};
let Some(abs_path) = self.project.read(cx).absolute_path(&path, cx) else {
continue;
};
let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len()));
let path_prefix = abs_path
.file_name()
.unwrap_or(path.path.as_os_str())
.display()
.to_string();
let completion = ContextPickerCompletionProvider::completion_for_path(
path,
&path_prefix,
false,
entry.is_dir(),
excerpt_id,
anchor..anchor,
self.editor.clone(),
self.mention_set.clone(),
self.project.clone(),
cx,
);
self.editor.update(cx, |message_editor, cx| {
message_editor.edit(
[(
multi_buffer::Anchor::max()..multi_buffer::Anchor::max(),
completion.new_text,
)],
cx,
);
});
if let Some(confirm) = completion.confirm.clone() {
confirm(CompletionIntent::Complete, window, cx);
}
}
}
pub fn set_expanded(&mut self, expanded: bool, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| {
if expanded {
editor.set_mode(EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: false,
})
} else {
editor.set_mode(EditorMode::AutoHeight {
min_lines: MIN_EDITOR_LINES,
max_lines: Some(MAX_EDITOR_LINES),
})
}
cx.notify()
});
}
#[allow(unused)]
fn set_draft_message(
message_editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>,
project: Entity<Project>,
message: Option<&[acp::ContentBlock]>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<BufferSnapshot> {
cx.notify();
let message = message?;
let mut text = String::new();
let mut mentions = Vec::new();
for chunk in message {
match chunk {
acp::ContentBlock::Text(text_content) => {
text.push_str(&text_content.text);
}
acp::ContentBlock::Resource(acp::EmbeddedResource {
resource: acp::EmbeddedResourceResource::TextResourceContents(resource),
..
}) => {
if let Some(ref mention @ MentionUri::File(ref abs_path)) =
MentionUri::parse(&resource.uri).log_err()
{
let project_path = project
.read(cx)
.project_path_for_absolute_path(&abs_path, cx);
let start = text.len();
let content = mention.to_uri();
text.push_str(&content);
let end = text.len();
if let Some(project_path) = project_path {
let filename: SharedString = project_path
.path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string()
.into();
mentions.push((start..end, project_path, filename));
}
}
}
acp::ContentBlock::Image(_)
| acp::ContentBlock::Audio(_)
| acp::ContentBlock::Resource(_)
| acp::ContentBlock::ResourceLink(_) => {}
}
}
let snapshot = message_editor.update(cx, |editor, cx| {
editor.set_text(text, window, cx);
editor.buffer().read(cx).snapshot(cx)
});
for (range, project_path, filename) in mentions {
let crease_icon_path = if project_path.path.is_dir() {
FileIcons::get_folder_icon(false, cx)
.unwrap_or_else(|| IconName::Folder.path().into())
} else {
FileIcons::get_icon(Path::new(project_path.path.as_ref()), cx)
.unwrap_or_else(|| IconName::File.path().into())
};
let anchor = snapshot.anchor_before(range.start);
if let Some(project_path) = project.read(cx).absolute_path(&project_path, cx) {
let crease_id = crate::context_picker::insert_crease_for_mention(
anchor.excerpt_id,
anchor.text_anchor,
range.end - range.start,
filename,
crease_icon_path,
message_editor.clone(),
window,
cx,
);
if let Some(crease_id) = crease_id {
mention_set.lock().insert(crease_id, project_path);
}
}
}
let snapshot = snapshot.as_singleton().unwrap().2.clone();
Some(snapshot)
}
#[cfg(test)]
pub fn set_text(&mut self, text: &str, window: &mut Window, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| {
editor.set_text(text, window, cx);
});
}
}
impl Focusable for MessageEditor {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.editor.focus_handle(cx)
}
}
impl Render for MessageEditor {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
div()
.key_context("MessageEditor")
.on_action(cx.listener(Self::chat))
.flex_1()
.child({
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small
.rems(cx)
.to_pixels(settings.agent_font_size(cx));
let line_height = settings.buffer_line_height.value() * font_size;
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
EditorElement::new(
&self.editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
)
})
}
}
#[cfg(test)]
mod tests {
use std::path::Path;
use agent_client_protocol as acp;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext};
use lsp::{CompletionContext, CompletionTriggerKind};
use project::{CompletionIntent, Project};
use serde_json::json;
use util::path;
use workspace::Workspace;
use crate::acp::{message_editor::MessageEditor, thread_view::tests::init_test};
#[gpui::test]
async fn test_at_mention_removal(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({"file": ""})).await;
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let message_editor = cx.update(|window, cx| {
cx.new(|cx| MessageEditor::new(workspace.downgrade(), project.clone(), window, cx))
});
let editor = message_editor.update(cx, |message_editor, _| message_editor.editor.clone());
cx.run_until_parked();
let excerpt_id = editor.update(cx, |editor, cx| {
editor
.buffer()
.read(cx)
.excerpt_ids()
.into_iter()
.next()
.unwrap()
});
let completions = editor.update_in(cx, |editor, window, cx| {
editor.set_text("Hello @", window, cx);
let buffer = editor.buffer().read(cx).as_singleton().unwrap();
let completion_provider = editor.completion_provider().unwrap();
completion_provider.completions(
excerpt_id,
&buffer,
text::Anchor::MAX,
CompletionContext {
trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER,
trigger_character: Some("@".into()),
},
window,
cx,
)
});
let [_, completion]: [_; 2] = completions
.await
.unwrap()
.into_iter()
.flat_map(|response| response.completions)
.collect::<Vec<_>>()
.try_into()
.unwrap();
editor.update_in(cx, |editor, window, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
let start = snapshot
.anchor_in_excerpt(excerpt_id, completion.replace_range.start)
.unwrap();
let end = snapshot
.anchor_in_excerpt(excerpt_id, completion.replace_range.end)
.unwrap();
editor.edit([(start..end, completion.new_text)], cx);
(completion.confirm.unwrap())(CompletionIntent::Complete, window, cx);
});
cx.run_until_parked();
// Backspace over the inserted crease (and the following space).
editor.update_in(cx, |editor, window, cx| {
editor.backspace(&Default::default(), window, cx);
editor.backspace(&Default::default(), window, cx);
});
let content = message_editor
.update_in(cx, |message_editor, _window, cx| {
message_editor.contents(cx)
})
.await
.unwrap();
// We don't send a resource link for the deleted crease.
pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]);
}
}

View File

@@ -1,92 +0,0 @@
pub struct MessageHistory<T> {
items: Vec<T>,
current: Option<usize>,
}
impl<T> Default for MessageHistory<T> {
fn default() -> Self {
MessageHistory {
items: Vec::new(),
current: None,
}
}
}
impl<T> MessageHistory<T> {
pub fn push(&mut self, message: T) {
self.current.take();
self.items.push(message);
}
pub fn reset_position(&mut self) {
self.current.take();
}
pub fn prev(&mut self) -> Option<&T> {
if self.items.is_empty() {
return None;
}
let new_ix = self
.current
.get_or_insert(self.items.len())
.saturating_sub(1);
self.current = Some(new_ix);
self.items.get(new_ix)
}
pub fn next(&mut self) -> Option<&T> {
let current = self.current.as_mut()?;
*current += 1;
self.items.get(*current).or_else(|| {
self.current.take();
None
})
}
#[cfg(test)]
pub fn items(&self) -> &[T] {
&self.items
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prev_next() {
let mut history = MessageHistory::default();
// Test empty history
assert_eq!(history.prev(), None);
assert_eq!(history.next(), None);
// Add some messages
history.push("first");
history.push("second");
history.push("third");
// Test prev navigation
assert_eq!(history.prev(), Some(&"third"));
assert_eq!(history.prev(), Some(&"second"));
assert_eq!(history.prev(), Some(&"first"));
assert_eq!(history.prev(), Some(&"first"));
assert_eq!(history.next(), Some(&"second"));
// Test mixed navigation
history.push("fourth");
assert_eq!(history.prev(), Some(&"fourth"));
assert_eq!(history.prev(), Some(&"third"));
assert_eq!(history.next(), Some(&"fourth"));
assert_eq!(history.next(), None);
// Test that push resets navigation
history.prev();
history.prev();
history.push("fifth");
assert_eq!(history.prev(), Some(&"fifth"));
}
}

View File

@@ -0,0 +1,472 @@
use std::{cmp::Reverse, rc::Rc, sync::Arc};
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
use agent_client_protocol as acp;
use anyhow::Result;
use collections::IndexMap;
use futures::FutureExt;
use fuzzy::{StringMatchCandidate, match_strings};
use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use ui::{
AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
prelude::*, rems,
};
use util::ResultExt;
pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
pub fn acp_model_selector(
session_id: acp::SessionId,
selector: Rc<dyn AgentModelSelector>,
window: &mut Window,
cx: &mut Context<AcpModelSelector>,
) -> AcpModelSelector {
let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
Picker::list(delegate, window, cx)
.show_scrollbar(true)
.width(rems(20.))
.max_height(Some(rems(20.).into()))
}
enum AcpModelPickerEntry {
Separator(SharedString),
Model(AgentModelInfo),
}
pub struct AcpModelPickerDelegate {
session_id: acp::SessionId,
selector: Rc<dyn AgentModelSelector>,
filtered_entries: Vec<AcpModelPickerEntry>,
models: Option<AgentModelList>,
selected_index: usize,
selected_model: Option<AgentModelInfo>,
_refresh_models_task: Task<()>,
}
impl AcpModelPickerDelegate {
fn new(
session_id: acp::SessionId,
selector: Rc<dyn AgentModelSelector>,
window: &mut Window,
cx: &mut Context<AcpModelSelector>,
) -> Self {
let mut rx = selector.watch(cx);
let refresh_models_task = cx.spawn_in(window, {
let session_id = session_id.clone();
async move |this, cx| {
async fn refresh(
this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
session_id: &acp::SessionId,
cx: &mut AsyncWindowContext,
) -> Result<()> {
let (models_task, selected_model_task) = this.update(cx, |this, cx| {
(
this.delegate.selector.list_models(cx),
this.delegate.selector.selected_model(session_id, cx),
)
})?;
let (models, selected_model) = futures::join!(models_task, selected_model_task);
this.update_in(cx, |this, window, cx| {
this.delegate.models = models.ok();
this.delegate.selected_model = selected_model.ok();
this.delegate.update_matches(this.query(cx), window, cx)
})?
.await;
Ok(())
}
refresh(&this, &session_id, cx).await.log_err();
while let Ok(()) = rx.recv().await {
refresh(&this, &session_id, cx).await.log_err();
}
}
});
Self {
session_id,
selector,
filtered_entries: Vec::new(),
models: None,
selected_model: None,
selected_index: 0,
_refresh_models_task: refresh_models_task,
}
}
pub fn active_model(&self) -> Option<&AgentModelInfo> {
self.selected_model.as_ref()
}
}
impl PickerDelegate for AcpModelPickerDelegate {
type ListItem = AnyElement;
fn match_count(&self) -> usize {
self.filtered_entries.len()
}
fn selected_index(&self) -> usize {
self.selected_index
}
fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
cx.notify();
}
fn can_select(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) -> bool {
match self.filtered_entries.get(ix) {
Some(AcpModelPickerEntry::Model(_)) => true,
Some(AcpModelPickerEntry::Separator(_)) | None => false,
}
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Select a model…".into()
}
fn update_matches(
&mut self,
query: String,
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
cx.spawn_in(window, async move |this, cx| {
let filtered_models = match this
.read_with(cx, |this, cx| {
this.delegate.models.clone().map(move |models| {
fuzzy_search(models, query, cx.background_executor().clone())
})
})
.ok()
.flatten()
{
Some(task) => task.await,
None => AgentModelList::Flat(vec![]),
};
this.update_in(cx, |this, window, cx| {
this.delegate.filtered_entries =
info_list_to_picker_entries(filtered_models).collect();
// Finds the currently selected model in the list
let new_index = this
.delegate
.selected_model
.as_ref()
.and_then(|selected| {
this.delegate.filtered_entries.iter().position(|entry| {
if let AcpModelPickerEntry::Model(model_info) = entry {
model_info.id == selected.id
} else {
false
}
})
})
.unwrap_or(0);
this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
cx.notify();
})
.ok();
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
if let Some(AcpModelPickerEntry::Model(model_info)) =
self.filtered_entries.get(self.selected_index)
{
self.selector
.select_model(self.session_id.clone(), model_info.id.clone(), cx)
.detach_and_log_err(cx);
self.selected_model = Some(model_info.clone());
let current_index = self.selected_index;
self.set_selected_index(current_index, window, cx);
cx.emit(DismissEvent);
}
}
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
cx.emit(DismissEvent);
}
fn render_match(
&self,
ix: usize,
selected: bool,
_: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
match self.filtered_entries.get(ix)? {
AcpModelPickerEntry::Separator(title) => Some(
div()
.px_2()
.pb_1()
.when(ix > 1, |this| {
this.mt_1()
.pt_2()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
})
.child(
Label::new(title)
.size(LabelSize::XSmall)
.color(Color::Muted),
)
.into_any_element(),
),
AcpModelPickerEntry::Model(model_info) => {
let is_selected = Some(model_info) == self.selected_model.as_ref();
let model_icon_color = if is_selected {
Color::Accent
} else {
Color::Muted
};
Some(
ListItem::new(ix)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
.start_slot::<Icon>(model_info.icon.map(|icon| {
Icon::new(icon)
.color(model_icon_color)
.size(IconSize::Small)
}))
.child(
h_flex()
.w_full()
.pl_0p5()
.gap_1p5()
.w(px(240.))
.child(Label::new(model_info.name.clone()).truncate()),
)
.end_slot(div().pr_3().when(is_selected, |this| {
this.child(
Icon::new(IconName::Check)
.color(Color::Accent)
.size(IconSize::Small),
)
}))
.into_any_element(),
)
}
}
}
fn render_footer(
&self,
_: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<gpui::AnyElement> {
Some(
h_flex()
.w_full()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.p_1()
.gap_4()
.justify_between()
.child(
Button::new("configure", "Configure")
.icon(IconName::Settings)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.on_click(|_, window, cx| {
window.dispatch_action(
zed_actions::agent::OpenSettings.boxed_clone(),
cx,
);
}),
)
.into_any(),
)
}
}
fn info_list_to_picker_entries(
model_list: AgentModelList,
) -> impl Iterator<Item = AcpModelPickerEntry> {
match model_list {
AgentModelList::Flat(list) => {
itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
}
AgentModelList::Grouped(index_map) => {
itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
.chain(models.into_iter().map(AcpModelPickerEntry::Model))
}))
}
}
}
async fn fuzzy_search(
model_list: AgentModelList,
query: String,
executor: BackgroundExecutor,
) -> AgentModelList {
async fn fuzzy_search_list(
model_list: Vec<AgentModelInfo>,
query: &str,
executor: BackgroundExecutor,
) -> Vec<AgentModelInfo> {
let candidates = model_list
.iter()
.enumerate()
.map(|(ix, model)| {
StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
})
.collect::<Vec<_>>();
let mut matches = match_strings(
&candidates,
&query,
false,
true,
100,
&Default::default(),
executor,
)
.await;
matches.sort_unstable_by_key(|mat| {
let candidate = &candidates[mat.candidate_id];
(Reverse(OrderedFloat(mat.score)), candidate.id)
});
matches
.into_iter()
.map(|mat| model_list[mat.candidate_id].clone())
.collect()
}
match model_list {
AgentModelList::Flat(model_list) => {
AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
}
AgentModelList::Grouped(index_map) => {
let groups =
futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
fuzzy_search_list(models, &query, executor.clone())
.map(|results| (group_name, results))
}))
.await;
AgentModelList::Grouped(IndexMap::from_iter(
groups
.into_iter()
.filter(|(_, results)| !results.is_empty()),
))
}
}
}
#[cfg(test)]
mod tests {
use gpui::TestAppContext;
use super::*;
fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
|(group, models)| {
(
acp_thread::AgentModelGroupName(group.to_string().into()),
models
.into_iter()
.map(|model| acp_thread::AgentModelInfo {
id: acp_thread::AgentModelId(model.to_string().into()),
name: model.to_string().into(),
icon: None,
})
.collect::<Vec<_>>(),
)
},
)))
}
fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
let AgentModelList::Grouped(groups) = result else {
panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
};
assert_eq!(
groups.len(),
expected.len(),
"Number of groups doesn't match"
);
for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
let (actual_group, actual_models) = groups.get_index(i).unwrap();
assert_eq!(
actual_group.0.as_ref(),
*expected_group,
"Group at position {} doesn't match expected group",
i
);
assert_eq!(
actual_models.len(),
expected_models.len(),
"Number of models in group {} doesn't match",
expected_group
);
for (j, expected_model_name) in expected_models.iter().enumerate() {
assert_eq!(
actual_models[j].name, *expected_model_name,
"Model at position {} in group {} doesn't match expected model",
j, expected_group
);
}
}
}
#[gpui::test]
async fn test_fuzzy_match(cx: &mut TestAppContext) {
let models = create_model_list(vec![
(
"zed",
vec![
"Claude 3.7 Sonnet",
"Claude 3.7 Sonnet Thinking",
"gpt-4.1",
"gpt-4.1-nano",
],
),
("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
("ollama", vec!["mistral", "deepseek"]),
]);
// Results should preserve models order whenever possible.
// In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
// similarity scores, but `zed/gpt-4.1` was higher in the models list,
// so it should appear first in the results.
let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
assert_models_eq(
results,
vec![
("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
],
);
// Fuzzy search
let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
assert_models_eq(
results,
vec![
("zed", vec!["gpt-4.1-nano"]),
("openai", vec!["gpt-4.1-nano"]),
],
);
}
}

View File

@@ -0,0 +1,85 @@
use std::rc::Rc;
use acp_thread::AgentModelSelector;
use agent_client_protocol as acp;
use gpui::{Entity, FocusHandle};
use picker::popover_menu::PickerPopoverMenu;
use ui::{
ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, Tooltip, Window, prelude::*,
};
use zed_actions::agent::ToggleModelSelector;
use crate::acp::{AcpModelSelector, model_selector::acp_model_selector};
pub struct AcpModelSelectorPopover {
selector: Entity<AcpModelSelector>,
menu_handle: PopoverMenuHandle<AcpModelSelector>,
focus_handle: FocusHandle,
}
impl AcpModelSelectorPopover {
pub(crate) fn new(
session_id: acp::SessionId,
selector: Rc<dyn AgentModelSelector>,
menu_handle: PopoverMenuHandle<AcpModelSelector>,
focus_handle: FocusHandle,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
Self {
selector: cx.new(move |cx| acp_model_selector(session_id, selector, window, cx)),
menu_handle,
focus_handle,
}
}
pub fn toggle(&self, window: &mut Window, cx: &mut Context<Self>) {
self.menu_handle.toggle(window, cx);
}
}
impl Render for AcpModelSelectorPopover {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let model = self.selector.read(cx).delegate.active_model();
let model_name = model
.as_ref()
.map(|model| model.name.clone())
.unwrap_or_else(|| SharedString::from("Select a Model"));
let model_icon = model.as_ref().and_then(|model| model.icon);
let focus_handle = self.focus_handle.clone();
PickerPopoverMenu::new(
self.selector.clone(),
ButtonLike::new("active-model")
.when_some(model_icon, |this, icon| {
this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall))
})
.child(
Label::new(model_name)
.color(Color::Muted)
.size(LabelSize::Small)
.ml_0p5(),
)
.child(
Icon::new(IconName::ChevronDown)
.color(Color::Muted)
.size(IconSize::XSmall),
),
move |window, cx| {
Tooltip::for_action_in(
"Change Model",
&ToggleModelSelector,
&focus_handle,
window,
cx,
)
},
gpui::Corner::BottomRight,
cx,
)
.with_handle(self.menu_handle.clone())
.render(window, cx)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1521,7 +1521,8 @@ impl AgentDiff {
self.update_reviewing_editors(workspace, window, cx);
}
}
AcpThreadEvent::Stopped
AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::Stopped
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Error
| AcpThreadEvent::ServerExited(_) => {}

File diff suppressed because it is too large Load Diff

View File

@@ -64,6 +64,8 @@ actions!(
NewTextThread,
/// Toggles the context picker interface for adding files, symbols, or other context.
ToggleContextPicker,
/// Toggles the menu to create new agent threads.
ToggleNewThreadMenu,
/// Toggles the navigation menu for switching between threads and views.
ToggleNavigationMenu,
/// Toggles the options menu for agent settings and preferences.
@@ -155,11 +157,11 @@ enum ExternalAgent {
}
impl ExternalAgent {
pub fn server(&self) -> Rc<dyn agent_servers::AgentServer> {
pub fn server(&self, fs: Arc<dyn fs::Fs>) -> Rc<dyn agent_servers::AgentServer> {
match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer),
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)),
}
}
}

View File

@@ -2,7 +2,7 @@ mod agent_notification;
mod burn_mode_tooltip;
mod context_pill;
mod end_trial_upsell;
mod new_thread_button;
// mod new_thread_button;
mod onboarding_modal;
pub mod preview;
@@ -10,5 +10,5 @@ pub use agent_notification::*;
pub use burn_mode_tooltip::*;
pub use context_pill::*;
pub use end_trial_upsell::*;
pub use new_thread_button::*;
// pub use new_thread_button::*;
pub use onboarding_modal::*;

View File

@@ -11,7 +11,7 @@ pub struct NewThreadButton {
}
impl NewThreadButton {
pub fn new(id: impl Into<ElementId>, label: impl Into<SharedString>, icon: IconName) -> Self {
fn new(id: impl Into<ElementId>, label: impl Into<SharedString>, icon: IconName) -> Self {
Self {
id: id.into(),
label: label.into(),
@@ -21,12 +21,12 @@ impl NewThreadButton {
}
}
pub fn keybinding(mut self, keybinding: Option<ui::KeyBinding>) -> Self {
fn keybinding(mut self, keybinding: Option<ui::KeyBinding>) -> Self {
self.keybinding = keybinding;
self
}
pub fn on_click<F>(mut self, handler: F) -> Self
fn on_click<F>(mut self, handler: F) -> Self
where
F: Fn(&mut Window, &mut App) + 'static,
{

View File

@@ -86,7 +86,7 @@ impl Tool for DiagnosticsTool {
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
_action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
@@ -159,10 +159,6 @@ impl Tool for DiagnosticsTool {
}
}
action_log.update(cx, |action_log, _cx| {
action_log.checked_project_diagnostics();
});
if has_diagnostics {
Task::ready(Ok(output.into())).into()
} else {

View File

@@ -65,7 +65,7 @@ pub enum EditAgentOutputEvent {
ResolvingEditRange(Range<Anchor>),
UnresolvedEditRange,
AmbiguousEditRange(Vec<Range<usize>>),
Edited,
Edited(Range<Anchor>),
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
@@ -178,7 +178,9 @@ impl EditAgent {
)
});
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.unbounded_send(EditAgentOutputEvent::Edited(
language::Anchor::MIN..language::Anchor::MAX,
))
.ok();
})?;
@@ -200,7 +202,9 @@ impl EditAgent {
});
})?;
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.unbounded_send(EditAgentOutputEvent::Edited(
language::Anchor::MIN..language::Anchor::MAX,
))
.ok();
}
}
@@ -336,8 +340,8 @@ impl EditAgent {
// Edit the buffer and report edits to the action log as part of the
// same effect cycle, otherwise the edit will be reported as if the
// user made it.
cx.update(|cx| {
let max_edit_end = buffer.update(cx, |buffer, cx| {
let (min_edit_start, max_edit_end) = cx.update(|cx| {
let (min_edit_start, max_edit_end) = buffer.update(cx, |buffer, cx| {
buffer.edit(edits.iter().cloned(), None, cx);
let max_edit_end = buffer
.summaries_for_anchors::<Point, _>(
@@ -345,7 +349,16 @@ impl EditAgent {
)
.max()
.unwrap();
buffer.anchor_before(max_edit_end)
let min_edit_start = buffer
.summaries_for_anchors::<Point, _>(
edits.iter().map(|(range, _)| &range.start),
)
.min()
.unwrap();
(
buffer.anchor_after(min_edit_start),
buffer.anchor_before(max_edit_end),
)
});
self.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
@@ -358,9 +371,10 @@ impl EditAgent {
cx,
);
});
(min_edit_start, max_edit_end)
})?;
output_events
.unbounded_send(EditAgentOutputEvent::Edited)
.unbounded_send(EditAgentOutputEvent::Edited(min_edit_start..max_edit_end))
.ok();
}
@@ -755,6 +769,7 @@ mod tests {
use gpui::{AppContext, TestAppContext};
use indoc::indoc;
use language_model::fake_provider::FakeLanguageModel;
use pretty_assertions::assert_matches;
use project::{AgentLocation, Project};
use rand::prelude::*;
use rand::rngs::StdRng;
@@ -992,7 +1007,10 @@ mod tests {
model.send_last_completion_stream_text_chunk("<new_text>abX");
cx.run_until_parked();
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited(_)]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXc\ndef\nghi\njkl"
@@ -1007,7 +1025,10 @@ mod tests {
model.send_last_completion_stream_text_chunk("cY");
cx.run_until_parked();
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited { .. }]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi\njkl"
@@ -1118,9 +1139,9 @@ mod tests {
model.send_last_completion_stream_text_chunk("GHI</new_text>");
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited { .. }]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@@ -1165,9 +1186,9 @@ mod tests {
);
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited(_)]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@@ -1183,9 +1204,9 @@ mod tests {
chunks_tx.unbounded_send("```\njkl\n").unwrap();
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited { .. }]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@@ -1201,9 +1222,9 @@ mod tests {
chunks_tx.unbounded_send("mno\n").unwrap();
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited { .. }]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@@ -1219,9 +1240,9 @@ mod tests {
chunks_tx.unbounded_send("pqr\n```").unwrap();
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited(_)],
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),

View File

@@ -307,7 +307,7 @@ impl Tool for EditFileTool {
let mut ambiguous_ranges = Vec::new();
while let Some(event) = events.next().await {
match event {
EditAgentOutputEvent::Edited => {
EditAgentOutputEvent::Edited { .. } => {
if let Some(card) = card_clone.as_ref() {
card.update(cx, |card, cx| card.update_diff(cx))?;
}

View File

@@ -18,6 +18,6 @@ collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
parking_lot.workspace = true
rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] }
rodio = { workspace = true, features = ["wav", "playback", "tracing"] }
util.workspace = true
workspace-hack.workspace = true

View File

@@ -59,16 +59,9 @@ pub enum VersionCheckType {
pub enum AutoUpdateStatus {
Idle,
Checking,
Downloading {
version: VersionCheckType,
},
Installing {
version: VersionCheckType,
},
Updated {
binary_path: PathBuf,
version: VersionCheckType,
},
Downloading { version: VersionCheckType },
Installing { version: VersionCheckType },
Updated { version: VersionCheckType },
Errored,
}
@@ -83,6 +76,7 @@ pub struct AutoUpdater {
current_version: SemanticVersion,
http_client: Arc<HttpClientWithUrl>,
pending_poll: Option<Task<Option<()>>>,
quit_subscription: Option<gpui::Subscription>,
}
#[derive(Deserialize, Clone, Debug)]
@@ -164,7 +158,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
AutoUpdateSetting::register(cx);
cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
workspace.register_action(|_, action: &Check, window, cx| check(action, window, cx));
workspace.register_action(|_, action, window, cx| check(action, window, cx));
workspace.register_action(|_, action, _, cx| {
view_release_notes(action, cx);
@@ -174,7 +168,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
let version = release_channel::AppVersion::global(cx);
let auto_updater = cx.new(|cx| {
let updater = AutoUpdater::new(version, http_client);
let updater = AutoUpdater::new(version, http_client, cx);
let poll_for_updates = ReleaseChannel::try_global(cx)
.map(|channel| channel.poll_for_updates())
@@ -321,12 +315,34 @@ impl AutoUpdater {
cx.default_global::<GlobalAutoUpdate>().0.clone()
}
fn new(current_version: SemanticVersion, http_client: Arc<HttpClientWithUrl>) -> Self {
fn new(
current_version: SemanticVersion,
http_client: Arc<HttpClientWithUrl>,
cx: &mut Context<Self>,
) -> Self {
// On windows, executable files cannot be overwritten while they are
// running, so we must wait to overwrite the application until quitting
// or restarting. When quitting the app, we spawn the auto update helper
// to finish the auto update process after Zed exits. When restarting
// the app after an update, we use `set_restart_path` to run the auto
// update helper instead of the app, so that it can overwrite the app
// and then spawn the new binary.
let quit_subscription = Some(cx.on_app_quit(|_, _| async move {
#[cfg(target_os = "windows")]
finalize_auto_update_on_quit();
}));
cx.on_app_restart(|this, _| {
this.quit_subscription.take();
})
.detach();
Self {
status: AutoUpdateStatus::Idle,
current_version,
http_client,
pending_poll: None,
quit_subscription,
}
}
@@ -536,6 +552,8 @@ impl AutoUpdater {
)
})?;
Self::check_dependencies()?;
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Checking;
cx.notify();
@@ -582,13 +600,15 @@ impl AutoUpdater {
cx.notify();
})?;
let binary_path = Self::binary_path(installer_dir, target_path, &cx).await?;
let new_binary_path = Self::install_release(installer_dir, target_path, &cx).await?;
if let Some(new_binary_path) = new_binary_path {
cx.update(|cx| cx.set_restart_path(new_binary_path))?;
}
this.update(&mut cx, |this, cx| {
this.set_should_show_update_notification(true, cx)
.detach_and_log_err(cx);
this.status = AutoUpdateStatus::Updated {
binary_path,
version: newer_version,
};
cx.notify();
@@ -639,6 +659,15 @@ impl AutoUpdater {
}
}
fn check_dependencies() -> Result<()> {
#[cfg(not(target_os = "windows"))]
anyhow::ensure!(
which::which("rsync").is_ok(),
"Aborting. Could not find rsync which is required for auto-updates."
);
Ok(())
}
async fn target_path(installer_dir: &InstallerDir) -> Result<PathBuf> {
let filename = match OS {
"macos" => anyhow::Ok("Zed.dmg"),
@@ -647,20 +676,14 @@ impl AutoUpdater {
unsupported_os => anyhow::bail!("not supported: {unsupported_os}"),
}?;
#[cfg(not(target_os = "windows"))]
anyhow::ensure!(
which::which("rsync").is_ok(),
"Aborting. Could not find rsync which is required for auto-updates."
);
Ok(installer_dir.path().join(filename))
}
async fn binary_path(
async fn install_release(
installer_dir: InstallerDir,
target_path: PathBuf,
cx: &AsyncApp,
) -> Result<PathBuf> {
) -> Result<Option<PathBuf>> {
match OS {
"macos" => install_release_macos(&installer_dir, target_path, cx).await,
"linux" => install_release_linux(&installer_dir, target_path, cx).await,
@@ -801,7 +824,7 @@ async fn install_release_linux(
temp_dir: &InstallerDir,
downloaded_tar_gz: PathBuf,
cx: &AsyncApp,
) -> Result<PathBuf> {
) -> Result<Option<PathBuf>> {
let channel = cx.update(|cx| ReleaseChannel::global(cx).dev_name())?;
let home_dir = PathBuf::from(env::var("HOME").context("no HOME env var set")?);
let running_app_path = cx.update(|cx| cx.app_path())??;
@@ -861,14 +884,14 @@ async fn install_release_linux(
String::from_utf8_lossy(&output.stderr)
);
Ok(to.join(expected_suffix))
Ok(Some(to.join(expected_suffix)))
}
async fn install_release_macos(
temp_dir: &InstallerDir,
downloaded_dmg: PathBuf,
cx: &AsyncApp,
) -> Result<PathBuf> {
) -> Result<Option<PathBuf>> {
let running_app_path = cx.update(|cx| cx.app_path())??;
let running_app_filename = running_app_path
.file_name()
@@ -910,10 +933,10 @@ async fn install_release_macos(
String::from_utf8_lossy(&output.stderr)
);
Ok(running_app_path)
Ok(None)
}
async fn install_release_windows(downloaded_installer: PathBuf) -> Result<PathBuf> {
async fn install_release_windows(downloaded_installer: PathBuf) -> Result<Option<PathBuf>> {
let output = Command::new(downloaded_installer)
.arg("/verysilent")
.arg("/update=true")
@@ -926,29 +949,36 @@ async fn install_release_windows(downloaded_installer: PathBuf) -> Result<PathBu
"failed to start installer: {:?}",
String::from_utf8_lossy(&output.stderr)
);
Ok(std::env::current_exe()?)
// We return the path to the update helper program, because it will
// perform the final steps of the update process, copying the new binary,
// deleting the old one, and launching the new binary.
let helper_path = std::env::current_exe()?
.parent()
.context("No parent dir for Zed.exe")?
.join("tools\\auto_update_helper.exe");
Ok(Some(helper_path))
}
pub fn check_pending_installation() -> bool {
pub fn finalize_auto_update_on_quit() {
let Some(installer_path) = std::env::current_exe()
.ok()
.and_then(|p| p.parent().map(|p| p.join("updates")))
else {
return false;
return;
};
// The installer will create a flag file after it finishes updating
let flag_file = installer_path.join("versions.txt");
if flag_file.exists() {
if let Some(helper) = installer_path
if flag_file.exists()
&& let Some(helper) = installer_path
.parent()
.map(|p| p.join("tools\\auto_update_helper.exe"))
{
let _ = std::process::Command::new(helper).spawn();
return true;
}
{
let mut command = std::process::Command::new(helper);
command.arg("--launch");
command.arg("false");
let _ = command.spawn();
}
false
}
#[cfg(test)]
@@ -1002,7 +1032,6 @@ mod tests {
let app_commit_sha = Ok(Some("a".to_string()));
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)),
};
let fetched_version = SemanticVersion::new(1, 0, 1);
@@ -1024,7 +1053,6 @@ mod tests {
let app_commit_sha = Ok(Some("a".to_string()));
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)),
};
let fetched_version = SemanticVersion::new(1, 0, 2);
@@ -1090,7 +1118,6 @@ mod tests {
let app_commit_sha = Ok(Some("a".to_string()));
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
};
let fetched_sha = "b".to_string();
@@ -1112,7 +1139,6 @@ mod tests {
let app_commit_sha = Ok(Some("a".to_string()));
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
};
let fetched_sha = "c".to_string();
@@ -1160,7 +1186,6 @@ mod tests {
let app_commit_sha = Ok(None);
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
};
let fetched_sha = "b".to_string();
@@ -1183,7 +1208,6 @@ mod tests {
let app_commit_sha = Ok(None);
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
};
let fetched_sha = "c".to_string();

View File

@@ -37,6 +37,11 @@ mod windows_impl {
pub(crate) const WM_JOB_UPDATED: u32 = WM_USER + 1;
pub(crate) const WM_TERMINATE: u32 = WM_USER + 2;
#[derive(Debug)]
struct Args {
launch: Option<bool>,
}
pub(crate) fn run() -> Result<()> {
let helper_dir = std::env::current_exe()?
.parent()
@@ -51,8 +56,9 @@ mod windows_impl {
log::info!("======= Starting Zed update =======");
let (tx, rx) = std::sync::mpsc::channel();
let hwnd = create_dialog_window(rx)?.0 as isize;
let args = parse_args();
std::thread::spawn(move || {
let result = perform_update(app_dir.as_path(), Some(hwnd));
let result = perform_update(app_dir.as_path(), Some(hwnd), args.launch.unwrap_or(true));
tx.send(result).ok();
unsafe { PostMessageW(Some(HWND(hwnd as _)), WM_TERMINATE, WPARAM(0), LPARAM(0)) }.ok();
});
@@ -77,6 +83,41 @@ mod windows_impl {
Ok(())
}
fn parse_args() -> Args {
let mut result = Args { launch: None };
if let Some(candidate) = std::env::args().nth(1) {
parse_single_arg(&candidate, &mut result);
}
result
}
fn parse_single_arg(arg: &str, result: &mut Args) {
let Some((key, value)) = arg.strip_prefix("--").and_then(|arg| arg.split_once('=')) else {
log::error!(
"Invalid argument format: '{}'. Expected format: --key=value",
arg
);
return;
};
match key {
"launch" => parse_launch_arg(value, &mut result.launch),
_ => log::error!("Unknown argument: --{}", key),
}
}
fn parse_launch_arg(value: &str, arg: &mut Option<bool>) {
match value {
"true" => *arg = Some(true),
"false" => *arg = Some(false),
_ => log::error!(
"Invalid value for --launch: '{}'. Expected 'true' or 'false'",
value
),
}
}
pub(crate) fn show_error(mut content: String) {
if content.len() > 600 {
content.truncate(600);
@@ -91,4 +132,47 @@ mod windows_impl {
)
};
}
#[cfg(test)]
mod tests {
use crate::windows_impl::{Args, parse_launch_arg, parse_single_arg};
#[test]
fn test_parse_launch_arg() {
let mut arg = None;
parse_launch_arg("true", &mut arg);
assert_eq!(arg, Some(true));
let mut arg = None;
parse_launch_arg("false", &mut arg);
assert_eq!(arg, Some(false));
let mut arg = None;
parse_launch_arg("invalid", &mut arg);
assert_eq!(arg, None);
}
#[test]
fn test_parse_single_arg() {
let mut args = Args { launch: None };
parse_single_arg("--launch=true", &mut args);
assert_eq!(args.launch, Some(true));
let mut args = Args { launch: None };
parse_single_arg("--launch=false", &mut args);
assert_eq!(args.launch, Some(false));
let mut args = Args { launch: None };
parse_single_arg("--launch=invalid", &mut args);
assert_eq!(args.launch, None);
let mut args = Args { launch: None };
parse_single_arg("--launch", &mut args);
assert_eq!(args.launch, None);
let mut args = Args { launch: None };
parse_single_arg("--unknown", &mut args);
assert_eq!(args.launch, None);
}
}
}

View File

@@ -72,7 +72,7 @@ pub(crate) fn create_dialog_window(receiver: Receiver<Result<()>>) -> Result<HWN
let hwnd = CreateWindowExW(
WS_EX_TOPMOST,
class_name,
windows::core::w!("Zed Editor"),
windows::core::w!("Zed"),
WS_VISIBLE | WS_POPUP | WS_CAPTION,
rect.right / 2 - width / 2,
rect.bottom / 2 - height / 2,
@@ -171,7 +171,7 @@ unsafe extern "system" fn wnd_proc(
&HSTRING::from(font_name),
);
let temp = SelectObject(hdc, font.into());
let string = HSTRING::from("Zed Editor is updating...");
let string = HSTRING::from("Updating Zed...");
return_if_failed!(TextOutW(hdc, 20, 15, &string).ok());
return_if_failed!(DeleteObject(temp).ok());

View File

@@ -118,7 +118,7 @@ pub(crate) const JOBS: [Job; 2] = [
},
];
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>) -> Result<()> {
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>, launch: bool) -> Result<()> {
let hwnd = hwnd.map(|ptr| HWND(ptr as _));
for job in JOBS.iter() {
@@ -145,9 +145,11 @@ pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>) -> Result<()>
}
}
}
let _ = std::process::Command::new(app_dir.join("Zed.exe"))
.creation_flags(CREATE_NEW_PROCESS_GROUP.0)
.spawn();
if launch {
let _ = std::process::Command::new(app_dir.join("Zed.exe"))
.creation_flags(CREATE_NEW_PROCESS_GROUP.0)
.spawn();
}
log::info!("Update completed successfully");
Ok(())
}
@@ -159,11 +161,11 @@ mod test {
#[test]
fn test_perform_update() {
let app_dir = std::path::Path::new("C:/");
assert!(perform_update(app_dir, None).is_ok());
assert!(perform_update(app_dir, None, false).is_ok());
// Simulate a timeout
unsafe { std::env::set_var("ZED_AUTO_UPDATE", "err") };
let ret = perform_update(app_dir, None);
let ret = perform_update(app_dir, None, false);
assert!(ret.is_err_and(|e| e.to_string().as_str() == "Timed out"));
}
}

View File

@@ -957,17 +957,14 @@ mod mac_os {
) -> Result<()> {
use anyhow::bail;
let app_id_prompt = format!("id of app \"{}\"", channel.display_name());
let app_id_output = Command::new("osascript")
let app_path_prompt = format!(
"POSIX path of (path to application \"{}\")",
channel.display_name()
);
let app_path_output = Command::new("osascript")
.arg("-e")
.arg(&app_id_prompt)
.arg(&app_path_prompt)
.output()?;
if !app_id_output.status.success() {
bail!("Could not determine app id for {}", channel.display_name());
}
let app_name = String::from_utf8(app_id_output.stdout)?.trim().to_owned();
let app_path_prompt = format!("kMDItemCFBundleIdentifier == '{app_name}'");
let app_path_output = Command::new("mdfind").arg(app_path_prompt).output()?;
if !app_path_output.status.success() {
bail!(
"Could not determine app path for {}",

View File

@@ -340,22 +340,35 @@ impl Telemetry {
}
pub fn log_edit_event(self: &Arc<Self>, environment: &'static str, is_via_ssh: bool) {
static LAST_EVENT_TIME: Mutex<Option<Instant>> = Mutex::new(None);
let mut state = self.state.lock();
let period_data = state.event_coalescer.log_event(environment);
drop(state);
if let Some((start, end, environment)) = period_data {
let duration = end
.saturating_duration_since(start)
.min(Duration::from_secs(60 * 60 * 24))
.as_millis() as i64;
if let Some(mut last_event) = LAST_EVENT_TIME.try_lock() {
let current_time = std::time::Instant::now();
let last_time = last_event.get_or_insert(current_time);
telemetry::event!(
"Editor Edited",
duration = duration,
environment = environment,
is_via_ssh = is_via_ssh
);
if current_time.duration_since(*last_time) > Duration::from_secs(60 * 10) {
*last_time = current_time;
} else {
return;
}
if let Some((start, end, environment)) = period_data {
let duration = end
.saturating_duration_since(start)
.min(Duration::from_secs(60 * 60 * 24))
.as_millis() as i64;
telemetry::event!(
"Editor Edited",
duration = duration,
environment = environment,
is_via_ssh = is_via_ssh
);
}
}
}

View File

@@ -21,7 +21,7 @@ use language::{
point_from_lsp, point_to_lsp,
};
use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
use node_runtime::NodeRuntime;
use node_runtime::{NodeRuntime, VersionCheck};
use parking_lot::Mutex;
use project::DisableAiSettings;
use request::StatusNotification;
@@ -1169,9 +1169,8 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
const SERVER_PATH: &str =
"node_modules/@github/copilot-language-server/dist/language-server.js";
let latest_version = node_runtime
.npm_package_latest_version(PACKAGE_NAME)
.await?;
// pinning it: https://github.com/zed-industries/zed/issues/36093
const PINNED_VERSION: &str = "1.354";
let server_path = paths::copilot_dir().join(SERVER_PATH);
fs.create_dir(paths::copilot_dir()).await?;
@@ -1181,12 +1180,13 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
PACKAGE_NAME,
&server_path,
paths::copilot_dir(),
&latest_version,
&PINNED_VERSION,
VersionCheck::VersionMismatch,
)
.await;
if should_install {
node_runtime
.npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
.npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &PINNED_VERSION)])
.await?;
}

View File

@@ -250,6 +250,24 @@ pub type RenderDiffHunkControlsFn = Arc<
) -> AnyElement,
>;
enum ReportEditorEvent {
Saved { auto_saved: bool },
EditorOpened,
ZetaTosClicked,
Closed,
}
impl ReportEditorEvent {
pub fn event_type(&self) -> &'static str {
match self {
Self::Saved { .. } => "Editor Saved",
Self::EditorOpened => "Editor Opened",
Self::ZetaTosClicked => "Edit Prediction Provider ToS Clicked",
Self::Closed => "Editor Closed",
}
}
}
struct InlineValueCache {
enabled: bool,
inlays: Vec<InlayId>,
@@ -2325,7 +2343,7 @@ impl Editor {
}
if editor.mode.is_full() {
editor.report_editor_event("Editor Opened", None, cx);
editor.report_editor_event(ReportEditorEvent::EditorOpened, None, cx);
}
editor
@@ -9124,7 +9142,7 @@ impl Editor {
.on_mouse_down(MouseButton::Left, |_, window, _| window.prevent_default())
.on_click(cx.listener(|this, _event, window, cx| {
cx.stop_propagation();
this.report_editor_event("Edit Prediction Provider ToS Clicked", None, cx);
this.report_editor_event(ReportEditorEvent::ZetaTosClicked, None, cx);
window.dispatch_action(
zed_actions::OpenZedPredictOnboarding.boxed_clone(),
cx,
@@ -20547,7 +20565,7 @@ impl Editor {
fn report_editor_event(
&self,
event_type: &'static str,
reported_event: ReportEditorEvent,
file_extension: Option<String>,
cx: &App,
) {
@@ -20581,15 +20599,30 @@ impl Editor {
.show_edit_predictions;
let project = project.read(cx);
telemetry::event!(
event_type,
file_extension,
vim_mode,
copilot_enabled,
copilot_enabled_for_language,
edit_predictions_provider,
is_via_ssh = project.is_via_ssh(),
);
let event_type = reported_event.event_type();
if let ReportEditorEvent::Saved { auto_saved } = reported_event {
telemetry::event!(
event_type,
type = if auto_saved {"autosave"} else {"manual"},
file_extension,
vim_mode,
copilot_enabled,
copilot_enabled_for_language,
edit_predictions_provider,
is_via_ssh = project.is_via_ssh(),
);
} else {
telemetry::event!(
event_type,
file_extension,
vim_mode,
copilot_enabled,
copilot_enabled_for_language,
edit_predictions_provider,
is_via_ssh = project.is_via_ssh(),
);
};
}
/// Copy the highlighted chunks to the clipboard as JSON. The format is an array of lines,

View File

@@ -22456,7 +22456,7 @@ async fn test_invisible_worktree_servers(cx: &mut TestAppContext) {
);
cx.update(|_, cx| {
workspace::reload(&workspace::Reload::default(), cx);
workspace::reload(cx);
});
assert_language_servers_count(
1,

View File

@@ -3011,7 +3011,7 @@ impl EditorElement {
.icon_color(Color::Custom(cx.theme().colors().editor_line_number))
.selected_icon_color(Color::Custom(cx.theme().colors().editor_foreground))
.icon_size(IconSize::Custom(rems(editor_font_size / window.rem_size())))
.width(width.into())
.width(width)
.on_click(move |_, window, cx| {
editor.update(cx, |editor, cx| {
editor.expand_excerpt(excerpt_id, direction, window, cx);
@@ -3627,7 +3627,7 @@ impl EditorElement {
ButtonLike::new("toggle-buffer-fold")
.style(ui::ButtonStyle::Transparent)
.height(px(28.).into())
.width(px(28.).into())
.width(px(28.))
.children(toggle_chevron_icon)
.tooltip({
let focus_handle = focus_handle.clone();

View File

@@ -1,7 +1,7 @@
use crate::{
Anchor, Autoscroll, Editor, EditorEvent, EditorSettings, ExcerptId, ExcerptRange, FormatTarget,
MultiBuffer, MultiBufferSnapshot, NavigationData, SearchWithinRange, SelectionEffects,
ToPoint as _,
MultiBuffer, MultiBufferSnapshot, NavigationData, ReportEditorEvent, SearchWithinRange,
SelectionEffects, ToPoint as _,
display_map::HighlightKey,
editor_settings::SeedQuerySetting,
persistence::{DB, SerializedEditor},
@@ -776,6 +776,10 @@ impl Item for Editor {
}
}
fn on_removed(&self, cx: &App) {
self.report_editor_event(ReportEditorEvent::Closed, None, cx);
}
fn deactivated(&mut self, _: &mut Window, cx: &mut Context<Self>) {
let selection = self.selections.newest_anchor();
self.push_to_nav_history(selection.head(), None, true, false, cx);
@@ -815,9 +819,9 @@ impl Item for Editor {
) -> Task<Result<()>> {
// Add meta data tracking # of auto saves
if options.autosave {
self.report_editor_event("Editor Autosaved", None, cx);
self.report_editor_event(ReportEditorEvent::Saved { auto_saved: true }, None, cx);
} else {
self.report_editor_event("Editor Saved", None, cx);
self.report_editor_event(ReportEditorEvent::Saved { auto_saved: false }, None, cx);
}
let buffers = self.buffer().clone().read(cx).all_buffers();
@@ -896,7 +900,11 @@ impl Item for Editor {
.path
.extension()
.map(|a| a.to_string_lossy().to_string());
self.report_editor_event("Editor Saved", file_extension, cx);
self.report_editor_event(
ReportEditorEvent::Saved { auto_saved: false },
file_extension,
cx,
);
project.update(cx, |project, cx| project.save_buffer_as(buffer, path, cx))
}
@@ -997,12 +1005,16 @@ impl Item for Editor {
) {
self.workspace = Some((workspace.weak_handle(), workspace.database_id()));
if let Some(workspace) = &workspace.weak_handle().upgrade() {
cx.subscribe(&workspace, |editor, _, event: &workspace::Event, _cx| {
if matches!(event, workspace::Event::ModalOpened) {
editor.mouse_context_menu.take();
editor.inline_blame_popover.take();
}
})
cx.subscribe(
&workspace,
|editor, _, event: &workspace::Event, _cx| match event {
workspace::Event::ModalOpened => {
editor.mouse_context_menu.take();
editor.inline_blame_popover.take();
}
_ => {}
},
)
.detach();
}
}

View File

@@ -1118,15 +1118,17 @@ impl ExtensionStore {
extensions_to_unload.len() - reload_count
);
for extension_id in &extensions_to_load {
if let Some(extension) = new_index.extensions.get(extension_id) {
telemetry::event!(
"Extension Loaded",
extension_id,
version = extension.manifest.version
);
}
}
let extension_ids = extensions_to_load
.iter()
.filter_map(|id| {
Some((
id.clone(),
new_index.extensions.get(id)?.manifest.version.clone(),
))
})
.collect::<Vec<_>>();
telemetry::event!("Extensions Loaded", id_and_versions = extension_ids);
let themes_to_remove = old_index
.themes

View File

@@ -33,13 +33,23 @@ impl FileIcons {
// TODO: Associate a type with the languages and have the file's language
// override these associations
// check if file name is in suffixes
// e.g. catch file named `eslint.config.js` instead of `.eslint.config.js`
if let Some(typ) = path.file_name().and_then(|typ| typ.to_str()) {
if let Some(mut typ) = path.file_name().and_then(|typ| typ.to_str()) {
// check if file name is in suffixes
// e.g. catch file named `eslint.config.js` instead of `.eslint.config.js`
let maybe_path = get_icon_from_suffix(typ);
if maybe_path.is_some() {
return maybe_path;
}
// check if suffix based on first dot is in suffixes
// e.g. consider `module.js` as suffix to angular's module file named `auth.module.js`
while let Some((_, suffix)) = typ.split_once('.') {
let maybe_path = get_icon_from_suffix(suffix);
if maybe_path.is_some() {
return maybe_path;
}
typ = suffix;
}
}
// primary case: check if the files extension or the hidden file name

View File

@@ -51,6 +51,7 @@ ashpd.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
git = { workspace = true, features = ["test-support"] }
[features]
test-support = ["gpui/test-support", "git/test-support"]

View File

@@ -1,8 +1,9 @@
use crate::{FakeFs, Fs};
use crate::{FakeFs, FakeFsEntry, Fs};
use anyhow::{Context as _, Result};
use collections::{HashMap, HashSet};
use futures::future::{self, BoxFuture, join_all};
use git::{
Oid,
blame::Blame,
repository::{
AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository,
@@ -10,8 +11,9 @@ use git::{
},
status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus},
};
use gpui::{AsyncApp, BackgroundExecutor, SharedString};
use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task};
use ignore::gitignore::GitignoreBuilder;
use parking_lot::Mutex;
use rope::Rope;
use smol::future::FutureExt as _;
use std::{path::PathBuf, sync::Arc};
@@ -19,6 +21,7 @@ use std::{path::PathBuf, sync::Arc};
#[derive(Clone)]
pub struct FakeGitRepository {
pub(crate) fs: Arc<FakeFs>,
pub(crate) checkpoints: Arc<Mutex<HashMap<Oid, FakeFsEntry>>>,
pub(crate) executor: BackgroundExecutor,
pub(crate) dot_git_path: PathBuf,
pub(crate) repository_dir_path: PathBuf,
@@ -183,7 +186,7 @@ impl GitRepository for FakeGitRepository {
async move { None }.boxed()
}
fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result<GitStatus>> {
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>> {
let workdir_path = self.dot_git_path.parent().unwrap();
// Load gitignores
@@ -311,7 +314,10 @@ impl GitRepository for FakeGitRepository {
entries: entries.into(),
})
});
async move { result? }.boxed()
Task::ready(match result {
Ok(result) => result,
Err(e) => Err(e),
})
}
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>> {
@@ -466,22 +472,57 @@ impl GitRepository for FakeGitRepository {
}
fn checkpoint(&self) -> BoxFuture<'static, Result<GitRepositoryCheckpoint>> {
unimplemented!()
let executor = self.executor.clone();
let fs = self.fs.clone();
let checkpoints = self.checkpoints.clone();
let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
async move {
executor.simulate_random_delay().await;
let oid = Oid::random(&mut executor.rng());
let entry = fs.entry(&repository_dir_path)?;
checkpoints.lock().insert(oid, entry);
Ok(GitRepositoryCheckpoint { commit_sha: oid })
}
.boxed()
}
fn restore_checkpoint(
&self,
_checkpoint: GitRepositoryCheckpoint,
) -> BoxFuture<'_, Result<()>> {
unimplemented!()
fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> {
let executor = self.executor.clone();
let fs = self.fs.clone();
let checkpoints = self.checkpoints.clone();
let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
async move {
executor.simulate_random_delay().await;
let checkpoints = checkpoints.lock();
let entry = checkpoints
.get(&checkpoint.commit_sha)
.context(format!("invalid checkpoint: {}", checkpoint.commit_sha))?;
fs.insert_entry(&repository_dir_path, entry.clone())?;
Ok(())
}
.boxed()
}
fn compare_checkpoints(
&self,
_left: GitRepositoryCheckpoint,
_right: GitRepositoryCheckpoint,
left: GitRepositoryCheckpoint,
right: GitRepositoryCheckpoint,
) -> BoxFuture<'_, Result<bool>> {
unimplemented!()
let executor = self.executor.clone();
let checkpoints = self.checkpoints.clone();
async move {
executor.simulate_random_delay().await;
let checkpoints = checkpoints.lock();
let left = checkpoints
.get(&left.commit_sha)
.context(format!("invalid left checkpoint: {}", left.commit_sha))?;
let right = checkpoints
.get(&right.commit_sha)
.context(format!("invalid right checkpoint: {}", right.commit_sha))?;
Ok(left == right)
}
.boxed()
}
fn diff_checkpoints(
@@ -496,3 +537,63 @@ impl GitRepository for FakeGitRepository {
unimplemented!()
}
}
#[cfg(test)]
mod tests {
use crate::{FakeFs, Fs};
use gpui::BackgroundExecutor;
use serde_json::json;
use std::path::Path;
use util::path;
#[gpui::test]
async fn test_checkpoints(executor: BackgroundExecutor) {
let fs = FakeFs::new(executor);
fs.insert_tree(
path!("/"),
json!({
"bar": {
"baz": "qux"
},
"foo": {
".git": {},
"a": "lorem",
"b": "ipsum",
},
}),
)
.await;
fs.with_git_state(Path::new("/foo/.git"), true, |_git| {})
.unwrap();
let repository = fs.open_repo(Path::new("/foo/.git")).unwrap();
let checkpoint_1 = repository.checkpoint().await.unwrap();
fs.write(Path::new("/foo/b"), b"IPSUM").await.unwrap();
fs.write(Path::new("/foo/c"), b"dolor").await.unwrap();
let checkpoint_2 = repository.checkpoint().await.unwrap();
let checkpoint_3 = repository.checkpoint().await.unwrap();
assert!(
repository
.compare_checkpoints(checkpoint_2.clone(), checkpoint_3.clone())
.await
.unwrap()
);
assert!(
!repository
.compare_checkpoints(checkpoint_1.clone(), checkpoint_2.clone())
.await
.unwrap()
);
repository.restore_checkpoint(checkpoint_1).await.unwrap();
assert_eq!(
fs.files_with_contents(Path::new("")),
[
(Path::new("/bar/baz").into(), b"qux".into()),
(Path::new("/foo/a").into(), b"lorem".into()),
(Path::new("/foo/b").into(), b"ipsum".into())
]
);
}
}

View File

@@ -924,7 +924,7 @@ pub struct FakeFs {
#[cfg(any(test, feature = "test-support"))]
struct FakeFsState {
root: Arc<Mutex<FakeFsEntry>>,
root: FakeFsEntry,
next_inode: u64,
next_mtime: SystemTime,
git_event_tx: smol::channel::Sender<PathBuf>,
@@ -939,7 +939,7 @@ struct FakeFsState {
}
#[cfg(any(test, feature = "test-support"))]
#[derive(Debug)]
#[derive(Clone, Debug)]
enum FakeFsEntry {
File {
inode: u64,
@@ -953,7 +953,7 @@ enum FakeFsEntry {
inode: u64,
mtime: MTime,
len: u64,
entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
entries: BTreeMap<String, FakeFsEntry>,
git_repo_state: Option<Arc<Mutex<FakeGitRepositoryState>>>,
},
Symlink {
@@ -961,6 +961,67 @@ enum FakeFsEntry {
},
}
#[cfg(any(test, feature = "test-support"))]
impl PartialEq for FakeFsEntry {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(
Self::File {
inode: l_inode,
mtime: l_mtime,
len: l_len,
content: l_content,
git_dir_path: l_git_dir_path,
},
Self::File {
inode: r_inode,
mtime: r_mtime,
len: r_len,
content: r_content,
git_dir_path: r_git_dir_path,
},
) => {
l_inode == r_inode
&& l_mtime == r_mtime
&& l_len == r_len
&& l_content == r_content
&& l_git_dir_path == r_git_dir_path
}
(
Self::Dir {
inode: l_inode,
mtime: l_mtime,
len: l_len,
entries: l_entries,
git_repo_state: l_git_repo_state,
},
Self::Dir {
inode: r_inode,
mtime: r_mtime,
len: r_len,
entries: r_entries,
git_repo_state: r_git_repo_state,
},
) => {
let same_repo_state = match (l_git_repo_state.as_ref(), r_git_repo_state.as_ref()) {
(Some(l), Some(r)) => Arc::ptr_eq(l, r),
(None, None) => true,
_ => false,
};
l_inode == r_inode
&& l_mtime == r_mtime
&& l_len == r_len
&& l_entries == r_entries
&& same_repo_state
}
(Self::Symlink { target: l_target }, Self::Symlink { target: r_target }) => {
l_target == r_target
}
_ => false,
}
}
}
#[cfg(any(test, feature = "test-support"))]
impl FakeFsState {
fn get_and_increment_mtime(&mut self) -> MTime {
@@ -975,25 +1036,9 @@ impl FakeFsState {
inode
}
fn read_path(&self, target: &Path) -> Result<Arc<Mutex<FakeFsEntry>>> {
Ok(self
.try_read_path(target, true)
.ok_or_else(|| {
anyhow!(io::Error::new(
io::ErrorKind::NotFound,
format!("not found: {target:?}")
))
})?
.0)
}
fn try_read_path(
&self,
target: &Path,
follow_symlink: bool,
) -> Option<(Arc<Mutex<FakeFsEntry>>, PathBuf)> {
let mut path = target.to_path_buf();
fn canonicalize(&self, target: &Path, follow_symlink: bool) -> Option<PathBuf> {
let mut canonical_path = PathBuf::new();
let mut path = target.to_path_buf();
let mut entry_stack = Vec::new();
'outer: loop {
let mut path_components = path.components().peekable();
@@ -1003,7 +1048,7 @@ impl FakeFsState {
Component::Prefix(prefix_component) => prefix = Some(prefix_component),
Component::RootDir => {
entry_stack.clear();
entry_stack.push(self.root.clone());
entry_stack.push(&self.root);
canonical_path.clear();
match prefix {
Some(prefix_component) => {
@@ -1020,20 +1065,18 @@ impl FakeFsState {
canonical_path.pop();
}
Component::Normal(name) => {
let current_entry = entry_stack.last().cloned()?;
let current_entry = current_entry.lock();
if let FakeFsEntry::Dir { entries, .. } = &*current_entry {
let entry = entries.get(name.to_str().unwrap()).cloned()?;
let current_entry = *entry_stack.last()?;
if let FakeFsEntry::Dir { entries, .. } = current_entry {
let entry = entries.get(name.to_str().unwrap())?;
if path_components.peek().is_some() || follow_symlink {
let entry = entry.lock();
if let FakeFsEntry::Symlink { target, .. } = &*entry {
if let FakeFsEntry::Symlink { target, .. } = entry {
let mut target = target.clone();
target.extend(path_components);
path = target;
continue 'outer;
}
}
entry_stack.push(entry.clone());
entry_stack.push(entry);
canonical_path = canonical_path.join(name);
} else {
return None;
@@ -1043,19 +1086,72 @@ impl FakeFsState {
}
break;
}
Some((entry_stack.pop()?, canonical_path))
if entry_stack.is_empty() {
None
} else {
Some(canonical_path)
}
}
fn write_path<Fn, T>(&self, path: &Path, callback: Fn) -> Result<T>
fn try_entry(
&mut self,
target: &Path,
follow_symlink: bool,
) -> Option<(&mut FakeFsEntry, PathBuf)> {
let canonical_path = self.canonicalize(target, follow_symlink)?;
let mut components = canonical_path.components();
let Some(Component::RootDir) = components.next() else {
panic!(
"the path {:?} was not canonicalized properly {:?}",
target, canonical_path
)
};
let mut entry = &mut self.root;
for component in components {
match component {
Component::Normal(name) => {
if let FakeFsEntry::Dir { entries, .. } = entry {
entry = entries.get_mut(name.to_str().unwrap())?;
} else {
return None;
}
}
_ => {
panic!(
"the path {:?} was not canonicalized properly {:?}",
target, canonical_path
)
}
}
}
Some((entry, canonical_path))
}
fn entry(&mut self, target: &Path) -> Result<&mut FakeFsEntry> {
Ok(self
.try_entry(target, true)
.ok_or_else(|| {
anyhow!(io::Error::new(
io::ErrorKind::NotFound,
format!("not found: {target:?}")
))
})?
.0)
}
fn write_path<Fn, T>(&mut self, path: &Path, callback: Fn) -> Result<T>
where
Fn: FnOnce(btree_map::Entry<String, Arc<Mutex<FakeFsEntry>>>) -> Result<T>,
Fn: FnOnce(btree_map::Entry<String, FakeFsEntry>) -> Result<T>,
{
let path = normalize_path(path);
let filename = path.file_name().context("cannot overwrite the root")?;
let parent_path = path.parent().unwrap();
let parent = self.read_path(parent_path)?;
let mut parent = parent.lock();
let parent = self.entry(parent_path)?;
let new_entry = parent
.dir_entries(parent_path)?
.entry(filename.to_str().unwrap().into());
@@ -1105,13 +1201,13 @@ impl FakeFs {
this: this.clone(),
executor: executor.clone(),
state: Arc::new(Mutex::new(FakeFsState {
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
root: FakeFsEntry::Dir {
inode: 0,
mtime: MTime(UNIX_EPOCH),
len: 0,
entries: Default::default(),
git_repo_state: None,
})),
},
git_event_tx: tx,
next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL,
next_inode: 1,
@@ -1161,15 +1257,15 @@ impl FakeFs {
.write_path(path, move |entry| {
match entry {
btree_map::Entry::Vacant(e) => {
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
e.insert(FakeFsEntry::File {
inode: new_inode,
mtime: new_mtime,
content: Vec::new(),
len: 0,
git_dir_path: None,
})));
});
}
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut() {
FakeFsEntry::File { mtime, .. } => *mtime = new_mtime,
FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime,
FakeFsEntry::Symlink { .. } => {}
@@ -1188,7 +1284,7 @@ impl FakeFs {
pub async fn insert_symlink(&self, path: impl AsRef<Path>, target: PathBuf) {
let mut state = self.state.lock();
let path = path.as_ref();
let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
let file = FakeFsEntry::Symlink { target };
state
.write_path(path.as_ref(), move |e| match e {
btree_map::Entry::Vacant(e) => {
@@ -1221,13 +1317,13 @@ impl FakeFs {
match entry {
btree_map::Entry::Vacant(e) => {
kind = Some(PathEventKind::Created);
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
e.insert(FakeFsEntry::File {
inode: new_inode,
mtime: new_mtime,
len: new_len,
content: new_content,
git_dir_path: None,
})));
});
}
btree_map::Entry::Occupied(mut e) => {
kind = Some(PathEventKind::Changed);
@@ -1237,7 +1333,7 @@ impl FakeFs {
len,
content,
..
} = &mut *e.get_mut().lock()
} = e.get_mut()
{
*mtime = new_mtime;
*content = new_content;
@@ -1259,9 +1355,8 @@ impl FakeFs {
pub fn read_file_sync(&self, path: impl AsRef<Path>) -> Result<Vec<u8>> {
let path = path.as_ref();
let path = normalize_path(path);
let state = self.state.lock();
let entry = state.read_path(&path)?;
let entry = entry.lock();
let mut state = self.state.lock();
let entry = state.entry(&path)?;
entry.file_content(&path).cloned()
}
@@ -1269,9 +1364,8 @@ impl FakeFs {
let path = path.as_ref();
let path = normalize_path(path);
self.simulate_random_delay().await;
let state = self.state.lock();
let entry = state.read_path(&path)?;
let entry = entry.lock();
let mut state = self.state.lock();
let entry = state.entry(&path)?;
entry.file_content(&path).cloned()
}
@@ -1292,6 +1386,25 @@ impl FakeFs {
self.state.lock().flush_events(count);
}
pub(crate) fn entry(&self, target: &Path) -> Result<FakeFsEntry> {
self.state.lock().entry(target).cloned()
}
pub(crate) fn insert_entry(&self, target: &Path, new_entry: FakeFsEntry) -> Result<()> {
let mut state = self.state.lock();
state.write_path(target, |entry| {
match entry {
btree_map::Entry::Vacant(vacant_entry) => {
vacant_entry.insert(new_entry);
}
btree_map::Entry::Occupied(mut occupied_entry) => {
occupied_entry.insert(new_entry);
}
}
Ok(())
})
}
#[must_use]
pub fn insert_tree<'a>(
&'a self,
@@ -1361,20 +1474,19 @@ impl FakeFs {
F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T,
{
let mut state = self.state.lock();
let entry = state.read_path(dot_git).context("open .git")?;
let mut entry = entry.lock();
let git_event_tx = state.git_event_tx.clone();
let entry = state.entry(dot_git).context("open .git")?;
if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry {
if let FakeFsEntry::Dir { git_repo_state, .. } = entry {
let repo_state = git_repo_state.get_or_insert_with(|| {
log::debug!("insert git state for {dot_git:?}");
Arc::new(Mutex::new(FakeGitRepositoryState::new(
state.git_event_tx.clone(),
)))
Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
});
let mut repo_state = repo_state.lock();
let result = f(&mut repo_state, dot_git, dot_git);
drop(repo_state);
if emit_git_event {
state.emit_event([(dot_git, None)]);
}
@@ -1398,21 +1510,20 @@ impl FakeFs {
}
}
.clone();
drop(entry);
let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else {
let Some((git_dir_entry, canonical_path)) = state.try_entry(&path, true) else {
anyhow::bail!("pointed-to git dir {path:?} not found")
};
let FakeFsEntry::Dir {
git_repo_state,
entries,
..
} = &mut *git_dir_entry.lock()
} = git_dir_entry
else {
anyhow::bail!("gitfile points to a non-directory")
};
let common_dir = if let Some(child) = entries.get("commondir") {
Path::new(
std::str::from_utf8(child.lock().file_content("commondir".as_ref())?)
std::str::from_utf8(child.file_content("commondir".as_ref())?)
.context("commondir content")?,
)
.to_owned()
@@ -1420,15 +1531,14 @@ impl FakeFs {
canonical_path.clone()
};
let repo_state = git_repo_state.get_or_insert_with(|| {
Arc::new(Mutex::new(FakeGitRepositoryState::new(
state.git_event_tx.clone(),
)))
Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
});
let mut repo_state = repo_state.lock();
let result = f(&mut repo_state, &canonical_path, &common_dir);
if emit_git_event {
drop(repo_state);
state.emit_event([(canonical_path, None)]);
}
@@ -1655,14 +1765,12 @@ impl FakeFs {
pub fn paths(&self, include_dot_git: bool) -> Vec<PathBuf> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
queue.push_back((
PathBuf::from(util::path!("/")),
self.state.lock().root.clone(),
));
let state = &*self.state.lock();
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
while let Some((path, entry)) = queue.pop_front() {
if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
if let FakeFsEntry::Dir { entries, .. } = entry {
for (name, entry) in entries {
queue.push_back((path.join(name), entry.clone()));
queue.push_back((path.join(name), entry));
}
}
if include_dot_git
@@ -1679,14 +1787,12 @@ impl FakeFs {
pub fn directories(&self, include_dot_git: bool) -> Vec<PathBuf> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
queue.push_back((
PathBuf::from(util::path!("/")),
self.state.lock().root.clone(),
));
let state = &*self.state.lock();
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
while let Some((path, entry)) = queue.pop_front() {
if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
if let FakeFsEntry::Dir { entries, .. } = entry {
for (name, entry) in entries {
queue.push_back((path.join(name), entry.clone()));
queue.push_back((path.join(name), entry));
}
if include_dot_git
|| !path
@@ -1703,17 +1809,14 @@ impl FakeFs {
pub fn files(&self) -> Vec<PathBuf> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
queue.push_back((
PathBuf::from(util::path!("/")),
self.state.lock().root.clone(),
));
let state = &*self.state.lock();
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
while let Some((path, entry)) = queue.pop_front() {
let e = entry.lock();
match &*e {
match entry {
FakeFsEntry::File { .. } => result.push(path),
FakeFsEntry::Dir { entries, .. } => {
for (name, entry) in entries {
queue.push_back((path.join(name), entry.clone()));
queue.push_back((path.join(name), entry));
}
}
FakeFsEntry::Symlink { .. } => {}
@@ -1725,13 +1828,10 @@ impl FakeFs {
pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec<u8>)> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
queue.push_back((
PathBuf::from(util::path!("/")),
self.state.lock().root.clone(),
));
let state = &*self.state.lock();
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
while let Some((path, entry)) = queue.pop_front() {
let e = entry.lock();
match &*e {
match entry {
FakeFsEntry::File { content, .. } => {
if path.starts_with(prefix) {
result.push((path, content.clone()));
@@ -1739,7 +1839,7 @@ impl FakeFs {
}
FakeFsEntry::Dir { entries, .. } => {
for (name, entry) in entries {
queue.push_back((path.join(name), entry.clone()));
queue.push_back((path.join(name), entry));
}
}
FakeFsEntry::Symlink { .. } => {}
@@ -1805,10 +1905,7 @@ impl FakeFsEntry {
}
}
fn dir_entries(
&mut self,
path: &Path,
) -> Result<&mut BTreeMap<String, Arc<Mutex<FakeFsEntry>>>> {
fn dir_entries(&mut self, path: &Path) -> Result<&mut BTreeMap<String, FakeFsEntry>> {
if let Self::Dir { entries, .. } = self {
Ok(entries)
} else {
@@ -1855,12 +1952,12 @@ struct FakeHandle {
impl FileHandle for FakeHandle {
fn current_path(&self, fs: &Arc<dyn Fs>) -> Result<PathBuf> {
let fs = fs.as_fake();
let state = fs.state.lock();
let Some(target) = state.moves.get(&self.inode) else {
let mut state = fs.state.lock();
let Some(target) = state.moves.get(&self.inode).cloned() else {
anyhow::bail!("fake fd not moved")
};
if state.try_read_path(&target, false).is_some() {
if state.try_entry(&target, false).is_some() {
return Ok(target.clone());
}
anyhow::bail!("fake fd target not found")
@@ -1888,13 +1985,13 @@ impl Fs for FakeFs {
state.write_path(&cur_path, |entry| {
entry.or_insert_with(|| {
created_dirs.push((cur_path.clone(), Some(PathEventKind::Created)));
Arc::new(Mutex::new(FakeFsEntry::Dir {
FakeFsEntry::Dir {
inode,
mtime,
len: 0,
entries: Default::default(),
git_repo_state: None,
}))
}
});
Ok(())
})?
@@ -1909,13 +2006,13 @@ impl Fs for FakeFs {
let mut state = self.state.lock();
let inode = state.get_and_increment_inode();
let mtime = state.get_and_increment_mtime();
let file = Arc::new(Mutex::new(FakeFsEntry::File {
let file = FakeFsEntry::File {
inode,
mtime,
len: 0,
content: Vec::new(),
git_dir_path: None,
}));
};
let mut kind = Some(PathEventKind::Created);
state.write_path(path, |entry| {
match entry {
@@ -1939,7 +2036,7 @@ impl Fs for FakeFs {
async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> {
let mut state = self.state.lock();
let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
let file = FakeFsEntry::Symlink { target };
state
.write_path(path.as_ref(), move |e| match e {
btree_map::Entry::Vacant(e) => {
@@ -2002,7 +2099,7 @@ impl Fs for FakeFs {
}
})?;
let inode = match *moved_entry.lock() {
let inode = match moved_entry {
FakeFsEntry::File { inode, .. } => inode,
FakeFsEntry::Dir { inode, .. } => inode,
_ => 0,
@@ -2051,8 +2148,8 @@ impl Fs for FakeFs {
let mut state = self.state.lock();
let mtime = state.get_and_increment_mtime();
let inode = state.get_and_increment_inode();
let source_entry = state.read_path(&source)?;
let content = source_entry.lock().file_content(&source)?.clone();
let source_entry = state.entry(&source)?;
let content = source_entry.file_content(&source)?.clone();
let mut kind = Some(PathEventKind::Created);
state.write_path(&target, |e| match e {
btree_map::Entry::Occupied(e) => {
@@ -2066,13 +2163,13 @@ impl Fs for FakeFs {
}
}
btree_map::Entry::Vacant(e) => Ok(Some(
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
e.insert(FakeFsEntry::File {
inode,
mtime,
len: content.len() as u64,
content,
git_dir_path: None,
})))
})
.clone(),
)),
})?;
@@ -2088,8 +2185,7 @@ impl Fs for FakeFs {
let base_name = path.file_name().context("cannot remove the root")?;
let mut state = self.state.lock();
let parent_entry = state.read_path(parent_path)?;
let mut parent_entry = parent_entry.lock();
let parent_entry = state.entry(parent_path)?;
let entry = parent_entry
.dir_entries(parent_path)?
.entry(base_name.to_str().unwrap().into());
@@ -2100,15 +2196,14 @@ impl Fs for FakeFs {
anyhow::bail!("{path:?} does not exist");
}
}
btree_map::Entry::Occupied(e) => {
btree_map::Entry::Occupied(mut entry) => {
{
let mut entry = e.get().lock();
let children = entry.dir_entries(&path)?;
let children = entry.get_mut().dir_entries(&path)?;
if !options.recursive && !children.is_empty() {
anyhow::bail!("{path:?} is not empty");
}
}
e.remove();
entry.remove();
}
}
state.emit_event([(path, Some(PathEventKind::Removed))]);
@@ -2122,8 +2217,7 @@ impl Fs for FakeFs {
let parent_path = path.parent().context("cannot remove the root")?;
let base_name = path.file_name().unwrap();
let mut state = self.state.lock();
let parent_entry = state.read_path(parent_path)?;
let mut parent_entry = parent_entry.lock();
let parent_entry = state.entry(parent_path)?;
let entry = parent_entry
.dir_entries(parent_path)?
.entry(base_name.to_str().unwrap().into());
@@ -2133,9 +2227,9 @@ impl Fs for FakeFs {
anyhow::bail!("{path:?} does not exist");
}
}
btree_map::Entry::Occupied(e) => {
e.get().lock().file_content(&path)?;
e.remove();
btree_map::Entry::Occupied(mut entry) => {
entry.get_mut().file_content(&path)?;
entry.remove();
}
}
state.emit_event([(path, Some(PathEventKind::Removed))]);
@@ -2149,12 +2243,10 @@ impl Fs for FakeFs {
async fn open_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>> {
self.simulate_random_delay().await;
let state = self.state.lock();
let entry = state.read_path(&path)?;
let entry = entry.lock();
let inode = match *entry {
FakeFsEntry::File { inode, .. } => inode,
FakeFsEntry::Dir { inode, .. } => inode,
let mut state = self.state.lock();
let inode = match state.entry(&path)? {
FakeFsEntry::File { inode, .. } => *inode,
FakeFsEntry::Dir { inode, .. } => *inode,
_ => unreachable!(),
};
Ok(Arc::new(FakeHandle { inode }))
@@ -2204,8 +2296,8 @@ impl Fs for FakeFs {
let path = normalize_path(path);
self.simulate_random_delay().await;
let state = self.state.lock();
let (_, canonical_path) = state
.try_read_path(&path, true)
let canonical_path = state
.canonicalize(&path, true)
.with_context(|| format!("path does not exist: {path:?}"))?;
Ok(canonical_path)
}
@@ -2213,9 +2305,9 @@ impl Fs for FakeFs {
async fn is_file(&self, path: &Path) -> bool {
let path = normalize_path(path);
self.simulate_random_delay().await;
let state = self.state.lock();
if let Some((entry, _)) = state.try_read_path(&path, true) {
entry.lock().is_file()
let mut state = self.state.lock();
if let Some((entry, _)) = state.try_entry(&path, true) {
entry.is_file()
} else {
false
}
@@ -2232,17 +2324,16 @@ impl Fs for FakeFs {
let path = normalize_path(path);
let mut state = self.state.lock();
state.metadata_call_count += 1;
if let Some((mut entry, _)) = state.try_read_path(&path, false) {
let is_symlink = entry.lock().is_symlink();
if let Some((mut entry, _)) = state.try_entry(&path, false) {
let is_symlink = entry.is_symlink();
if is_symlink {
if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) {
if let Some(e) = state.try_entry(&path, true).map(|e| e.0) {
entry = e;
} else {
return Ok(None);
}
}
let entry = entry.lock();
Ok(Some(match &*entry {
FakeFsEntry::File {
inode, mtime, len, ..
@@ -2274,12 +2365,11 @@ impl Fs for FakeFs {
async fn read_link(&self, path: &Path) -> Result<PathBuf> {
self.simulate_random_delay().await;
let path = normalize_path(path);
let state = self.state.lock();
let mut state = self.state.lock();
let (entry, _) = state
.try_read_path(&path, false)
.try_entry(&path, false)
.with_context(|| format!("path does not exist: {path:?}"))?;
let entry = entry.lock();
if let FakeFsEntry::Symlink { target } = &*entry {
if let FakeFsEntry::Symlink { target } = entry {
Ok(target.clone())
} else {
anyhow::bail!("not a symlink: {path:?}")
@@ -2294,8 +2384,7 @@ impl Fs for FakeFs {
let path = normalize_path(path);
let mut state = self.state.lock();
state.read_dir_call_count += 1;
let entry = state.read_path(&path)?;
let mut entry = entry.lock();
let entry = state.entry(&path)?;
let children = entry.dir_entries(&path)?;
let paths = children
.keys()
@@ -2359,6 +2448,7 @@ impl Fs for FakeFs {
dot_git_path: abs_dot_git.to_path_buf(),
repository_dir_path: repository_dir_path.to_owned(),
common_dir_path: common_dir_path.to_owned(),
checkpoints: Arc::default(),
}) as _
},
)

View File

@@ -12,7 +12,7 @@ workspace = true
path = "src/git.rs"
[features]
test-support = []
test-support = ["rand"]
[dependencies]
anyhow.workspace = true
@@ -26,6 +26,7 @@ http_client.workspace = true
log.workspace = true
parking_lot.workspace = true
regex.workspace = true
rand = { workspace = true, optional = true }
rope.workspace = true
schemars.workspace = true
serde.workspace = true
@@ -47,3 +48,4 @@ text = { workspace = true, features = ["test-support"] }
unindent.workspace = true
gpui = { workspace = true, features = ["test-support"] }
tempfile.workspace = true
rand.workspace = true

View File

@@ -73,6 +73,7 @@ async fn run_git_blame(
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("-w")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())

View File

@@ -119,6 +119,13 @@ impl Oid {
Ok(Self(oid))
}
#[cfg(any(test, feature = "test-support"))]
pub fn random(rng: &mut impl rand::Rng) -> Self {
let mut bytes = [0; 20];
rng.fill(&mut bytes);
Self::from_bytes(&bytes).unwrap()
}
pub fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}

View File

@@ -6,7 +6,7 @@ use collections::HashMap;
use futures::future::BoxFuture;
use futures::{AsyncWriteExt, FutureExt as _, select_biased};
use git2::BranchType;
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, SharedString};
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, SharedString, Task};
use parking_lot::Mutex;
use rope::Rope;
use schemars::JsonSchema;
@@ -338,7 +338,7 @@ pub trait GitRepository: Send + Sync {
fn merge_message(&self) -> BoxFuture<'_, Option<String>>;
fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result<GitStatus>>;
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>>;
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>>;
@@ -953,25 +953,27 @@ impl GitRepository for RealGitRepository {
.boxed()
}
fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result<GitStatus>> {
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>> {
let git_binary_path = self.git_binary_path.clone();
let working_directory = self.working_directory();
let path_prefixes = path_prefixes.to_owned();
self.executor
.spawn(async move {
let output = new_std_command(&git_binary_path)
.current_dir(working_directory?)
.args(git_status_args(&path_prefixes))
.output()?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
stdout.parse()
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("git status failed: {stderr}");
}
})
.boxed()
let working_directory = match self.working_directory() {
Ok(working_directory) => working_directory,
Err(e) => return Task::ready(Err(e)),
};
let args = git_status_args(&path_prefixes);
log::debug!("Checking for git status in {path_prefixes:?}");
self.executor.spawn(async move {
let output = new_std_command(&git_binary_path)
.current_dir(working_directory)
.args(args)
.output()?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
stdout.parse()
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("git status failed: {stderr}");
}
})
}
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>> {

View File

@@ -2105,7 +2105,7 @@ impl GitPanel {
Ok(_) => cx.update(|window, cx| {
window.prompt(
PromptLevel::Info,
"Git Clone",
&format!("Git Clone: {}", repo_name),
None,
&["Add repo to project", "Open repo in new project"],
cx,

View File

@@ -181,10 +181,6 @@ pub fn init(cx: &mut App) {
workspace.toggle_modal(window, cx, |window, cx| {
GitCloneModal::show(panel, window, cx)
});
// panel.update(cx, |panel, cx| {
// panel.git_clone(window, cx);
// });
});
workspace.register_action(|workspace, _: &git::OpenModifiedFiles, window, cx| {
open_modified_files(workspace, window, cx);

View File

@@ -1,5 +1,6 @@
use gpui::{
App, Application, Context, Menu, MenuItem, Window, WindowOptions, actions, div, prelude::*, rgb,
App, Application, Context, Menu, MenuItem, SystemMenuType, Window, WindowOptions, actions, div,
prelude::*, rgb,
};
struct SetMenus;
@@ -27,7 +28,11 @@ fn main() {
// Add menu items
cx.set_menus(vec![Menu {
name: "set_menus".into(),
items: vec![MenuItem::action("Quit", Quit)],
items: vec![
MenuItem::os_submenu("Services", SystemMenuType::Services),
MenuItem::separator(),
MenuItem::action("Quit", Quit),
],
}]);
cx.open_window(WindowOptions::default(), |_, cx| cx.new(|_| SetMenus {}))
.unwrap();

View File

@@ -277,6 +277,8 @@ pub struct App {
pub(crate) release_listeners: SubscriberSet<EntityId, ReleaseListener>,
pub(crate) global_observers: SubscriberSet<TypeId, Handler>,
pub(crate) quit_observers: SubscriberSet<(), QuitHandler>,
pub(crate) restart_observers: SubscriberSet<(), Handler>,
pub(crate) restart_path: Option<PathBuf>,
pub(crate) window_closed_observers: SubscriberSet<(), WindowClosedHandler>,
pub(crate) layout_id_buffer: Vec<LayoutId>, // We recycle this memory across layout requests.
pub(crate) propagate_event: bool,
@@ -349,6 +351,8 @@ impl App {
keyboard_layout_observers: SubscriberSet::new(),
global_observers: SubscriberSet::new(),
quit_observers: SubscriberSet::new(),
restart_observers: SubscriberSet::new(),
restart_path: None,
window_closed_observers: SubscriberSet::new(),
layout_id_buffer: Default::default(),
propagate_event: true,
@@ -832,8 +836,16 @@ impl App {
}
/// Restarts the application.
pub fn restart(&self, binary_path: Option<PathBuf>) {
self.platform.restart(binary_path)
pub fn restart(&mut self) {
self.restart_observers
.clone()
.retain(&(), |observer| observer(self));
self.platform.restart(self.restart_path.take())
}
/// Sets the path to use when restarting the application.
pub fn set_restart_path(&mut self, path: PathBuf) {
self.restart_path = Some(path);
}
/// Returns the HTTP client for the application.
@@ -1466,6 +1478,21 @@ impl App {
subscription
}
/// Register a callback to be invoked when the application is about to restart.
///
/// These callbacks are called before any `on_app_quit` callbacks.
pub fn on_app_restart(&self, mut on_restart: impl 'static + FnMut(&mut App)) -> Subscription {
let (subscription, activate) = self.restart_observers.insert(
(),
Box::new(move |cx| {
on_restart(cx);
true
}),
);
activate();
subscription
}
/// Register a callback to be invoked when a window is closed
/// The window is no longer accessible at the point this callback is invoked.
pub fn on_window_closed(&self, mut on_closed: impl FnMut(&mut App) + 'static) -> Subscription {

View File

@@ -164,6 +164,20 @@ impl<'a, T: 'static> Context<'a, T> {
subscription
}
/// Register a callback to be invoked when the application is about to restart.
pub fn on_app_restart(
&self,
mut on_restart: impl FnMut(&mut T, &mut App) + 'static,
) -> Subscription
where
T: 'static,
{
let handle = self.weak_entity();
self.app.on_app_restart(move |cx| {
handle.update(cx, |entity, cx| on_restart(entity, cx)).ok();
})
}
/// Arrange for the given function to be invoked whenever the application is quit.
/// The future returned from this callback will be polled for up to [crate::SHUTDOWN_TIMEOUT] until the app fully quits.
pub fn on_app_quit<Fut>(
@@ -175,20 +189,15 @@ impl<'a, T: 'static> Context<'a, T> {
T: 'static,
{
let handle = self.weak_entity();
let (subscription, activate) = self.app.quit_observers.insert(
(),
Box::new(move |cx| {
let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok();
async move {
if let Some(future) = future {
future.await;
}
self.app.on_app_quit(move |cx| {
let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok();
async move {
if let Some(future) = future {
future.await;
}
.boxed_local()
}),
);
activate();
subscription
}
.boxed_local()
})
}
/// Tell GPUI that this entity has changed and observers of it should be notified.

View File

@@ -20,6 +20,34 @@ impl Menu {
}
}
/// OS menus are menus that are recognized by the operating system
/// This allows the operating system to provide specialized items for
/// these menus
pub struct OsMenu {
/// The name of the menu
pub name: SharedString,
/// The type of menu
pub menu_type: SystemMenuType,
}
impl OsMenu {
/// Create an OwnedOsMenu from this OsMenu
pub fn owned(self) -> OwnedOsMenu {
OwnedOsMenu {
name: self.name.to_string().into(),
menu_type: self.menu_type,
}
}
}
/// The type of system menu
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum SystemMenuType {
/// The 'Services' menu in the Application menu on macOS
Services,
}
/// The different kinds of items that can be in a menu
pub enum MenuItem {
/// A separator between items
@@ -28,6 +56,9 @@ pub enum MenuItem {
/// A submenu
Submenu(Menu),
/// A menu, managed by the system (for example, the Services menu on macOS)
SystemMenu(OsMenu),
/// An action that can be performed
Action {
/// The name of this menu item
@@ -53,6 +84,14 @@ impl MenuItem {
Self::Submenu(menu)
}
/// Creates a new submenu that is populated by the OS
pub fn os_submenu(name: impl Into<SharedString>, menu_type: SystemMenuType) -> Self {
Self::SystemMenu(OsMenu {
name: name.into(),
menu_type,
})
}
/// Creates a new menu item that invokes an action
pub fn action(name: impl Into<SharedString>, action: impl Action) -> Self {
Self::Action {
@@ -89,10 +128,23 @@ impl MenuItem {
action,
os_action,
},
MenuItem::SystemMenu(os_menu) => OwnedMenuItem::SystemMenu(os_menu.owned()),
}
}
}
/// OS menus are menus that are recognized by the operating system
/// This allows the operating system to provide specialized items for
/// these menus
#[derive(Clone)]
pub struct OwnedOsMenu {
/// The name of the menu
pub name: SharedString,
/// The type of menu
pub menu_type: SystemMenuType,
}
/// A menu of the application, either a main menu or a submenu
#[derive(Clone)]
pub struct OwnedMenu {
@@ -111,6 +163,9 @@ pub enum OwnedMenuItem {
/// A submenu
Submenu(OwnedMenu),
/// A menu, managed by the system (for example, the Services menu on macOS)
SystemMenu(OwnedOsMenu),
/// An action that can be performed
Action {
/// The name of this menu item
@@ -139,6 +194,7 @@ impl Clone for OwnedMenuItem {
action: action.boxed_clone(),
os_action: *os_action,
},
OwnedMenuItem::SystemMenu(os_menu) => OwnedMenuItem::SystemMenu(os_menu.clone()),
}
}
}

View File

@@ -7,9 +7,9 @@ use super::{
use crate::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardEntry, ClipboardItem, ClipboardString,
CursorStyle, ForegroundExecutor, Image, ImageFormat, KeyContext, Keymap, MacDispatcher,
MacDisplay, MacWindow, Menu, MenuItem, OwnedMenu, PathPromptOptions, Platform, PlatformDisplay,
PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, Result, SemanticVersion, Task,
WindowAppearance, WindowParams, hash,
MacDisplay, MacWindow, Menu, MenuItem, OsMenu, OwnedMenu, PathPromptOptions, Platform,
PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, Result,
SemanticVersion, SystemMenuType, Task, WindowAppearance, WindowParams, hash,
};
use anyhow::{Context as _, anyhow};
use block::ConcreteBlock;
@@ -413,9 +413,20 @@ impl MacPlatform {
}
item.setSubmenu_(submenu);
item.setTitle_(ns_string(&name));
if name == "Services" {
let app: id = msg_send![APP_CLASS, sharedApplication];
app.setServicesMenu_(item);
item
}
MenuItem::SystemMenu(OsMenu { name, menu_type }) => {
let item = NSMenuItem::new(nil).autorelease();
let submenu = NSMenu::new(nil).autorelease();
submenu.setDelegate_(delegate);
item.setSubmenu_(submenu);
item.setTitle_(ns_string(&name));
match menu_type {
SystemMenuType::Services => {
let app: id = msg_send![APP_CLASS, sharedApplication];
app.setServicesMenu_(item);
}
}
item

View File

@@ -10,6 +10,7 @@ mod keyboard;
mod platform;
mod system_settings;
mod util;
mod vsync;
mod window;
mod wrapper;
@@ -25,6 +26,7 @@ pub(crate) use keyboard::*;
pub(crate) use platform::*;
pub(crate) use system_settings::*;
pub(crate) use util::*;
pub(crate) use vsync::*;
pub(crate) use window::*;
pub(crate) use wrapper::*;

View File

@@ -4,16 +4,15 @@ use ::util::ResultExt;
use anyhow::{Context, Result};
use windows::{
Win32::{
Foundation::{FreeLibrary, HMODULE, HWND},
Foundation::{HMODULE, HWND},
Graphics::{
Direct3D::*,
Direct3D11::*,
DirectComposition::*,
Dxgi::{Common::*, *},
},
System::LibraryLoader::LoadLibraryA,
},
core::{Interface, PCSTR},
core::Interface,
};
use crate::{
@@ -208,7 +207,7 @@ impl DirectXRenderer {
fn present(&mut self) -> Result<()> {
unsafe {
let result = self.resources.swap_chain.Present(1, DXGI_PRESENT(0));
let result = self.resources.swap_chain.Present(0, DXGI_PRESENT(0));
// Presenting the swap chain can fail if the DirectX device was removed or reset.
if result == DXGI_ERROR_DEVICE_REMOVED || result == DXGI_ERROR_DEVICE_RESET {
let reason = self.devices.device.GetDeviceRemovedReason();
@@ -1619,22 +1618,6 @@ pub(crate) mod shader_resources {
}
}
fn with_dll_library<R, F>(dll_name: PCSTR, f: F) -> Result<R>
where
F: FnOnce(HMODULE) -> Result<R>,
{
let library = unsafe {
LoadLibraryA(dll_name).with_context(|| format!("Loading dll: {}", dll_name.display()))?
};
let result = f(library);
unsafe {
FreeLibrary(library)
.with_context(|| format!("Freeing dll: {}", dll_name.display()))
.log_err();
}
result
}
mod nvidia {
use std::{
ffi::CStr,
@@ -1644,7 +1627,7 @@ mod nvidia {
use anyhow::Result;
use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s};
use crate::platform::windows::directx_renderer::with_dll_library;
use crate::with_dll_library;
// https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180
const NVAPI_SHORT_STRING_MAX: usize = 64;
@@ -1711,7 +1694,7 @@ mod amd {
use anyhow::Result;
use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s};
use crate::platform::windows::directx_renderer::with_dll_library;
use crate::with_dll_library;
// https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145
const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12);

View File

@@ -32,7 +32,7 @@ use crate::*;
pub(crate) struct WindowsPlatform {
state: RefCell<WindowsPlatformState>,
raw_window_handles: RwLock<SmallVec<[HWND; 4]>>,
raw_window_handles: Arc<RwLock<SmallVec<[SafeHwnd; 4]>>>,
// The below members will never change throughout the entire lifecycle of the app.
icon: HICON,
main_receiver: flume::Receiver<Runnable>,
@@ -114,7 +114,7 @@ impl WindowsPlatform {
};
let icon = load_icon().unwrap_or_default();
let state = RefCell::new(WindowsPlatformState::new());
let raw_window_handles = RwLock::new(SmallVec::new());
let raw_window_handles = Arc::new(RwLock::new(SmallVec::new()));
let windows_version = WindowsVersion::new().context("Error retrieve windows version")?;
Ok(Self {
@@ -134,22 +134,12 @@ impl WindowsPlatform {
})
}
fn redraw_all(&self) {
for handle in self.raw_window_handles.read().iter() {
unsafe {
RedrawWindow(Some(*handle), None, None, RDW_INVALIDATE | RDW_UPDATENOW)
.ok()
.log_err();
}
}
}
pub fn window_from_hwnd(&self, hwnd: HWND) -> Option<Rc<WindowsWindowInner>> {
self.raw_window_handles
.read()
.iter()
.find(|entry| *entry == &hwnd)
.and_then(|hwnd| window_from_hwnd(*hwnd))
.find(|entry| entry.as_raw() == hwnd)
.and_then(|hwnd| window_from_hwnd(hwnd.as_raw()))
}
#[inline]
@@ -158,7 +148,7 @@ impl WindowsPlatform {
.read()
.iter()
.for_each(|handle| unsafe {
PostMessageW(Some(*handle), message, wparam, lparam).log_err();
PostMessageW(Some(handle.as_raw()), message, wparam, lparam).log_err();
});
}
@@ -166,7 +156,7 @@ impl WindowsPlatform {
let mut lock = self.raw_window_handles.write();
let index = lock
.iter()
.position(|handle| *handle == target_window)
.position(|handle| handle.as_raw() == target_window)
.unwrap();
lock.remove(index);
@@ -226,19 +216,19 @@ impl WindowsPlatform {
}
}
// Returns true if the app should quit.
fn handle_events(&self) -> bool {
// Returns if the app should quit.
fn handle_events(&self) {
let mut msg = MSG::default();
unsafe {
while PeekMessageW(&mut msg, None, 0, 0, PM_REMOVE).as_bool() {
while GetMessageW(&mut msg, None, 0, 0).as_bool() {
match msg.message {
WM_QUIT => return true,
WM_QUIT => return,
WM_INPUTLANGCHANGE
| WM_GPUI_CLOSE_ONE_WINDOW
| WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD
| WM_GPUI_DOCK_MENU_ACTION => {
if self.handle_gpui_evnets(msg.message, msg.wParam, msg.lParam, &msg) {
return true;
return;
}
}
_ => {
@@ -247,7 +237,6 @@ impl WindowsPlatform {
}
}
}
false
}
// Returns true if the app should quit.
@@ -315,8 +304,28 @@ impl WindowsPlatform {
self.raw_window_handles
.read()
.iter()
.find(|&&hwnd| hwnd == active_window_hwnd)
.copied()
.find(|hwnd| hwnd.as_raw() == active_window_hwnd)
.map(|hwnd| hwnd.as_raw())
}
fn begin_vsync_thread(&self) {
let all_windows = Arc::downgrade(&self.raw_window_handles);
std::thread::spawn(move || {
let vsync_provider = VSyncProvider::new();
loop {
vsync_provider.wait_for_vsync();
let Some(all_windows) = all_windows.upgrade() else {
break;
};
for hwnd in all_windows.read().iter() {
unsafe {
RedrawWindow(Some(hwnd.as_raw()), None, None, RDW_INVALIDATE)
.ok()
.log_err();
}
}
}
});
}
}
@@ -347,12 +356,8 @@ impl Platform for WindowsPlatform {
fn run(&self, on_finish_launching: Box<dyn 'static + FnOnce()>) {
on_finish_launching();
loop {
if self.handle_events() {
break;
}
self.redraw_all();
}
self.begin_vsync_thread();
self.handle_events();
if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit {
callback();
@@ -365,9 +370,9 @@ impl Platform for WindowsPlatform {
.detach();
}
fn restart(&self, _: Option<PathBuf>) {
fn restart(&self, binary_path: Option<PathBuf>) {
let pid = std::process::id();
let Some(app_path) = self.app_path().log_err() else {
let Some(app_path) = binary_path.or(self.app_path().log_err()) else {
return;
};
let script = format!(
@@ -445,7 +450,7 @@ impl Platform for WindowsPlatform {
) -> Result<Box<dyn PlatformWindow>> {
let window = WindowsWindow::new(handle, options, self.generate_creation_info())?;
let handle = window.get_raw_handle();
self.raw_window_handles.write().push(handle);
self.raw_window_handles.write().push(handle.into());
Ok(Box::new(window))
}

View File

@@ -1,14 +1,18 @@
use std::sync::OnceLock;
use ::util::ResultExt;
use anyhow::Context;
use windows::{
UI::{
Color,
ViewManagement::{UIColorType, UISettings},
},
Wdk::System::SystemServices::RtlGetVersion,
Win32::{Foundation::*, Graphics::Dwm::*, UI::WindowsAndMessaging::*},
core::{BOOL, HSTRING},
Win32::{
Foundation::*, Graphics::Dwm::*, System::LibraryLoader::LoadLibraryA,
UI::WindowsAndMessaging::*,
},
core::{BOOL, HSTRING, PCSTR},
};
use crate::*;
@@ -197,3 +201,19 @@ pub(crate) fn show_error(title: &str, content: String) {
)
};
}
pub(crate) fn with_dll_library<R, F>(dll_name: PCSTR, f: F) -> Result<R>
where
F: FnOnce(HMODULE) -> Result<R>,
{
let library = unsafe {
LoadLibraryA(dll_name).with_context(|| format!("Loading dll: {}", dll_name.display()))?
};
let result = f(library);
unsafe {
FreeLibrary(library)
.with_context(|| format!("Freeing dll: {}", dll_name.display()))
.log_err();
}
result
}

View File

@@ -0,0 +1,174 @@
use std::{
sync::LazyLock,
time::{Duration, Instant},
};
use anyhow::{Context, Result};
use util::ResultExt;
use windows::{
Win32::{
Foundation::{HANDLE, HWND},
Graphics::{
DirectComposition::{
COMPOSITION_FRAME_ID_COMPLETED, COMPOSITION_FRAME_ID_TYPE, COMPOSITION_FRAME_STATS,
COMPOSITION_TARGET_ID,
},
Dwm::{DWM_TIMING_INFO, DwmFlush, DwmGetCompositionTimingInfo},
},
System::{
LibraryLoader::{GetModuleHandleA, GetProcAddress},
Performance::QueryPerformanceFrequency,
Threading::INFINITE,
},
},
core::{HRESULT, s},
};
static QPC_TICKS_PER_SECOND: LazyLock<u64> = LazyLock::new(|| {
let mut frequency = 0;
// On systems that run Windows XP or later, the function will always succeed and
// will thus never return zero.
unsafe { QueryPerformanceFrequency(&mut frequency).unwrap() };
frequency as u64
});
const VSYNC_INTERVAL_THRESHOLD: Duration = Duration::from_millis(1);
const DEFAULT_VSYNC_INTERVAL: Duration = Duration::from_micros(16_666); // ~60Hz
// Here we are using dynamic loading of DirectComposition functions,
// or the app will refuse to start on windows systems that do not support DirectComposition.
type DCompositionGetFrameId =
unsafe extern "system" fn(frameidtype: COMPOSITION_FRAME_ID_TYPE, frameid: *mut u64) -> HRESULT;
type DCompositionGetStatistics = unsafe extern "system" fn(
frameid: u64,
framestats: *mut COMPOSITION_FRAME_STATS,
targetidcount: u32,
targetids: *mut COMPOSITION_TARGET_ID,
actualtargetidcount: *mut u32,
) -> HRESULT;
type DCompositionWaitForCompositorClock =
unsafe extern "system" fn(count: u32, handles: *const HANDLE, timeoutinms: u32) -> u32;
pub(crate) struct VSyncProvider {
interval: Duration,
f: Box<dyn Fn() -> bool>,
}
impl VSyncProvider {
pub(crate) fn new() -> Self {
if let Some((get_frame_id, get_statistics, wait_for_comp_clock)) =
initialize_direct_composition()
.context("Retrieving DirectComposition functions")
.log_with_level(log::Level::Warn)
{
let interval = get_dwm_interval_from_direct_composition(get_frame_id, get_statistics)
.context("Failed to get DWM interval from DirectComposition")
.log_err()
.unwrap_or(DEFAULT_VSYNC_INTERVAL);
log::info!(
"DirectComposition is supported for VSync, interval: {:?}",
interval
);
let f = Box::new(move || unsafe {
wait_for_comp_clock(0, std::ptr::null(), INFINITE) == 0
});
Self { interval, f }
} else {
let interval = get_dwm_interval()
.context("Failed to get DWM interval")
.log_err()
.unwrap_or(DEFAULT_VSYNC_INTERVAL);
log::info!(
"DirectComposition is not supported for VSync, falling back to DWM, interval: {:?}",
interval
);
let f = Box::new(|| unsafe { DwmFlush().is_ok() });
Self { interval, f }
}
}
pub(crate) fn wait_for_vsync(&self) {
let vsync_start = Instant::now();
let wait_succeeded = (self.f)();
let elapsed = vsync_start.elapsed();
// DwmFlush and DCompositionWaitForCompositorClock returns very early
// instead of waiting until vblank when the monitor goes to sleep or is
// unplugged (nothing to present due to desktop occlusion). We use 1ms as
// a threshhold for the duration of the wait functions and fallback to
// Sleep() if it returns before that. This could happen during normal
// operation for the first call after the vsync thread becomes non-idle,
// but it shouldn't happen often.
if !wait_succeeded || elapsed < VSYNC_INTERVAL_THRESHOLD {
log::warn!("VSyncProvider::wait_for_vsync() took shorter than expected");
std::thread::sleep(self.interval);
}
}
}
fn initialize_direct_composition() -> Result<(
DCompositionGetFrameId,
DCompositionGetStatistics,
DCompositionWaitForCompositorClock,
)> {
unsafe {
// Load DLL at runtime since older Windows versions don't have dcomp.
let hmodule = GetModuleHandleA(s!("dcomp.dll")).context("Loading dcomp.dll")?;
let get_frame_id_addr = GetProcAddress(hmodule, s!("DCompositionGetFrameId"))
.context("Function DCompositionGetFrameId not found")?;
let get_statistics_addr = GetProcAddress(hmodule, s!("DCompositionGetStatistics"))
.context("Function DCompositionGetStatistics not found")?;
let wait_for_compositor_clock_addr =
GetProcAddress(hmodule, s!("DCompositionWaitForCompositorClock"))
.context("Function DCompositionWaitForCompositorClock not found")?;
let get_frame_id: DCompositionGetFrameId = std::mem::transmute(get_frame_id_addr);
let get_statistics: DCompositionGetStatistics = std::mem::transmute(get_statistics_addr);
let wait_for_compositor_clock: DCompositionWaitForCompositorClock =
std::mem::transmute(wait_for_compositor_clock_addr);
Ok((get_frame_id, get_statistics, wait_for_compositor_clock))
}
}
fn get_dwm_interval_from_direct_composition(
get_frame_id: DCompositionGetFrameId,
get_statistics: DCompositionGetStatistics,
) -> Result<Duration> {
let mut frame_id = 0;
unsafe { get_frame_id(COMPOSITION_FRAME_ID_COMPLETED, &mut frame_id) }.ok()?;
let mut stats = COMPOSITION_FRAME_STATS::default();
unsafe {
get_statistics(
frame_id,
&mut stats,
0,
std::ptr::null_mut(),
std::ptr::null_mut(),
)
}
.ok()?;
Ok(retrieve_duration(stats.framePeriod, *QPC_TICKS_PER_SECOND))
}
fn get_dwm_interval() -> Result<Duration> {
let mut timing_info = DWM_TIMING_INFO {
cbSize: std::mem::size_of::<DWM_TIMING_INFO>() as u32,
..Default::default()
};
unsafe { DwmGetCompositionTimingInfo(HWND::default(), &mut timing_info) }?;
let interval = retrieve_duration(timing_info.qpcRefreshPeriod, *QPC_TICKS_PER_SECOND);
// Check for interval values that are impossibly low. A 29 microsecond
// interval was seen (from a qpcRefreshPeriod of 60).
if interval < VSYNC_INTERVAL_THRESHOLD {
Ok(retrieve_duration(
timing_info.rateRefresh.uiDenominator as u64,
timing_info.rateRefresh.uiNumerator as u64,
))
} else {
Ok(interval)
}
}
#[inline]
fn retrieve_duration(counts: u64, ticks_per_second: u64) -> Duration {
let ticks_per_microsecond = ticks_per_second / 1_000_000;
Duration::from_micros(counts / ticks_per_microsecond)
}

View File

@@ -1,6 +1,6 @@
use std::ops::Deref;
use windows::Win32::UI::WindowsAndMessaging::HCURSOR;
use windows::Win32::{Foundation::HWND, UI::WindowsAndMessaging::HCURSOR};
#[derive(Debug, Clone, Copy)]
pub(crate) struct SafeCursor {
@@ -23,3 +23,31 @@ impl Deref for SafeCursor {
&self.raw
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct SafeHwnd {
raw: HWND,
}
impl SafeHwnd {
pub(crate) fn as_raw(&self) -> HWND {
self.raw
}
}
unsafe impl Send for SafeHwnd {}
unsafe impl Sync for SafeHwnd {}
impl From<HWND> for SafeHwnd {
fn from(value: HWND) -> Self {
SafeHwnd { raw: value }
}
}
impl Deref for SafeHwnd {
type Target = HWND;
fn deref(&self) -> &Self::Target {
&self.raw
}
}

View File

@@ -167,6 +167,7 @@ fn generate_test_function(
));
cx_teardowns.extend(quote!(
dispatcher.run_until_parked();
#cx_varname.executor().forbid_parking();
#cx_varname.quit();
dispatcher.run_until_parked();
));
@@ -232,7 +233,7 @@ fn generate_test_function(
cx_teardowns.extend(quote!(
drop(#cx_varname_lock);
dispatcher.run_until_parked();
#cx_varname.update(|cx| { cx.quit() });
#cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); });
dispatcher.run_until_parked();
));
continue;
@@ -247,6 +248,7 @@ fn generate_test_function(
));
cx_teardowns.extend(quote!(
dispatcher.run_until_parked();
#cx_varname.executor().forbid_parking();
#cx_varname.quit();
dispatcher.run_until_parked();
));

View File

@@ -942,6 +942,7 @@ impl LanguageModel for CloudLanguageModel {
model.id(),
model.supports_parallel_tool_calls(),
None,
None,
);
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {

View File

@@ -14,7 +14,7 @@ use language_model::{
RateLimiter, Role, StopReason, TokenUsage,
};
use menu;
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
@@ -45,6 +45,7 @@ pub struct AvailableModel {
pub max_tokens: u64,
pub max_output_tokens: Option<u64>,
pub max_completion_tokens: Option<u64>,
pub reasoning_effort: Option<ReasoningEffort>,
}
pub struct OpenAiLanguageModelProvider {
@@ -213,6 +214,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
reasoning_effort: model.reasoning_effort.clone(),
},
);
}
@@ -301,7 +303,25 @@ impl LanguageModel for OpenAiLanguageModel {
}
fn supports_images(&self) -> bool {
false
use open_ai::Model;
match &self.model {
Model::FourOmni
| Model::FourOmniMini
| Model::FourPointOne
| Model::FourPointOneMini
| Model::FourPointOneNano
| Model::Five
| Model::FiveMini
| Model::FiveNano
| Model::O1
| Model::O3
| Model::O4Mini => true,
Model::ThreePointFiveTurbo
| Model::Four
| Model::FourTurbo
| Model::O3Mini
| Model::Custom { .. } => false,
}
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
@@ -351,6 +371,7 @@ impl LanguageModel for OpenAiLanguageModel {
self.model.id(),
self.model.supports_parallel_tool_calls(),
self.max_output_tokens(),
self.model.reasoning_effort(),
);
let completions = self.stream_completion(request, cx);
async move {
@@ -366,6 +387,7 @@ pub fn into_open_ai(
model_id: &str,
supports_parallel_tool_calls: bool,
max_output_tokens: Option<u64>,
reasoning_effort: Option<ReasoningEffort>,
) -> open_ai::Request {
let stream = !model_id.starts_with("o1-");
@@ -455,6 +477,7 @@ pub fn into_open_ai(
} else {
None
},
prompt_cache_key: request.thread_id,
tools: request
.tools
.into_iter()
@@ -471,6 +494,7 @@ pub fn into_open_ai(
LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
LanguageModelToolChoice::None => open_ai::ToolChoice::None,
}),
reasoning_effort,
}
}

View File

@@ -355,7 +355,13 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
LanguageModelCompletionError,
>,
> {
let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens());
let request = into_open_ai(
request,
&self.model.name,
true,
self.max_output_tokens(),
None,
);
let completions = self.stream_completion(request, cx);
async move {
let mapper = OpenAiEventMapper::new();

View File

@@ -356,6 +356,7 @@ impl LanguageModel for VercelLanguageModel {
self.model.id(),
self.model.supports_parallel_tool_calls(),
self.max_output_tokens(),
None,
);
let completions = self.stream_completion(request, cx);
async move {

View File

@@ -360,6 +360,7 @@ impl LanguageModel for XAiLanguageModel {
self.model.id(),
self.model.supports_parallel_tool_calls(),
self.max_output_tokens(),
None,
);
let completions = self.stream_completion(request, cx);
async move {

View File

@@ -149,7 +149,9 @@
parameters: (parameter_list
"(" @context
")" @context)))
]
(type_qualifier)? @context) @item
; Fields declarations may define multiple fields, and so @item is on the
; declarator so they each get distinct ranges.
] @item
(type_qualifier)? @context)
(comment) @annotation

View File

@@ -103,7 +103,13 @@ impl LspAdapter for CssLspAdapter {
let should_install_language_server = self
.node
.should_install_npm_package(Self::PACKAGE_NAME, &server_path, &container_dir, &version)
.should_install_npm_package(
Self::PACKAGE_NAME,
&server_path,
&container_dir,
&version,
Default::default(),
)
.await;
if should_install_language_server {

View File

@@ -487,6 +487,8 @@ const GO_MODULE_ROOT_TASK_VARIABLE: VariableName =
VariableName::Custom(Cow::Borrowed("GO_MODULE_ROOT"));
const GO_SUBTEST_NAME_TASK_VARIABLE: VariableName =
VariableName::Custom(Cow::Borrowed("GO_SUBTEST_NAME"));
const GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE: VariableName =
VariableName::Custom(Cow::Borrowed("GO_TABLE_TEST_CASE_NAME"));
impl ContextProvider for GoContextProvider {
fn build_context(
@@ -545,10 +547,19 @@ impl ContextProvider for GoContextProvider {
let go_subtest_variable = extract_subtest_name(_subtest_name.unwrap_or(""))
.map(|subtest_name| (GO_SUBTEST_NAME_TASK_VARIABLE.clone(), subtest_name));
let table_test_case_name = variables.get(&VariableName::Custom(Cow::Borrowed(
"_table_test_case_name",
)));
let go_table_test_case_variable = table_test_case_name
.and_then(extract_subtest_name)
.map(|case_name| (GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.clone(), case_name));
Task::ready(Ok(TaskVariables::from_iter(
[
go_package_variable,
go_subtest_variable,
go_table_test_case_variable,
go_module_root_variable,
]
.into_iter()
@@ -570,6 +581,28 @@ impl ContextProvider for GoContextProvider {
let module_cwd = Some(GO_MODULE_ROOT_TASK_VARIABLE.template_value());
Task::ready(Some(TaskTemplates(vec![
TaskTemplate {
label: format!(
"go test {} -v -run {}/{}",
GO_PACKAGE_TASK_VARIABLE.template_value(),
VariableName::Symbol.template_value(),
GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.template_value(),
),
command: "go".into(),
args: vec![
"test".into(),
"-v".into(),
"-run".into(),
format!(
"\\^{}\\$/\\^{}\\$",
VariableName::Symbol.template_value(),
GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.template_value(),
),
],
cwd: package_cwd.clone(),
tags: vec!["go-table-test-case".to_owned()],
..TaskTemplate::default()
},
TaskTemplate {
label: format!(
"go test {} -run {}",
@@ -842,10 +875,21 @@ mod tests {
.collect()
});
let tag_strings: Vec<String> = runnables
.iter()
.flat_map(|r| &r.runnable.tags)
.map(|tag| tag.0.to_string())
.collect();
assert!(
runnables.len() == 2,
"Should find test function and subtest with double quotes, found: {}",
runnables.len()
tag_strings.contains(&"go-test".to_string()),
"Should find go-test tag, found: {:?}",
tag_strings
);
assert!(
tag_strings.contains(&"go-subtest".to_string()),
"Should find go-subtest tag, found: {:?}",
tag_strings
);
let buffer = cx.new(|cx| {
@@ -860,10 +904,299 @@ mod tests {
.collect()
});
let tag_strings: Vec<String> = runnables
.iter()
.flat_map(|r| &r.runnable.tags)
.map(|tag| tag.0.to_string())
.collect();
assert!(
runnables.len() == 2,
"Should find test function and subtest with backticks, found: {}",
runnables.len()
tag_strings.contains(&"go-test".to_string()),
"Should find go-test tag, found: {:?}",
tag_strings
);
assert!(
tag_strings.contains(&"go-subtest".to_string()),
"Should find go-subtest tag, found: {:?}",
tag_strings
);
}
#[gpui::test]
fn test_go_table_test_slice_detection(cx: &mut TestAppContext) {
let language = language("go", tree_sitter_go::LANGUAGE.into());
let table_test = r#"
package main
import "testing"
func TestExample(t *testing.T) {
_ = "some random string"
testCases := []struct{
name string
anotherStr string
}{
{
name: "test case 1",
anotherStr: "foo",
},
{
name: "test case 2",
anotherStr: "bar",
},
}
notATableTest := []struct{
name string
}{
{
name: "some string",
},
{
name: "some other string",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// test code here
})
}
}
"#;
let buffer =
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
cx.executor().run_until_parked();
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
snapshot.runnable_ranges(0..table_test.len()).collect()
});
let tag_strings: Vec<String> = runnables
.iter()
.flat_map(|r| &r.runnable.tags)
.map(|tag| tag.0.to_string())
.collect();
assert!(
tag_strings.contains(&"go-test".to_string()),
"Should find go-test tag, found: {:?}",
tag_strings
);
assert!(
tag_strings.contains(&"go-table-test-case".to_string()),
"Should find go-table-test-case tag, found: {:?}",
tag_strings
);
let go_test_count = tag_strings.iter().filter(|&tag| tag == "go-test").count();
let go_table_test_count = tag_strings
.iter()
.filter(|&tag| tag == "go-table-test-case")
.count();
assert!(
go_test_count == 1,
"Should find exactly 1 go-test, found: {}",
go_test_count
);
assert!(
go_table_test_count == 2,
"Should find exactly 2 go-table-test-case, found: {}",
go_table_test_count
);
}
#[gpui::test]
fn test_go_table_test_slice_ignored(cx: &mut TestAppContext) {
let language = language("go", tree_sitter_go::LANGUAGE.into());
let table_test = r#"
package main
func Example() {
_ = "some random string"
notATableTest := []struct{
name string
}{
{
name: "some string",
},
{
name: "some other string",
},
}
}
"#;
let buffer =
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
cx.executor().run_until_parked();
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
snapshot.runnable_ranges(0..table_test.len()).collect()
});
let tag_strings: Vec<String> = runnables
.iter()
.flat_map(|r| &r.runnable.tags)
.map(|tag| tag.0.to_string())
.collect();
assert!(
!tag_strings.contains(&"go-test".to_string()),
"Should find go-test tag, found: {:?}",
tag_strings
);
assert!(
!tag_strings.contains(&"go-table-test-case".to_string()),
"Should find go-table-test-case tag, found: {:?}",
tag_strings
);
}
#[gpui::test]
fn test_go_table_test_map_detection(cx: &mut TestAppContext) {
let language = language("go", tree_sitter_go::LANGUAGE.into());
let table_test = r#"
package main
import "testing"
func TestExample(t *testing.T) {
_ = "some random string"
testCases := map[string]struct {
someStr string
fail bool
}{
"test failure": {
someStr: "foo",
fail: true,
},
"test success": {
someStr: "bar",
fail: false,
},
}
notATableTest := map[string]struct {
someStr string
}{
"some string": {
someStr: "foo",
},
"some other string": {
someStr: "bar",
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
// test code here
})
}
}
"#;
let buffer =
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
cx.executor().run_until_parked();
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
snapshot.runnable_ranges(0..table_test.len()).collect()
});
let tag_strings: Vec<String> = runnables
.iter()
.flat_map(|r| &r.runnable.tags)
.map(|tag| tag.0.to_string())
.collect();
assert!(
tag_strings.contains(&"go-test".to_string()),
"Should find go-test tag, found: {:?}",
tag_strings
);
assert!(
tag_strings.contains(&"go-table-test-case".to_string()),
"Should find go-table-test-case tag, found: {:?}",
tag_strings
);
let go_test_count = tag_strings.iter().filter(|&tag| tag == "go-test").count();
let go_table_test_count = tag_strings
.iter()
.filter(|&tag| tag == "go-table-test-case")
.count();
assert!(
go_test_count == 1,
"Should find exactly 1 go-test, found: {}",
go_test_count
);
assert!(
go_table_test_count == 2,
"Should find exactly 2 go-table-test-case, found: {}",
go_table_test_count
);
}
#[gpui::test]
fn test_go_table_test_map_ignored(cx: &mut TestAppContext) {
let language = language("go", tree_sitter_go::LANGUAGE.into());
let table_test = r#"
package main
func Example() {
_ = "some random string"
notATableTest := map[string]struct {
someStr string
}{
"some string": {
someStr: "foo",
},
"some other string": {
someStr: "bar",
},
}
}
"#;
let buffer =
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
cx.executor().run_until_parked();
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
snapshot.runnable_ranges(0..table_test.len()).collect()
});
let tag_strings: Vec<String> = runnables
.iter()
.flat_map(|r| &r.runnable.tags)
.map(|tag| tag.0.to_string())
.collect();
assert!(
!tag_strings.contains(&"go-test".to_string()),
"Should find go-test tag, found: {:?}",
tag_strings
);
assert!(
!tag_strings.contains(&"go-table-test-case".to_string()),
"Should find go-table-test-case tag, found: {:?}",
tag_strings
);
}

View File

@@ -1,4 +1,5 @@
(comment) @annotation
(type_declaration
"type" @context
[
@@ -42,13 +43,13 @@
(var_declaration
"var" @context
[
; The declaration may define multiple variables, and so @item is on
; the identifier so they get distinct ranges.
(var_spec
name: (identifier) @name) @item
name: (identifier) @name @item)
(var_spec_list
"("
(var_spec
name: (identifier) @name) @item
")"
name: (identifier) @name @item)
)
]
)
@@ -60,5 +61,7 @@
"(" @context
")" @context)) @item
; Fields declarations may define multiple fields, and so @item is on the
; declarator so they each get distinct ranges.
(field_declaration
name: (_) @name) @item
name: (_) @name @item)

View File

@@ -91,3 +91,103 @@
) @_
(#set! tag go-main)
)
; Table test cases - slice and map
(
(short_var_declaration
left: (expression_list (identifier) @_collection_var)
right: (expression_list
(composite_literal
type: [
(slice_type)
(map_type
key: (type_identifier) @_key_type
(#eq? @_key_type "string")
)
]
body: (literal_value
[
(literal_element
(literal_value
(keyed_element
(literal_element
(identifier) @_field_name
)
(literal_element
[
(interpreted_string_literal) @run @_table_test_case_name
(raw_string_literal) @run @_table_test_case_name
]
)
)
)
)
(keyed_element
(literal_element
[
(interpreted_string_literal) @run @_table_test_case_name
(raw_string_literal) @run @_table_test_case_name
]
)
)
]
)
)
)
)
(for_statement
(range_clause
left: (expression_list
[
(
(identifier)
(identifier) @_loop_var
)
(identifier) @_loop_var
]
)
right: (identifier) @_range_var
(#eq? @_range_var @_collection_var)
)
body: (block
(expression_statement
(call_expression
function: (selector_expression
operand: (identifier) @_t_var
field: (field_identifier) @_run_method
(#eq? @_run_method "Run")
)
arguments: (argument_list
.
[
(selector_expression
operand: (identifier) @_tc_var
(#eq? @_tc_var @_loop_var)
field: (field_identifier) @_field_check
(#eq? @_field_check @_field_name)
)
(identifier) @_arg_var
(#eq? @_arg_var @_loop_var)
]
.
(func_literal
parameters: (parameter_list
(parameter_declaration
type: (pointer_type
(qualified_type
package: (package_identifier) @_pkg
name: (type_identifier) @_type
(#eq? @_pkg "testing")
(#eq? @_type "T")
)
)
)
)
)
)
)
)
)
) @_
(#set! tag go-table-test-case)
)

View File

@@ -31,12 +31,16 @@
(export_statement
(lexical_declaration
["let" "const"] @context
; Multiple names may be exported - @item is on the declarator to keep
; ranges distinct.
(variable_declarator
name: (_) @name) @item)))
(program
(lexical_declaration
["let" "const"] @context
; Multiple names may be defined - @item is on the declarator to keep
; ranges distinct.
(variable_declarator
name: (_) @name) @item))

View File

@@ -340,7 +340,13 @@ impl LspAdapter for JsonLspAdapter {
let should_install_language_server = self
.node
.should_install_npm_package(Self::PACKAGE_NAME, &server_path, &container_dir, &version)
.should_install_npm_package(
Self::PACKAGE_NAME,
&server_path,
&container_dir,
&version,
Default::default(),
)
.await;
if should_install_language_server {

View File

@@ -206,6 +206,7 @@ impl LspAdapter for PythonLspAdapter {
&server_path,
&container_dir,
&version,
Default::default(),
)
.await;

View File

@@ -238,7 +238,7 @@ impl LspAdapter for RustLspAdapter {
)
.await?;
make_file_executable(&server_path).await?;
remove_matching(&container_dir, |path| server_path != path).await;
remove_matching(&container_dir, |path| path != destination_path).await;
GithubBinaryMetadata::write_to_file(
&GithubBinaryMetadata {
metadata_version: 1,
@@ -1023,8 +1023,14 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option<LanguageServ
last = Some(path);
}
let path = last.context("no cached binary")?;
let path = match RustLspAdapter::GITHUB_ASSET_KIND {
AssetKind::TarGz | AssetKind::Gz => path.clone(), // Tar and gzip extract in place.
AssetKind::Zip => path.clone().join("rust-analyzer.exe"), // zip contains a .exe
};
anyhow::Ok(LanguageServerBinary {
path: last.context("no cached binary")?,
path,
env: None,
arguments: Default::default(),
})

View File

@@ -108,7 +108,13 @@ impl LspAdapter for TailwindLspAdapter {
let should_install_language_server = self
.node
.should_install_npm_package(Self::PACKAGE_NAME, &server_path, &container_dir, &version)
.should_install_npm_package(
Self::PACKAGE_NAME,
&server_path,
&container_dir,
&version,
Default::default(),
)
.await;
if should_install_language_server {

View File

@@ -34,12 +34,16 @@
(export_statement
(lexical_declaration
["let" "const"] @context
; Multiple names may be exported - @item is on the declarator to keep
; ranges distinct.
(variable_declarator
name: (_) @name) @item))
(program
(lexical_declaration
["let" "const"] @context
; Multiple names may be defined - @item is on the declarator to keep
; ranges distinct.
(variable_declarator
name: (_) @name) @item))

Some files were not shown because too many files have changed in this diff Show More