Compare commits

...

82 Commits

Author SHA1 Message Date
Joseph T. Lyons
a41f654a9a zed 0.182.10 2025-04-17 09:34:29 -04:00
Agus Zubiaga
9b66e23a3f edit prediction: Assign providers when client status changes (#28919)
There was recently a change that caused the Zed Edit Prediction provider
to only be assigned when the client was connected. However, this check
happened too early, resulting in restored buffers never getting
registered. We'll now subscribe to client status changes and reassign
providers accordingly.

Release Notes:

- edit prediction: Fixed bug disabling prediction in restored buffers
2025-04-17 08:56:45 -04:00
gcp-cherry-pick-bot[bot]
a5af8c47de Fix more panics when removing excerpts (cherry-pick #28836) (#28872)
Cherry-picked Fix more panics when removing excerpts (#28836)

Release Notes:

- Fixed a panic when an excerpt removed has an edit suggestion inlay in
it

---------

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>

Co-authored-by: Kirill Bulatov <kirill@zed.dev>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-04-16 10:06:26 -06:00
Joseph T. Lyons
68f3cf8c8b v0.182.x stable 2025-04-16 08:45:21 -04:00
gcp-cherry-pick-bot[bot]
55060fca0d Fix anchor comparison in multi buffer after expanding excerpts (cherry-pick #28828) (#28837)
Cherry-picked Fix anchor comparison in multi buffer after expanding
excerpts (#28828)

Release Notes:

- Fixed incorrect excerpt comparison when replacing them

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
Co-authored-by: Conrad Irwin <conrad@zed.dev>

Co-authored-by: Kirill Bulatov <kirill@zed.dev>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
Co-authored-by: Conrad Irwin <conrad@zed.dev>
2025-04-15 22:28:38 -06:00
Joseph T. Lyons
995a265a9c zed 0.182.9 2025-04-15 16:21:49 -04:00
Antonio Scandurra
cc2168ec52 Fix rejecting multiple hunks in AgentDiff (#28806)
Release Notes:

- Fixed a bug that caused `Reject All` to not always reject _all_ the
hunks.

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-04-15 16:20:30 -04:00
Danilo Leal
4032e036af markdown: Add ability to customize individual heading level (#28733)
This PR adds a new field in the `MarkdownStyle` struct,
`heading_level_styles`, allowing, via the newly added function
`apply_heading_style` and struct `HeadingLevelStyles` to customize each
individual heading level in Markdown rendering/styling function.

Things like this should now be possible:

```rust
    MarkdownStyle {
        heading_level_styles: Some(HeadingLevelStyles {
            h1: Some(TextStyleRefinement {
                font_size: Some(rems(1.15).into()),
                ..Default::default()
            }),
        }),
        ..Default::default()
    }
```

Release Notes:

- N/A
2025-04-15 14:11:25 -04:00
Bennet Bo Fenner
54dea92e96 agent: Show a warning when some tools are incompatible with the selected model (#28755)
WIP

<img width="644" alt="image"
src="https://github.com/user-attachments/assets/b24e1a57-f82e-457c-b788-1b314ade7c84"
/>


<img width="644" alt="image"
src="https://github.com/user-attachments/assets/b158953c-2015-4cc8-b8ed-35c6fcbe162d"
/>


Release Notes:

- agent: Improve compatibility with Gemini Tool Calling APIs. When a
tool is incompatible with the Gemini APIs a warning indicator will be
displayed. Incompatible tools will be automatically excluded from the
conversation

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-04-15 13:09:49 -04:00
Danilo Leal
5f2617da87 agent: Add ability to interrupt current generation with a new message (#28762)
If you wanted to interrupt the current LLM response that's generating to
send a follow up message, you'd need to stop it first, type your new
message, and then send it. Now, you can just type your new message while
there's a response generating and send it. This will interrupt the
previous response generation and kick off a new one.

Release Notes:

- agent: Allow to send a new message while a response is generating,
interrupting the LLM to focus instead on the most recent prompt.
2025-04-15 12:51:11 -04:00
Bennet Bo Fenner
4c572ee68c agent: Return ToolResult from run inside Tool (#28763)
This is just a refactor which adds no functionality.
We now return a `ToolResult` from `Tool > run(...)`. For now this just
wraps the output task in a struct. We'll use this to implement custom
rendering of tools, see #28621.

Release Notes:

- N/A
2025-04-15 11:15:47 -04:00
Smit Barmase
756fb87073 hover_popover: Fix markdown selection for info and diagnostic popovers (#28642)
Closes #28638

This PR fixes markdown selection for the info and diagnostic popovers.

In the editor popover, after the changes in
https://github.com/zed-industries/zed/pull/28255, the markdown selection
state updates correctly, but it no longer triggers the editor element to
repaint like it used to. This is fixed by adding a subscription to
listen for markdown entity changes and triggering a repaint for the
editor.

I assume markdown selection works elsewhere because:

1. Either the `Markdown` entity is directly part of a struct that
implements the `Render` trait, causing it to repaint whenever the
markdown state changes. See
[here](d1ffda9bfe/crates/ui_prompt/src/ui_prompt.rs (L65)).
2. OR it's wrapped around component like Popover which implements
`RenderOnce` trait. See
[here](d1ffda9bfe/crates/editor/src/code_context_menus.rs (L645)).

Whereas info and diagnostic popovers does not do both. I do think we can
change it to use `Popover` component, but for now this works as quick
fix.

Extras:
- Remove unnecessary struct cloning.
- Refactor rendering logic to use `when_some`.

Release Notes:

- Fixed issue where selection wasn't working for info and diagnostic
popovers.
2025-04-15 19:49:59 +05:30
Danilo Leal
255e5b602a agent: Adjust markdown heading sizes (#28759)
Adjust the heading sizes for the Agent Panel so they're not aggressively
huge.

Release Notes:

- N/A
2025-04-15 09:27:20 -04:00
Bennet Bo Fenner
6beb3add17 agent: Make ToolWorkingSet an Entity (#28757)
Motivation is to emit events when enabled tools change, want to use this
in #28755

Release Notes:

- N/A
2025-04-15 09:05:30 -04:00
Bennet Bo Fenner
5e0c0766ec agent: Improve compatibility when using MCP servers with Gemini models (#28700)
WIP

Release Notes:

- agent: Improve compatibility when using MCPs with Gemini models
2025-04-15 09:02:46 -04:00
Richard Feldman
ba008fc2b9 Add contents_tool (#28738)
This is a combination of the "read file" and "list directory contents"
tools as part of a push to reduce our quantity of builtin tools by
combining some of them.

The functionality is all there for this tool, although there's room for
improvement on the visuals side: it currently always shows the same icon
and always says "Read" - so you can't tell at a glance when it's reading
a directory vs an individual file. Changing this will require a change
to the `Tool` trait, which can be in a separate PR. (FYI @danilo-leal!)

<img width="606" alt="Screenshot 2025-04-14 at 11 56 27 PM"
src="https://github.com/user-attachments/assets/bded72af-6476-4469-97c6-2f344629b0e4"
/>

Release Notes:

- Added `contents` tool
2025-04-15 08:55:04 -04:00
Zed Bot
4dcf10f829 Bump to 0.182.8 for @bennetbo 2025-04-15 08:43:51 +00:00
gcp-cherry-pick-bot[bot]
42015f157e gemini: Fix "invalid argument" error when request contains no tools (cherry-pick #28747) (#28750)
Cherry-picked gemini: Fix "invalid argument" error when request contains
no tools (#28747)

When we do not have any tools, we want to set the `tools` field to
`None`

Release Notes:

- Fixed an issue where Gemini requests would sometimes return a Bad
Request ("Invalid argument...")

Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
2025-04-15 10:21:28 +02:00
Danilo Leal
c4265f4d87 agent: Add some design tweaks (#28726)
Fine-tuning some areas of the Agent Panel design.

Release Notes:

- N/A
2025-04-14 22:28:58 -04:00
Bennet Bo Fenner
abe558847b agent: Apply soft-wrap when message editor is expanded (#28716)
Release Notes:

- N/A
2025-04-14 22:28:44 -04:00
Peter Tripp
8adbe5a0b0 zed 0.182.7 2025-04-14 16:00:17 -04:00
duvetfall
68bc185349 assistant_tools: Fix code_action and rename schemas for Gemini (#28634)
Closes #28475

Updates `rename` and `code_action` `input_schema` methods to use
`json_schema_for<T>()` which transforms standard JSONSchema into the
subset required by Gemini.
Also makes `input_schema` implementations consistent.
Tested tools against Gemini 2.5 Pro Preview, Zed Claude 3.7 Sonnet
Thinking, o3-mini

Release Notes:

- Agent Beta: Fixed error 400 `INVALID_ARGUMENT` when using Gemini with
`code_actions` or `rename` tools enabled.
2025-04-14 15:45:15 -04:00
Umesh Yadav
252f48c774 Add support for OpenAI GPT-4.1 models (#28708)
Release Notes:

- Add support for OpenAI GPT-4.1 via Copilot Chat and OpenAI API

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-04-14 15:35:15 -04:00
Richard Hao
9514fba6f7 copilot_chat: Add Gemini 2.5 Pro support to Copilot Chat (#28660) 2025-04-14 15:34:17 -04:00
Bennet Bo Fenner
70ff2dd596 agent: Check built-in tools schema compatibility in tests (#28691)
This ensures that we respect the `LanguageModelToolSchemaFormat` value
when we call `tool.input_schema`. This prevents us from breaking Gemini
compatibility when adding/changing built-in tools. See #28634.

The test suite will now fail with an error message like this, when
providing an incompatible input_schema:

```
thread 'tests::test_tool_schema_compatibility' panicked at crates/assistant_tools/src/assistant_tools.rs:108:17:
Tool schema for `code_actions` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).
Are you using `schema::json_schema_for<T>(format)` to generate the schema?
```


Release Notes:

- N/A
2025-04-14 11:50:53 -04:00
Danilo Leal
2093e21cd2 agent: Add scrollbar to the history view (#28690)
Ended up not making this one visible only upon hover or something
because the layout alignment would be weird given the list item spans
the full width. So, experimenting with this design here:

<img
src="https://github.com/user-attachments/assets/62bf661e-1aae-4644-8a89-49cefb3e8130"
width="700" />

Release Notes:

- agent: Add scrollbar to the history view.
2025-04-14 11:50:53 -04:00
Agus Zubiaga
b183df1de3 agent: Handle context window exceeded errors from Anthropic (#28688)
![CleanShot 2025-04-14 at 11 15
38@2x](https://github.com/user-attachments/assets/9e803ffb-74fd-486b-bebc-2155a407a9fa)

Release Notes:

- agent: Handle context window exceeded errors from Anthropic
2025-04-14 10:51:23 -04:00
Marshall Bowers
a9a0ff23b4 agent: Clean up thread auto-capturing (#28550)
This PR cleans up the thread auto-capturing added in #28271.

- Removed usage of `unsafe`
- Fixed feature flag check
- We were incorrectly not respecting the feature flag in release builds
- Made sure the telemetry event was being run on the background executor

Release Notes:

- N/A
2025-04-14 10:51:01 -04:00
Thomas Mickley-Doyle
e234399b36 agent: Auto-capture telemetry feature flag (#28271)
Release Notes:

- N/A
2025-04-14 10:49:24 -04:00
Danilo Leal
73d6e2d780 agent: Move focus to the message editor after going back (#28686)
When you hit the back button in the agent panel toolbar, we were
returning the focus to the buffer instead to the panel's message editor,
which is likely where you want to be after quickly checking history or
settings.

Release Notes:

- N/A
2025-04-14 10:42:29 -04:00
5brian
e8930c5cc1 agent: Fix expand message editor while not focused (#28650)
Allow expanding the message editor while the agent panel is not focused,
right now there is no effect when you use the button from another focus

Release Notes:

- N/A
2025-04-14 10:42:24 -04:00
Joseph T. Lyons
c4e9bf84f1 zed 0.182.6 2025-04-14 07:59:19 -04:00
Danilo Leal
d71ef64438 agent: Display keybindings for "Reject All" and "Keep All" (#28620)
<img
src="https://github.com/user-attachments/assets/2cdc5121-dd7b-4f46-8d43-88d5152c77ea"
width="550" />


Release Notes:

- N/A
2025-04-12 07:48:52 -04:00
Danilo Leal
b71c31a057 agent: Increase message editor height (#28618)
It was too tiny, felt like we could use more breathing room.

Release Notes:

- N/A
2025-04-12 07:48:52 -04:00
Danilo Leal
ddf0f4b87c agent: Adjust MCP section in the settings view (#28615)
Mentioning "MCP" more prominently, adding tool descriptions in the icon
button tooltip, and other UI adjustments.

<img
src="https://github.com/user-attachments/assets/e021b3be-99b8-454c-b5fd-0221a7947a35"
width="600" />

Release Notes:

- N/A
2025-04-11 21:00:28 -04:00
Danilo Leal
93b05d3284 agent: Add "Install MCPs" to panel menu (#28616)
Also took the opportunity to rename the "Continue in New Thread" item to
a potentially clearer name.

<img
src="https://github.com/user-attachments/assets/024de07d-f215-4c41-8fbe-652a216b61d9"
width="300"/>

Release Notes:

- N/A
2025-04-11 21:00:28 -04:00
Bennet Bo Fenner
a9e3ce1d8f agent: Cleanup message_editor (#28614)
This PR splits up the rendering of the message editor into multiple
functions. Previously we had a single `h_flex()...` expression which
spanned across 550 lines, `cargo fmt` stopped working.

Release Notes:

- N/A
2025-04-11 21:00:28 -04:00
Bennet Bo Fenner
895b81d3e3 agent: Only show recommended models that are actually configured (#28613)
Release Notes:

- N/A
2025-04-11 21:00:28 -04:00
Bennet Bo Fenner
115a4628d6 agent: Remove unused code (#28552)
This code was used before we had a proper completion menu for
at-mentions

Release Notes:

- N/A
2025-04-11 21:00:28 -04:00
Bennet Bo Fenner
5c2f88eb8c agent: Refine language model selector (#28597)
Release Notes:

- agent: Show recommended models in the agent model selector and display
the provider in the model selector's trigger.

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com>
2025-04-11 21:00:28 -04:00
Agus Zubiaga
d9810152d3 agent: Register tracked buffers with language servers (#28610)
Release Notes:

- agent: Start language servers when accessing files via tools

Co-authored-by: Michael <michael@zed.dev>
2025-04-11 21:00:28 -04:00
Joseph T. Lyons
ad92f318dd zed 0.182.5 2025-04-11 17:30:56 -04:00
Cole Miller
0aa2e3688f Backport git store bugfix to v0.182.x (#28593)
Closes #ISSUE

Release Notes:

- Fixed a bug that could cause the staged status of entries in the git
panel to be stale.
2025-04-11 17:08:46 -04:00
gcp-cherry-pick-bot[bot]
47b6eb1106 Try to weak-link ScreenCaptureKit always (cherry-pick #28585) (#28591)
Cherry-picked Try to weak-link ScreenCaptureKit always (#28585)

Release Notes:

- Fixed a crash on macOS (Catalina, Big Sur). Users may need to manually
redownload Zed if they updated to a broken release.
([#28585](https://github.com/zed-industries/zed/pull/28585))

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-04-11 15:33:02 -04:00
gcp-cherry-pick-bot[bot]
bb151eb67f Fix a panic in the git store (cherry-pick #28590) (#28595)
Cherry-picked Fix a panic in the git store (#28590)

Closes #ISSUE

Release Notes:

- Fixed a panic that could occur when git statuses were updated.

Co-authored-by: Cole Miller <cole@zed.dev>
2025-04-11 14:39:20 -04:00
Joseph T. Lyons
1e2b9863fd zed 0.182.4 2025-04-11 10:17:27 -04:00
Bennet Bo Fenner
42c6cceb07 agent: Fix bug where wrong crease for @mention would be displayed (#28558)
Release Notes:

- agent: Fix a bug where an inserted @mention did not show up as the one
that was selected
2025-04-11 10:13:41 -04:00
Danilo Leal
70182ba97a agent: Fix "new text thread" action name (#28555)
Moving from "NewPromptEditor" to "NewTextThread". We recently re-named
that and this was missing.

Release Notes:

- N/A
2025-04-10 21:58:46 -04:00
Danilo Leal
aba9875439 agent: Make the message editor expandable (#28420)
This PR allows expanding the message editor textarea to fit almost the
total height of the Agent Panel. Stylistically, I'm also changing the
font family we use in the textarea to use the buffer font; want to
experiment with this for a bit.

Release Notes:

- agent: The Agent Panel textarea can now be expanded to fill almost the
total height of the panel.

---------

Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-04-10 21:58:28 -04:00
Danilo Leal
243b04676d editor: Refactor EditorMode::Full (#28546)
This PR lightly refactors the `EditorMode::Full` exposing two new
methods: `is_full` and `set_mode`.

Motivation is to expose fields that modify the behavior when the editor
is in `Full` mode. By using is `mode.is_full()` instead of
`EditorMode::Full` we can introduce new fields without breaking other
places in the code.

Release Notes:

- N/A

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-04-10 21:58:15 -04:00
Zed Bot
47edd4a675 Bump to 0.182.3 for @maxdeviant 2025-04-11 00:18:30 +00:00
Antonio Scandurra
d487d464a1 Actually run the eval and fix a hang when retrieving outline (#28547)
Release Notes:

- Fixed a regression that caused the agent to hang sometimes.

---------

Co-authored-by: Thomas Mickley-Doyle <tmickleydoyle@gmail.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Michael Sloan <mgsloan@gmail.com>
2025-04-10 18:11:46 -06:00
Bennet Bo Fenner
d1dcbbac45 markdown: Track code block metadata in parser (#28543)
This allows us to not scan the codeblock content for newlines on every
frame in `active_thread`

Release Notes:

- N/A
2025-04-10 17:46:19 -06:00
Danilo Leal
3430d431cf Change zed.dev's default model to Claude 3.7 Sonnet (#28541)
From Claude 3.5 Sonnet to **Claude 3.7 Sonnet**.

Release Notes:

- Change the default model of Zed's hosted LLM service to Claude 3.7
Sonnet.
2025-04-10 17:46:12 -06:00
Marshall Bowers
16395bd78d language_models: Fix non-streaming Copilot Chat models (#28537)
This PR fixes usage of non-streaming Copilot Chat models.

Closes https://github.com/zed-industries/zed/issues/28528.

Release Notes:

- Fixed an issue with using non-streaming Copilot Chat models (e.g., o1,
o3-mini).
2025-04-10 17:46:01 -06:00
Marko Kungla
4261201eb9 Add --user-data-dir CLI flag and propose renaming support_dir to data_dir (#26886)
This PR introduces support for a `--user-data-dir` CLI flag to override
Zed's data directory and proposes renaming `support_dir` to `data_dir`
for better cross-platform clarity. It builds on the discussion in #25349
about custom data directories, aiming to provide a flexible
cross-platform solution.

### Changes

The PR is split into two commits:
1. **[feat(cli): add --user-data-dir to override data
directory](28e8889105)**
2. **[refactor(paths): rename support_dir to data_dir for cross-platform
clarity](affd2fc606)**


### Context
Inspired by the need for custom data directories discussed in #25349,
this PR provides an immediate implementation in the first commit, while
the second commit suggests a naming improvement for broader appeal.
@mikayla-maki, I’d appreciate your feedback, especially on the rename
proposal, given your involvement in the original discussion!

### Testing
- `cargo build `
- `./target/debug/zed --user-data-dir ~/custom-data-dir`

Release Notes:
- Added --user-data-dir CLI flag

---------

Signed-off-by: Marko Kungla <marko.kungla@gmail.com>
2025-04-10 17:32:12 -04:00
gcp-cherry-pick-bot[bot]
ff12554c07 Do not query for LSP tasks buffers that do not belong to the position given (cherry-pick #28536) (#28538)
Cherry-picked Do not query for LSP tasks buffers that do not belong to
the position given (#28536)

Follow-up of https://github.com/zed-industries/zed/pull/28359

Release Notes:

- Fixed a panic when LSP tasks are queried in certain multi buffer
excerpts

Co-authored-by: Kirill Bulatov <kirill@zed.dev>
2025-04-10 15:03:45 -06:00
Zed Bot
aae71f512e Bump to 0.182.2 for @ConradIrwin 2025-04-10 20:42:53 +00:00
gcp-cherry-pick-bot[bot]
793da80471 Bump rustls (cherry-pick #28531) (#28535)
Cherry-picked Bump rustls (#28531)

Closes #26699

Release Notes:

- Fixed a panic when enabling or disabling a VPN on macOS

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-04-10 14:38:31 -06:00
Nate Butler
4e82a27503 Add progress bar component (#28518)
- Adds the progress bar component

Release Notes:

- N/A
2025-04-10 14:18:12 -04:00
Thomas Mickley-Doyle
69335b3c73 agent: Add selected tool names to agent panel telemetry (#28247)
Release Notes:

- N/A
2025-04-10 14:16:00 -04:00
gcp-cherry-pick-bot[bot]
0b13333ca3 Fix merge conflicts jumping (cherry-pick #28508) (#28512)
Cherry-picked Fix merge conflicts jumping (#28508)

This regressed in #27568, oops.

Release Notes:

- Fixed a bug causing conflicted files in the git panel to jump to the
"Tracked" section as soon as they were staged.

Co-authored-by: Cole Miller <cole@zed.dev>
2025-04-10 10:51:58 -04:00
gcp-cherry-pick-bot[bot]
f6ef8662d4 Downgrade environment-related logging (cherry-pick #28509) (#28514)
Cherry-picked Downgrade environment-related logging (#28509)

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Co-authored-by: Cole Miller <cole@zed.dev>
2025-04-10 10:51:37 -04:00
Joseph T. Lyons
d5cc01abe7 zed 0.182.1 2025-04-10 09:21:04 -04:00
Agus Zubiaga
08fecda7ee agent: Use current shell (#28470)
Release Notes:

- agent: Replace `bash` tool with `terminal` tool which uses the current
shell

---------

Co-authored-by: Bennet <bennet@zed.dev>
Co-authored-by: Antonio <antonio@zed.dev>
2025-04-10 09:19:42 -04:00
Antonio Scandurra
538b88c260 Lay the groundwork for a Rust-based eval (#28488)
Also, we moved the logic for driving the agentic loop into `Thread` so
that we don't have to re-implement it.

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-04-10 09:19:42 -04:00
Bennet Bo Fenner
701e033fea agent: Optimize render_markdown_block function (#28487)
Co-Authored-by: Agus <agus@zed.dev>

Closes #ISSUE

Release Notes:

- N/A

Co-authored-by: Agus <agus@zed.dev>
2025-04-10 09:19:42 -04:00
Antonio Scandurra
9ed28c5e31 Revert "Add reminder message about system prompt" (#28482)
This breaks the agentic loop.
2025-04-10 09:19:42 -04:00
Danilo Leal
23b7a98152 agent: Fix toolbar spacing (#28485)
Release Notes:

- N/A
2025-04-10 09:19:42 -04:00
Danilo Leal
fe075ac273 agent: Add button to open thread as markdown (#28481)
<img
src="https://github.com/user-attachments/assets/92ca8f64-a949-4cc1-a657-3978a2c65839"
width="600"/>

Release Notes:

- agent: The action to open the current active thread in Markdown is now
exposed in the UI.
2025-04-09 23:12:03 -04:00
5brian
2fb6ec2d5a agent: Prevent sending whitespace only messages (#28409)
Prevent this from happening when sending a prompt with only spaces and
newlines:


![image](https://github.com/user-attachments/assets/b275f4c5-c013-4695-8fb4-e3ad75d41750)

Release Notes:

- agent: Prevent from sending messages containing only white space.
2025-04-09 23:12:03 -04:00
Danilo Leal
cef835c901 agent: Collapse code blocks in the active thread (#28467)
Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-04-09 22:47:58 -04:00
Richard Feldman
0623a6a400 Add code action tool and rename tool (#28453)
Having a separate rename tool seems to make the agent more likely to use
it compared to having it be part of the code actions tool.

Release Notes:

- Added code action tool and rename tool.
2025-04-09 22:47:51 -04:00
Michael Sloan
a7722c9bc7 Fix directory context paths (#28459)
Release Notes:

- N/A
2025-04-09 22:47:30 -04:00
Bennet Bo Fenner
782b35aeb5 agent: Fuzzy match on paths and symbols when typing @ (#28357)
Release Notes:

- agent: Improve fuzzy matching when using @-mentions
2025-04-09 22:47:10 -04:00
Thomas Mickley-Doyle
612e30ea9c agent: Add reactions at the response level (#27958)
Release Notes:

- Added the user reaction (👍 or 👎) to each agent response.
- 👎 will trigger a comment box linked to the response

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-04-09 22:46:52 -04:00
Michael Sloan
8794d545c3 Reapply "Use Project instead of Workspace in ContextStore (#28402)" (#28441)
Motivation for this change is to use `ContextStore` in headless
assistant, which requires it to not depend on UI entities like
`Workspace`.

This reapplies a change that was revert was in #28428, and fixes the panic.

Release Notes:

- N/A
2025-04-09 22:46:45 -04:00
Richard Feldman
ec1ae32b8e Make regex search tool optionally case-sensitive (#28427)
Release Notes:

- The agent panel's regex search tool is now optionally case-sensitive.
2025-04-09 22:45:56 -04:00
Richard Feldman
17a271d4e2 Revert to fix panic in inline assistant (#28428)
This reverts commit f12a554f86, which
introduced a panic in inline assistant (cc @mgsloan) - I'm not sure what
the motivation was for that change, but I figure we can revert to fix
the inline assistant now and deal with that later. 😄

Panic was:

> Thread "main" panicked with "cannot read workspace::Workspace while it
is already being updated" at
/Users/rtfeldman/code/zed/crates/gpui/src/app/entity_map.rs:139:32


Release Notes:

- N/A
2025-04-09 11:25:13 -04:00
Agus Zubiaga
3d9dbfe902 Fix bash tool output (#28391) 2025-04-09 10:38:26 -04:00
Richard Feldman
b9aa296d4a Add reminder message about system prompt (#28344)
Trying out sending the model a reminder message about code blocks in the
system prompt. If this seems to work well, we can include more specific
reminder messages, e.g. tool-specific ones.

Release Notes:

- N/A
2025-04-09 10:38:19 -04:00
Joseph T. Lyons
305f113741 v0.182.x preview 2025-04-09 09:10:17 -04:00
171 changed files with 6999 additions and 4368 deletions

118
Cargo.lock generated
View File

@@ -52,7 +52,6 @@ dependencies = [
name = "agent"
version = "0.1.0"
dependencies = [
"agent_rules",
"anyhow",
"assistant_context_editor",
"assistant_settings",
@@ -65,6 +64,7 @@ dependencies = [
"clock",
"collections",
"command_palette_hooks",
"component",
"context_server",
"convert_case 0.8.0",
"db",
@@ -85,6 +85,7 @@ dependencies = [
"language",
"language_model",
"language_model_selector",
"linkme",
"log",
"lsp",
"markdown",
@@ -114,6 +115,7 @@ dependencies = [
"terminal_view",
"text",
"theme",
"thiserror 2.0.12",
"time",
"time_format",
"ui",
@@ -125,57 +127,6 @@ dependencies = [
"zed_actions",
]
[[package]]
name = "agent_eval"
version = "0.1.0"
dependencies = [
"agent",
"anyhow",
"assistant_tool",
"assistant_tools",
"clap",
"client",
"collections",
"context_server",
"dap",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"language",
"language_model",
"language_models",
"node_runtime",
"project",
"prompt_store",
"release_channel",
"reqwest_client",
"serde",
"serde_json",
"serde_json_lenient",
"settings",
"smol",
"tempfile",
"util",
"walkdir",
"workspace-hack",
]
[[package]]
name = "agent_rules"
version = "0.1.0"
dependencies = [
"anyhow",
"fs",
"gpui",
"indoc",
"prompt_store",
"util",
"workspace-hack",
"worktree",
]
[[package]]
name = "ahash"
version = "0.7.8"
@@ -1639,7 +1590,7 @@ dependencies = [
"hyper-util",
"pin-project-lite",
"rustls 0.21.12",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
@@ -4900,6 +4851,37 @@ dependencies = [
"num-traits",
]
[[package]]
name = "eval"
version = "0.1.0"
dependencies = [
"agent",
"anyhow",
"assistant_settings",
"assistant_tool",
"assistant_tools",
"client",
"context_server",
"dap",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"language",
"language_model",
"language_models",
"node_runtime",
"project",
"prompt_store",
"release_channel",
"reqwest_client",
"serde",
"settings",
"toml 0.8.20",
"workspace-hack",
]
[[package]]
name = "evals"
version = "0.1.0"
@@ -6635,7 +6617,7 @@ dependencies = [
name = "http_client_tls"
version = "0.1.0"
dependencies = [
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-platform-verifier",
"workspace-hack",
]
@@ -6734,7 +6716,7 @@ dependencies = [
"http 1.3.1",
"hyper 1.6.0",
"hyper-util",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
@@ -7666,6 +7648,7 @@ dependencies = [
name = "language_model_selector"
version = "0.1.0"
dependencies = [
"collections",
"feature_flags",
"gpui",
"language_model",
@@ -7716,6 +7699,7 @@ dependencies = [
"smol",
"strum",
"theme",
"thiserror 2.0.12",
"tiktoken-rs",
"tokio",
"ui",
@@ -11262,7 +11246,7 @@ dependencies = [
"quinn-proto",
"quinn-udp",
"rustc-hash 2.1.1",
"rustls 0.23.25",
"rustls 0.23.26",
"socket2",
"thiserror 2.0.12",
"tokio",
@@ -11281,7 +11265,7 @@ dependencies = [
"rand 0.9.0",
"ring",
"rustc-hash 2.1.1",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-pki-types",
"slab",
"thiserror 2.0.12",
@@ -11889,7 +11873,7 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-native-certs 0.8.1",
"rustls-pemfile 2.2.0",
"rustls-pki-types",
@@ -12280,9 +12264,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.25"
version = "0.23.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "822ee9188ac4ec04a2f0531e55d035fb2de73f18b41a63c70c2712503b6fb13c"
checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0"
dependencies = [
"aws-lc-rs",
"log",
@@ -12356,7 +12340,7 @@ dependencies = [
"jni",
"log",
"once_cell",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-native-certs 0.8.1",
"rustls-platform-verifier-android",
"rustls-webpki 0.103.1",
@@ -13454,7 +13438,7 @@ dependencies = [
"once_cell",
"percent-encoding",
"rust_decimal",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-pemfile 2.2.0",
"serde",
"serde_json",
@@ -14749,7 +14733,7 @@ version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b"
dependencies = [
"rustls 0.23.25",
"rustls 0.23.26",
"tokio",
]
@@ -14809,7 +14793,7 @@ checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
dependencies = [
"futures-util",
"log",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.26.2",
@@ -15368,7 +15352,7 @@ dependencies = [
"httparse",
"log",
"rand 0.9.0",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-pki-types",
"sha1",
"thiserror 2.0.12",
@@ -17716,7 +17700,7 @@ dependencies = [
"rust_decimal",
"rustix 0.38.44",
"rustix 1.0.5",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-webpki 0.103.1",
"scopeguard",
"sea-orm",
@@ -18105,7 +18089,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.182.0"
version = "0.182.10"
dependencies = [
"activity_indicator",
"agent",

View File

@@ -3,13 +3,11 @@ resolver = "2"
members = [
"crates/activity_indicator",
"crates/agent",
"crates/agent_rules",
"crates/anthropic",
"crates/askpass",
"crates/assets",
"crates/assistant",
"crates/assistant_context_editor",
"crates/agent_eval",
"crates/assistant_settings",
"crates/assistant_slash_command",
"crates/assistant_slash_commands",
@@ -47,6 +45,7 @@ members = [
"crates/diagnostics",
"crates/docs_preprocessor",
"crates/editor",
"crates/eval",
"crates/evals",
"crates/extension",
"crates/extension_api",
@@ -210,14 +209,12 @@ edition = "2024"
activity_indicator = { path = "crates/activity_indicator" }
agent = { path = "crates/agent" }
agent_rules = { path = "crates/agent_rules" }
ai = { path = "crates/ai" }
anthropic = { path = "crates/anthropic" }
askpass = { path = "crates/askpass" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant_context_editor = { path = "crates/assistant_context_editor" }
assistant_eval = { path = "crates/agent_eval" }
assistant_settings = { path = "crates/assistant_settings" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
@@ -509,7 +506,7 @@ runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804
rustc-demangle = "0.1.23"
rust-embed = { version = "8.4", features = ["include-exclude"] }
rustc-hash = "2.1.0"
rustls = { version = "0.23.22" }
rustls = { version = "0.23.26" }
rustls-platform-verifier = "0.5.0"
scap = { git = "https://github.com/zed-industries/scap", rev = "08f0a01417505cc0990b9931a37e5120db92e0d0", default-features = false }
schemars = { version = "0.8", features = ["impl_json_schema", "indexmap2"] }

View File

@@ -1,12 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="16" height="16" rx="2" fill="black" fill-opacity="0.2"/>
<g clip-path="url(#clip0_1916_18)">
<path d="M10.652 3.79999H8.816L12.164 12.2H14L10.652 3.79999Z" fill="#1F1F1E"/>
<path d="M5.348 3.79999L2 12.2H3.872L4.55672 10.436H8.05927L8.744 12.2H10.616L7.268 3.79999H5.348ZM5.16224 8.87599L6.308 5.92399L7.45374 8.87599H5.16224Z" fill="#1F1F1E"/>
</g>
<defs>
<clipPath id="clip0_1916_18">
<rect width="12" height="8.4" fill="white" transform="translate(2 3.79999)"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 601 B

View File

@@ -150,7 +150,9 @@
"context": "AgentDiff",
"bindings": {
"ctrl-y": "agent::Keep",
"ctrl-n": "agent::Reject"
"ctrl-n": "agent::Reject",
"ctrl-shift-y": "agent::KeepAll",
"ctrl-shift-n": "agent::RejectAll"
}
},
{
@@ -625,12 +627,13 @@
"context": "AgentPanel",
"bindings": {
"ctrl-n": "agent::NewThread",
"ctrl-alt-n": "agent::NewPromptEditor",
"ctrl-alt-n": "agent::NewTextThread",
"ctrl-shift-h": "agent::OpenHistory",
"ctrl-alt-c": "agent::OpenConfiguration",
"ctrl-i": "agent::ToggleProfileSelector",
"ctrl-alt-/": "assistant::ToggleModelSelector",
"ctrl-shift-a": "agent::ToggleContextPicker",
"shift-escape": "agent::ExpandMessageEditor",
"ctrl-e": "agent::ChatMode",
"ctrl-alt-e": "agent::RemoveAllContext"
}
@@ -646,7 +649,7 @@
"context": "AgentPanel && prompt_editor",
"use_key_equivalents": true,
"bindings": {
"cmd-n": "agent::NewPromptEditor",
"cmd-n": "agent::NewTextThread",
"cmd-alt-t": "agent::NewThread"
}
},

View File

@@ -242,7 +242,9 @@
"use_key_equivalents": true,
"bindings": {
"cmd-y": "agent::Keep",
"cmd-n": "agent::Reject"
"cmd-n": "agent::Reject",
"cmd-shift-y": "agent::KeepAll",
"cmd-shift-n": "agent::RejectAll"
}
},
{
@@ -281,12 +283,13 @@
"use_key_equivalents": true,
"bindings": {
"cmd-n": "agent::NewThread",
"cmd-alt-n": "agent::NewPromptEditor",
"cmd-alt-n": "agent::NewTextThread",
"cmd-shift-h": "agent::OpenHistory",
"cmd-alt-c": "agent::OpenConfiguration",
"cmd-i": "agent::ToggleProfileSelector",
"cmd-alt-/": "assistant::ToggleModelSelector",
"cmd-shift-a": "agent::ToggleContextPicker",
"shift-escape": "agent::ExpandMessageEditor",
"cmd-e": "agent::ChatMode",
"cmd-alt-e": "agent::RemoveAllContext"
}
@@ -302,7 +305,7 @@
"context": "AgentPanel && prompt_editor",
"use_key_equivalents": true,
"bindings": {
"cmd-n": "agent::NewPromptEditor",
"cmd-n": "agent::NewTextThread",
"cmd-alt-t": "agent::NewThread"
}
},

View File

@@ -163,3 +163,8 @@ There are rules that apply to these root directories:
{{/if}}
{{/each}}
{{/if}}
<user_environment>
Operating System: {{os}} ({{arch}})
Shell: {{shell}}
</user_environment>

View File

@@ -624,14 +624,14 @@
// The provider to use.
"provider": "zed.dev",
// The model to use.
"model": "claude-3-5-sonnet-latest"
"model": "claude-3-7-sonnet-latest"
},
// The model to use when applying edits from the assistant.
"editor_model": {
// The provider to use.
"provider": "zed.dev",
// The model to use.
"model": "claude-3-5-sonnet-latest"
"model": "claude-3-7-sonnet-latest"
},
// When enabled, the agent can run potentially destructive actions without asking for your confirmation.
"always_allow_tool_actions": false,
@@ -642,6 +642,7 @@
// We don't know which of the context server tools are safe for the "Ask" profile, so we don't enable them by default.
// "enable_all_context_servers": true,
"tools": {
"contents": true,
"diagnostics": true,
"fetch": true,
"list_directory": false,
@@ -656,9 +657,11 @@
"name": "Write",
"enable_all_context_servers": true,
"tools": {
"bash": true,
"terminal": true,
"batch_tool": true,
"code_actions": true,
"code_symbols": true,
"contents": true,
"copy_path": false,
"create_file": true,
"delete_path": false,
@@ -671,6 +674,7 @@
"path_search": true,
"read_file": true,
"regex_search": true,
"rename": true,
"symbol_info": true,
"thinking": true
}

View File

@@ -19,7 +19,6 @@ test-support = [
]
[dependencies]
agent_rules.workspace = true
anyhow.workspace = true
assistant_context_editor.workspace = true
assistant_settings.workspace = true
@@ -32,6 +31,7 @@ client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
component.workspace = true
context_server.workspace = true
convert_case.workspace = true
db.workspace = true
@@ -51,6 +51,7 @@ itertools.workspace = true
language.workspace = true
language_model.workspace = true
language_model_selector.workspace = true
linkme.workspace = true
log.workspace = true
lsp.workspace = true
markdown.workspace = true
@@ -79,15 +80,16 @@ terminal.workspace = true
terminal_view.workspace = true
text.workspace = true
theme.workspace = true
thiserror.workspace = true
time.workspace = true
time_format.workspace = true
ui.workspace = true
ui_input.workspace = true
util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
buffer_diff = { workspace = true, features = ["test-support"] }

View File

@@ -1,40 +1,41 @@
use crate::AssistantPanel;
use crate::context::{AssistantContext, ContextId};
use crate::context_picker::MentionLink;
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
ThreadEvent, ThreadFeedback,
};
use crate::thread_store::ThreadStore;
use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
use editor::{Editor, MultiBuffer};
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem,
DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Hsla, ListAlignment, ListState,
MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, Task,
TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::parser::CodeBlockKind;
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, without_fences};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
use project::ProjectItem as _;
use rope::Point;
use settings::{Settings as _, update_settings_file};
use std::ops::Range;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;
use text::ToPoint;
use theme::ThemeSettings;
use ui::{Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*};
use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
};
use util::ResultExt as _;
use workspace::{OpenOptions, Workspace};
@@ -57,12 +58,13 @@ pub struct ActiveThread {
editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
expanded_code_blocks: HashMap<(MessageId, usize), bool>,
last_error: Option<ThreadError>,
notifications: Vec<WindowHandle<AgentNotification>>,
copied_code_block_ids: HashSet<(MessageId, usize)>,
_subscriptions: Vec<Subscription>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
feedback_message_editor: Option<Entity<Editor>>,
open_feedback_editors: HashMap<MessageId, Entity<Editor>>,
}
struct RenderedMessage {
@@ -173,11 +175,37 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
});
MarkdownStyle {
base_text_style: text_style,
base_text_style: text_style.clone(),
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block_overflow_x_scroll: true,
table_overflow_x_scroll: true,
heading_level_styles: Some(HeadingLevelStyles {
h1: Some(TextStyleRefinement {
font_size: Some(rems(1.15).into()),
..Default::default()
}),
h2: Some(TextStyleRefinement {
font_size: Some(rems(1.1).into()),
..Default::default()
}),
h3: Some(TextStyleRefinement {
font_size: Some(rems(1.05).into()),
..Default::default()
}),
h4: Some(TextStyleRefinement {
font_size: Some(rems(1.).into()),
..Default::default()
}),
h5: Some(TextStyleRefinement {
font_size: Some(rems(0.95).into()),
..Default::default()
}),
h6: Some(TextStyleRefinement {
font_size: Some(rems(0.875).into()),
..Default::default()
}),
}),
code_block: StyleRefinement {
padding: EdgesRefinement {
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
@@ -289,15 +317,17 @@ fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
}
}
const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 10;
fn render_markdown_code_block(
message_id: MessageId,
ix: usize,
kind: &CodeBlockKind,
parsed_markdown: &ParsedMarkdown,
codeblock_range: Range<usize>,
metadata: CodeBlockMetadata,
active_thread: Entity<ActiveThread>,
workspace: WeakEntity<Workspace>,
_window: &mut Window,
_window: &Window,
cx: &App,
) -> Div {
let label = match kind {
@@ -377,16 +407,20 @@ fn render_markdown_code_block(
.rounded_sm()
.hover(|item| item.bg(cx.theme().colors().element_hover.opacity(0.5)))
.tooltip(Tooltip::text("Jump to File"))
.children(
file_icons::FileIcons::get_icon(&path_range.path, cx)
.map(Icon::from_path)
.map(|icon| icon.color(Color::Muted).size(IconSize::XSmall)),
)
.child(content)
.child(
Icon::new(IconName::ArrowUpRight)
.size(IconSize::XSmall)
.color(Color::Ignored),
h_flex()
.gap_0p5()
.children(
file_icons::FileIcons::get_icon(&path_range.path, cx)
.map(Icon::from_path)
.map(|icon| icon.color(Color::Muted).size(IconSize::XSmall)),
)
.child(content)
.child(
Icon::new(IconName::ArrowUpRight)
.size(IconSize::XSmall)
.color(Color::Ignored),
),
)
.on_click({
let path_range = path_range.clone();
@@ -444,17 +478,24 @@ fn render_markdown_code_block(
}),
};
let codeblock_was_copied = active_thread
.read(cx)
.copied_code_block_ids
.contains(&(message_id, ix));
let is_expanded = active_thread
.read(cx)
.expanded_code_blocks
.get(&(message_id, ix))
.copied()
.unwrap_or(false);
let codeblock_header_bg = cx
.theme()
.colors()
.element_background
.blend(cx.theme().colors().editor_foreground.opacity(0.01));
let codeblock_was_copied = active_thread
.read(cx)
.copied_code_block_ids
.contains(&(message_id, ix));
let codeblock_header = h_flex()
.group("codeblock_header")
.p_1()
@@ -466,57 +507,108 @@ fn render_markdown_code_block(
.rounded_t_md()
.children(label)
.child(
div().visible_on_hover("codeblock_header").child(
IconButton::new(
("copy-markdown-code", ix),
if codeblock_was_copied {
IconName::Check
} else {
IconName::Copy
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Copy Code"))
.on_click({
let active_thread = active_thread.clone();
let parsed_markdown = parsed_markdown.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
this.copied_code_block_ids.insert((message_id, ix));
h_flex()
.gap_1()
.child(
div().visible_on_hover("codeblock_header").child(
IconButton::new(
("copy-markdown-code", ix),
if codeblock_was_copied {
IconName::Check
} else {
IconName::Copy
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Copy Code"))
.on_click({
let active_thread = active_thread.clone();
let parsed_markdown = parsed_markdown.clone();
let code_block_range = metadata.content_range.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
this.copied_code_block_ids.insert((message_id, ix));
let code =
without_fences(&parsed_markdown.source()[codeblock_range.clone()])
.to_string();
let code = parsed_markdown.source()[code_block_range.clone()]
.to_string();
cx.write_to_clipboard(ClipboardItem::new_string(code));
cx.write_to_clipboard(ClipboardItem::new_string(code.clone()));
cx.spawn(async move |this, cx| {
cx.background_executor()
.timer(Duration::from_secs(2))
.await;
cx.spawn(async move |this, cx| {
cx.background_executor().timer(Duration::from_secs(2)).await;
cx.update(|cx| {
this.update(cx, |this, cx| {
this.copied_code_block_ids.remove(&(message_id, ix));
cx.notify();
cx.update(|cx| {
this.update(cx, |this, cx| {
this.copied_code_block_ids
.remove(&(message_id, ix));
cx.notify();
})
})
.ok();
})
})
.ok();
})
.detach();
});
}
}),
),
.detach();
});
}
}),
),
)
.when(
metadata.line_count > MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK,
|header| {
header.child(
IconButton::new(
("expand-collapse-code", ix),
if is_expanded {
IconName::ChevronUp
} else {
IconName::ChevronDown
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text(if is_expanded {
"Collapse Code"
} else {
"Expand Code"
}))
.on_click({
let active_thread = active_thread.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
let is_expanded = this
.expanded_code_blocks
.entry((message_id, ix))
.or_insert(false);
*is_expanded = !*is_expanded;
cx.notify();
});
}
}),
)
},
),
);
v_flex()
.mb_2()
.relative()
.my_2()
.overflow_hidden()
.rounded_lg()
.border_1()
.border_color(cx.theme().colors().border_variant)
.bg(cx.theme().colors().editor_background)
.child(codeblock_header)
.when(
metadata.line_count > MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK,
|this| {
if is_expanded {
this.h_full()
} else {
this.max_h_80()
}
},
)
}
fn open_markdown_link(
@@ -604,6 +696,7 @@ impl ActiveThread {
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe_in(&thread, window, Self::handle_thread_event),
cx.subscribe(&thread_store, Self::handle_rules_loading_error),
];
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), {
@@ -626,6 +719,7 @@ impl ActiveThread {
rendered_tool_uses: HashMap::default(),
expanded_tool_uses: HashMap::default(),
expanded_thinking_segments: HashMap::default(),
expanded_code_blocks: HashMap::default(),
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state),
show_scrollbar: false,
@@ -636,7 +730,7 @@ impl ActiveThread {
notifications: Vec::new(),
_subscriptions: subscriptions,
notification_subscriptions: HashMap::default(),
feedback_message_editor: None,
open_feedback_editors: HashMap::default(),
};
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
@@ -768,10 +862,9 @@ impl ActiveThread {
| ThreadEvent::SummaryChanged => {
self.save_thread(cx);
}
ThreadEvent::DoneStreaming => {
let thread = self.thread.read(cx);
if !thread.is_generating() {
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
let thread = self.thread.read(cx);
self.show_notification(
if thread.used_tools_since_last_user_message() {
"Finished running tools"
@@ -783,7 +876,8 @@ impl ActiveThread {
cx,
);
}
}
_ => {}
},
ThreadEvent::ToolConfirmationNeeded => {
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
}
@@ -828,11 +922,7 @@ impl ActiveThread {
self.save_thread(cx);
cx.notify();
}
ThreadEvent::UsePendingTools => {
let tool_uses = self
.thread
.update(cx, |thread, cx| thread.use_pending_tools(cx));
ThreadEvent::UsePendingTools { tool_uses } => {
for tool_use in tool_uses {
self.render_tool_use_markdown(
tool_use.id.clone(),
@@ -844,11 +934,8 @@ impl ActiveThread {
}
}
ThreadEvent::ToolFinished {
pending_tool_use,
canceled,
..
pending_tool_use, ..
} => {
let canceled = *canceled;
if let Some(tool_use) = pending_tool_use {
self.render_tool_use_markdown(
tool_use.id.clone(),
@@ -862,23 +949,24 @@ impl ActiveThread {
cx,
);
}
if self.thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
if !canceled {
thread.send_to_model(model, RequestKind::Chat, cx);
}
});
}
}
}
ThreadEvent::CheckpointChanged => cx.notify(),
}
}
fn handle_rules_loading_error(
&mut self,
_thread_store: Entity<ThreadStore>,
error: &RulesLoadingError,
cx: &mut Context<Self>,
) {
self.last_error = Some(ThreadError::Message {
header: "Error loading rules file".into(),
message: error.message.clone(),
});
cx.notify();
}
fn show_notification(
&mut self,
caption: impl Into<SharedString>,
@@ -939,7 +1027,7 @@ impl ActiveThread {
|this, _, event, window, cx| match event {
AgentNotificationEvent::Accepted => {
let handle = window.window_handle();
cx.activate(true); // Switch back to the Zed application
cx.activate(true);
let workspace_handle = this.workspace.clone();
@@ -1111,34 +1199,37 @@ impl ActiveThread {
fn handle_feedback_click(
&mut self,
message_id: MessageId,
feedback: ThreadFeedback,
window: &mut Window,
cx: &mut Context<Self>,
) {
let report = self.thread.update(cx, |thread, cx| {
thread.report_message_feedback(message_id, feedback, cx)
});
cx.spawn(async move |this, cx| {
report.await?;
this.update(cx, |_this, cx| cx.notify())
})
.detach_and_log_err(cx);
match feedback {
ThreadFeedback::Positive => {
let report = self
.thread
.update(cx, |thread, cx| thread.report_feedback(feedback, cx));
let this = cx.entity().downgrade();
cx.spawn(async move |_, cx| {
report.await?;
this.update(cx, |_this, cx| cx.notify())
})
.detach_and_log_err(cx);
self.open_feedback_editors.remove(&message_id);
}
ThreadFeedback::Negative => {
self.handle_show_feedback_comments(window, cx);
self.handle_show_feedback_comments(message_id, window, cx);
}
}
}
fn handle_show_feedback_comments(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if self.feedback_message_editor.is_some() {
return;
}
fn handle_show_feedback_comments(
&mut self,
message_id: MessageId,
window: &mut Window,
cx: &mut Context<Self>,
) {
let buffer = cx.new(|cx| {
let empty_string = String::new();
MultiBuffer::singleton(cx.new(|cx| Buffer::local(empty_string, cx)), cx)
@@ -1160,34 +1251,47 @@ impl ActiveThread {
});
editor.read(cx).focus_handle(cx).focus(window);
self.feedback_message_editor = Some(editor);
self.open_feedback_editors.insert(message_id, editor);
cx.notify();
}
fn submit_feedback_message(&mut self, cx: &mut Context<Self>) {
let Some(editor) = self.feedback_message_editor.clone() else {
fn submit_feedback_message(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
let Some(editor) = self.open_feedback_editors.get(&message_id) else {
return;
};
let report_task = self.thread.update(cx, |thread, cx| {
thread.report_feedback(ThreadFeedback::Negative, cx)
thread.report_message_feedback(message_id, ThreadFeedback::Negative, cx)
});
let comments = editor.read(cx).text(cx);
if !comments.is_empty() {
let thread_id = self.thread.read(cx).id().clone();
let comments_value = String::from(comments.as_str());
telemetry::event!("Assistant Thread Feedback Comments", thread_id, comments);
let message_content = self
.thread
.read(cx)
.message(message_id)
.map(|msg| msg.to_string())
.unwrap_or_default();
telemetry::event!(
"Assistant Thread Feedback Comments",
thread_id,
message_id = message_id.0,
message_content,
comments = comments_value
);
self.open_feedback_editors.remove(&message_id);
cx.spawn(async move |this, cx| {
report_task.await?;
this.update(cx, |_this, cx| cx.notify())
})
.detach_and_log_err(cx);
}
self.feedback_message_editor = None;
let this = cx.entity().downgrade();
cx.spawn(async move |_, cx| {
report_task.await?;
this.update(cx, |_this, cx| cx.notify())
})
.detach_and_log_err(cx);
}
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
@@ -1214,7 +1318,18 @@ impl ActiveThread {
let is_first_message = ix == 0;
let is_last_message = ix == self.messages.len() - 1;
let show_feedback = is_last_message && message.role != Role::User;
let show_feedback = (!is_generating && is_last_message && message.role != Role::User)
|| self.messages.get(ix + 1).map_or(false, |next_id| {
self.thread
.read(cx)
.message(*next_id)
.map_or(false, |next_message| {
next_message.role == Role::User
&& thread.tool_uses_for_message(*next_id, cx).is_empty()
&& thread.tool_results_for_message(*next_id).is_empty()
})
});
let needs_confirmation = tool_uses.iter().any(|tool_use| tool_use.needs_confirmation);
@@ -1287,8 +1402,17 @@ impl ActiveThread {
let editor_bg_color = colors.editor_background;
let bg_user_message_header = editor_bg_color.blend(active_color.opacity(0.25));
let feedback_container = h_flex().pt_2().pb_4().px_4().gap_1().justify_between();
let feedback_items = match self.thread.read(cx).feedback() {
let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileCode)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
.tooltip(Tooltip::text("Open Thread as Markdown"))
.on_click(|_event, window, cx| {
window.dispatch_action(Box::new(OpenActiveThreadAsMarkdown), cx)
});
let feedback_container = h_flex().py_2().px_4().gap_1().justify_between();
let feedback_items = match self.thread.read(cx).message_feedback(message_id) {
Some(feedback) => feedback_container
.child(
Label::new(match feedback {
@@ -1302,18 +1426,20 @@ impl ActiveThread {
)
.child(
h_flex()
.pr_1()
.gap_1()
.child(
IconButton::new("feedback-thumbs-up", IconName::ThumbsUp)
IconButton::new(("feedback-thumbs-up", ix), IconName::ThumbsUp)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(match feedback {
ThreadFeedback::Positive => Color::Accent,
ThreadFeedback::Negative => Color::Ignored,
})
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Helpful Response"))
.on_click(cx.listener(move |this, _, window, cx| {
this.handle_feedback_click(
message_id,
ThreadFeedback::Positive,
window,
cx,
@@ -1321,22 +1447,24 @@ impl ActiveThread {
})),
)
.child(
IconButton::new("feedback-thumbs-down", IconName::ThumbsDown)
IconButton::new(("feedback-thumbs-down", ix), IconName::ThumbsDown)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(match feedback {
ThreadFeedback::Positive => Color::Ignored,
ThreadFeedback::Negative => Color::Accent,
})
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Not Helpful"))
.on_click(cx.listener(move |this, _, window, cx| {
this.handle_feedback_click(
message_id,
ThreadFeedback::Negative,
window,
cx,
);
})),
),
)
.child(open_as_markdown),
)
.into_any_element(),
None => feedback_container
@@ -1349,15 +1477,17 @@ impl ActiveThread {
)
.child(
h_flex()
.pr_1()
.gap_1()
.child(
IconButton::new("feedback-thumbs-up", IconName::ThumbsUp)
IconButton::new(("feedback-thumbs-up", ix), IconName::ThumbsUp)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Helpful Response"))
.on_click(cx.listener(move |this, _, window, cx| {
this.handle_feedback_click(
message_id,
ThreadFeedback::Positive,
window,
cx,
@@ -1365,19 +1495,21 @@ impl ActiveThread {
})),
)
.child(
IconButton::new("feedback-thumbs-down", IconName::ThumbsDown)
IconButton::new(("feedback-thumbs-down", ix), IconName::ThumbsDown)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Not Helpful"))
.on_click(cx.listener(move |this, _, window, cx| {
this.handle_feedback_click(
message_id,
ThreadFeedback::Negative,
window,
cx,
);
})),
),
)
.child(open_as_markdown),
)
.into_any_element(),
};
@@ -1392,12 +1524,36 @@ impl ActiveThread {
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.child(edit_message_editor)
.pt_1()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.into_any()
} else {
div()
@@ -1562,11 +1718,9 @@ impl ActiveThread {
),
Role::Assistant => v_flex()
.id(("message-container", ix))
.ml_2()
.ml_2p5()
.pl_2()
.pr_4()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.children(message_content)
.when(has_tool_uses, |parent| {
parent.children(
@@ -1669,31 +1823,31 @@ impl ActiveThread {
.child(generating_label.unwrap()),
)
})
.when(show_feedback && !is_generating, |parent| {
.when(show_feedback, move |parent| {
parent.child(feedback_items).when_some(
self.feedback_message_editor.clone(),
|parent, feedback_editor| {
self.open_feedback_editors.get(&message_id),
move |parent, feedback_editor| {
let focus_handle = feedback_editor.focus_handle(cx);
parent.child(
v_flex()
.key_context("AgentFeedbackMessageEditor")
.on_action(cx.listener(|this, _: &menu::Cancel, _, cx| {
this.feedback_message_editor = None;
.on_action(cx.listener(move |this, _: &menu::Cancel, _, cx| {
this.open_feedback_editors.remove(&message_id);
cx.notify();
}))
.on_action(cx.listener(|this, _: &menu::Confirm, _, cx| {
this.submit_feedback_message(cx);
.on_action(cx.listener(move |this, _: &menu::Confirm, _, cx| {
this.submit_feedback_message(message_id, cx);
cx.notify();
}))
.on_action(cx.listener(Self::confirm_editing_message))
.my_3()
.mb_2()
.mx_4()
.p_2()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().editor_background)
.child(feedback_editor)
.child(feedback_editor.clone())
.child(
h_flex()
.gap_1()
@@ -1710,10 +1864,13 @@ impl ActiveThread {
)
.map(|kb| kb.size(rems_from_px(10.))),
)
.on_click(cx.listener(|this, _, _, cx| {
this.feedback_message_editor = None;
cx.notify();
})),
.on_click(cx.listener(
move |this, _, _window, cx| {
this.open_feedback_editors
.remove(&message_id);
cx.notify();
},
)),
)
.child(
Button::new(
@@ -1732,9 +1889,9 @@ impl ActiveThread {
.map(|kb| kb.size(rems_from_px(10.))),
)
.on_click(
cx.listener(|this, _, _, cx| {
this.submit_feedback_message(cx);
cx.notify();
cx.listener(move |this, _, _window, cx| {
this.submit_feedback_message(message_id, cx);
cx.notify()
}),
),
),
@@ -1799,13 +1956,13 @@ impl ActiveThread {
render: Arc::new({
let workspace = workspace.clone();
let active_thread = cx.entity();
move |id, kind, parsed_markdown, range, window, cx| {
move |kind, parsed_markdown, range, metadata, window, cx| {
render_markdown_code_block(
message_id,
id,
range.start,
kind,
parsed_markdown,
range,
metadata,
active_thread.clone(),
workspace.clone(),
window,
@@ -1813,6 +1970,47 @@ impl ActiveThread {
)
}
}),
transform: Some(Arc::new({
let active_thread = cx.entity();
move |el, range, metadata, _, cx| {
let is_expanded = active_thread
.read(cx)
.expanded_code_blocks
.get(&(message_id, range.start))
.copied()
.unwrap_or(false);
if is_expanded
|| metadata.line_count
<= MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK
{
return el;
}
el.child(
div()
.absolute()
.bottom_0()
.left_0()
.w_full()
.h_1_4()
.rounded_b_lg()
.bg(gpui::linear_gradient(
0.,
gpui::linear_color_stop(
cx.theme().colors().editor_background,
0.,
),
gpui::linear_color_stop(
cx.theme()
.colors()
.editor_background
.opacity(0.),
1.,
),
)),
)
}
})),
})
.on_url_click({
let workspace = self.workspace.clone();
@@ -2567,12 +2765,13 @@ impl ActiveThread {
}
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
let project_context = self.thread.read(cx).project_context();
let project_context = project_context.borrow();
let Some(project_context) = project_context.as_ref() else {
return div().into_any();
};
let rules_files = system_prompt_context
let rules_files = project_context
.worktrees
.iter()
.filter_map(|worktree| worktree.rules_file.as_ref())
@@ -2662,12 +2861,13 @@ impl ActiveThread {
}
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
let project_context = self.thread.read(cx).project_context();
let project_context = project_context.borrow();
let Some(project_context) = project_context.as_ref() else {
return;
};
let abs_paths = system_prompt_context
let abs_paths = project_context
.worktrees
.iter()
.flat_map(|worktree| worktree.rules_file.as_ref())
@@ -2804,10 +3004,10 @@ pub(crate) fn open_context(
}
}
AssistantContext::Directory(directory_context) => {
let path = directory_context.project_path.clone();
let project_path = directory_context.project_path(cx);
workspace.update(cx, |workspace, cx| {
workspace.project().update(cx, |project, cx| {
if let Some(entry) = project.entry_for_path(&path, cx) {
if let Some(entry) = project.entry_for_path(&project_path, cx) {
cx.emit(project::Event::RevealInProjectPanel(entry.id));
}
})

View File

@@ -1,7 +1,7 @@
use crate::{Keep, Reject, Thread, ThreadEvent};
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
use anyhow::Result;
use buffer_diff::DiffHunkStatus;
use collections::HashSet;
use collections::{HashMap, HashSet};
use editor::{
Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
actions::{GoToHunk, GoToPreviousHunk},
@@ -355,16 +355,24 @@ impl AgentDiff {
self.update_selection(&diff_hunks_in_ranges, window, cx);
}
let mut ranges_by_buffer = HashMap::default();
for hunk in &diff_hunks_in_ranges {
let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
if let Some(buffer) = buffer {
self.thread
.update(cx, |thread, cx| {
thread.reject_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
})
.detach_and_log_err(cx);
ranges_by_buffer
.entry(buffer.clone())
.or_insert_with(Vec::new)
.push(hunk.buffer_range.clone());
}
}
for (buffer, ranges) in ranges_by_buffer {
self.thread
.update(cx, |thread, cx| {
thread.reject_edits_in_ranges(buffer, ranges, cx)
})
.detach_and_log_err(cx);
}
}
fn update_selection(
@@ -843,7 +851,7 @@ impl ToolbarItemView for AgentDiffToolbar {
}
impl Render for AgentDiffToolbar {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let agent_diff = match self.agent_diff(cx) {
Some(ad) => ad,
None => return div(),
@@ -855,6 +863,8 @@ impl Render for AgentDiffToolbar {
return div();
}
let focus_handle = agent_diff.focus_handle(cx);
h_group_xl()
.my_neg_1()
.items_center()
@@ -864,15 +874,25 @@ impl Render for AgentDiffToolbar {
.child(
h_group_sm()
.child(
Button::new("reject-all", "Reject All").on_click(cx.listener(
|this, _, window, cx| {
this.dispatch_action(&crate::RejectAll, window, cx)
},
)),
Button::new("reject-all", "Reject All")
.key_binding({
KeyBinding::for_action_in(&RejectAll, &focus_handle, window, cx)
.map(|kb| kb.size(rems_from_px(12.)))
})
.on_click(cx.listener(|this, _, window, cx| {
this.dispatch_action(&RejectAll, window, cx)
})),
)
.child(Button::new("keep-all", "Keep All").on_click(cx.listener(
|this, _, window, cx| this.dispatch_action(&crate::KeepAll, window, cx),
))),
.child(
Button::new("keep-all", "Keep All")
.key_binding({
KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx)
.map(|kb| kb.size(rems_from_px(12.)))
})
.on_click(cx.listener(|this, _, window, cx| {
this.dispatch_action(&KeepAll, window, cx)
})),
),
)
}
}
@@ -882,6 +902,7 @@ mod tests {
use super::*;
use crate::{ThreadStore, thread_store};
use assistant_settings::AssistantSettings;
use assistant_tool::ToolWorkingSet;
use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::TestAppContext;
@@ -921,15 +942,16 @@ mod tests {
})
.unwrap();
let thread_store = cx.update(|cx| {
ThreadStore::new(
project.clone(),
Arc::default(),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
.unwrap()
});
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})
.await;
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());

View File

@@ -18,6 +18,7 @@ mod terminal_inline_assistant;
mod thread;
mod thread_history;
mod thread_store;
mod tool_compatibility;
mod tool_use;
mod ui;
@@ -46,10 +47,11 @@ pub use agent_diff::{AgentDiff, AgentDiffToolbar};
actions!(
agent,
[
NewPromptEditor,
NewTextThread,
ToggleContextPicker,
ToggleProfileSelector,
RemoveAllContext,
ExpandMessageEditor,
OpenHistory,
AddContextServer,
RemoveSelectedThread,

View File

@@ -12,7 +12,9 @@ use fs::Fs;
use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
use settings::{Settings, update_settings_file};
use ui::{Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, prelude::*};
use ui::{
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip, prelude::*,
};
use util::ResultExt as _;
use zed_actions::ExtensionCategoryFilter;
@@ -27,7 +29,7 @@ pub struct AssistantConfiguration {
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
context_server_manager: Entity<ContextServerManager>,
expanded_context_server_tools: HashMap<Arc<str>, bool>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
_registry_subscription: Subscription,
}
@@ -35,7 +37,7 @@ impl AssistantConfiguration {
pub fn new(
fs: Arc<dyn Fs>,
context_server_manager: Entity<ContextServerManager>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -224,7 +226,7 @@ impl AssistantConfiguration {
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
let tools_by_source = self.tools.tools_by_source(cx);
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let empty = Vec::new();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
@@ -236,7 +238,10 @@ impl AssistantConfiguration {
.child(
v_flex()
.gap_0p5()
.child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
.child(
Headline::new("Model Context Protocol (MCP) Servers")
.size(HeadlineSize::Small),
)
.child(Label::new(SUBHEADING).color(Color::Muted)),
)
.children(context_servers.into_iter().map(|context_server| {
@@ -262,10 +267,9 @@ impl AssistantConfiguration {
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.p_1()
.justify_between()
.px_2()
.py_1()
.when(are_tools_expanded, |element| {
.when(are_tools_expanded && tool_count > 1, |element| {
element
.border_b_1()
.border_color(cx.theme().colors().border)
@@ -275,6 +279,7 @@ impl AssistantConfiguration {
.gap_2()
.child(
Disclosure::new("tool-list-disclosure", are_tools_expanded)
.disabled(tool_count == 0)
.on_click(cx.listener({
let context_server_id = context_server.id();
move |this, _event, _window, _cx| {
@@ -295,10 +300,11 @@ impl AssistantConfiguration {
.child(Label::new(context_server.id()))
.child(
Label::new(format!("{tool_count} tools"))
.color(Color::Muted),
.color(Color::Muted)
.size(LabelSize::Small),
),
)
.child(h_flex().child(
.child(
Switch::new("context-server-switch", is_running.into()).on_click({
let context_server_manager =
self.context_server_manager.clone();
@@ -334,7 +340,7 @@ impl AssistantConfiguration {
}
}
}),
)),
),
)
.map(|parent| {
if !are_tools_expanded {
@@ -344,14 +350,29 @@ impl AssistantConfiguration {
parent.child(v_flex().children(tools.into_iter().enumerate().map(
|(ix, tool)| {
h_flex()
.px_2()
.id("tool-item")
.pl_2()
.pr_1()
.py_1()
.gap_2()
.justify_between()
.when(ix < tool_count - 1, |element| {
element
.border_b_1()
.border_color(cx.theme().colors().border)
.border_color(cx.theme().colors().border_variant)
})
.child(Label::new(tool.name()))
.child(
Label::new(tool.name())
.buffer_font(cx)
.size(LabelSize::Small),
)
.child(
IconButton::new(("tool-description", ix), IconName::Info)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Ignored)
.tooltip(Tooltip::text(tool.description())),
)
},
)))
})
@@ -362,7 +383,7 @@ impl AssistantConfiguration {
.gap_2()
.child(
h_flex().w_full().child(
Button::new("add-context-server", "Add Context Server")
Button::new("add-context-server", "Add MCPs Directly")
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
.full_width()
@@ -378,7 +399,7 @@ impl AssistantConfiguration {
h_flex().w_full().child(
Button::new(
"install-context-server-extensions",
"Install Context Server Extensions",
"Install MCP Extensions",
)
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)

View File

@@ -84,7 +84,7 @@ pub struct NewProfileMode {
pub struct ManageProfilesModal {
fs: Arc<dyn Fs>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle,
mode: Mode,
@@ -117,7 +117,7 @@ impl ManageProfilesModal {
pub fn new(
fs: Arc<dyn Fs>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut Context<Self>,

View File

@@ -60,7 +60,7 @@ pub struct ToolPickerDelegate {
impl ToolPickerDelegate {
pub fn new(
fs: Arc<dyn Fs>,
tool_set: Arc<ToolWorkingSet>,
tool_set: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
profile_id: AgentProfileId,
profile: AgentProfile,
@@ -68,7 +68,7 @@ impl ToolPickerDelegate {
) -> Self {
let mut tool_entries = Vec::new();
for (source, tools) in tool_set.tools_by_source(cx) {
for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
name: tool.name().into(),
source: source.clone(),
@@ -192,7 +192,7 @@ impl PickerDelegate for ToolPickerDelegate {
if active_profile_id == &self.profile_id {
self.thread_store
.update(cx, |this, cx| {
this.load_profile(&self.profile, cx);
this.load_profile(self.profile.clone(), cx);
})
.log_err();
}

View File

@@ -80,17 +80,16 @@ impl AssistantModelSelector {
impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let model_registry = LanguageModelRegistry::read_global(cx);
let focus_handle = self.focus_handle.clone();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = match self.model_type {
ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
};
let focus_handle = self.focus_handle.clone();
let model_name = match model {
Some(model) => model.model.name().0,
_ => SharedString::from("No model selected"),
let (model_name, model_icon) = match model {
Some(model) => (model.model.name().0, Some(model.provider.icon())),
_ => (SharedString::from("No model selected"), None),
};
LanguageModelSelectorPopoverMenu::new(
@@ -100,10 +99,16 @@ impl Render for AssistantModelSelector {
.child(
h_flex()
.gap_0p5()
.children(
model_icon.map(|icon| {
Icon::new(icon).color(Color::Muted).size(IconSize::Small)
}),
)
.child(
Label::new(model_name)
.size(LabelSize::Small)
.color(Color::Muted),
.color(Color::Muted)
.ml_1(),
)
.child(
Icon::new(IconName::ChevronDown)

View File

@@ -44,8 +44,8 @@ use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
use crate::thread_history::{PastContext, PastThread, ThreadHistory};
use crate::thread_store::ThreadStore;
use crate::{
AgentDiff, InlineAssistant, NewPromptEditor, NewThread, OpenActiveThreadAsMarkdown,
OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker,
AgentDiff, ExpandMessageEditor, InlineAssistant, NewTextThread, NewThread,
OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker,
};
pub fn init(cx: &mut App) {
@@ -70,7 +70,7 @@ pub fn init(cx: &mut App) {
panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
}
})
.register_action(|workspace, _: &NewPromptEditor, window, cx| {
.register_action(|workspace, _: &NewTextThread, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
@@ -90,6 +90,16 @@ pub fn init(cx: &mut App) {
let thread = panel.read(cx).thread.read(cx).thread().clone();
AgentDiff::deploy_in_workspace(thread, workspace, window, cx);
}
})
.register_action(|workspace, _: &ExpandMessageEditor, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.message_editor.update(cx, |editor, cx| {
editor.expand_message_editor(&ExpandMessageEditor, window, cx);
});
});
}
});
},
)
@@ -193,11 +203,13 @@ impl AssistantPanel {
cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
let tools = Arc::new(ToolWorkingSet::default());
let thread_store = workspace.update(cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::new(project, tools.clone(), prompt_builder.clone(), cx)
})??;
let tools = cx.new(|_| ToolWorkingSet::default())?;
let thread_store = workspace
.update(cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
})?
.await;
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
let context_store = workspace
@@ -557,6 +569,7 @@ impl AssistantPanel {
ActiveView::Configuration | ActiveView::History => {
self.active_view =
ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx);
self.message_editor.focus_handle(cx).focus(window);
cx.notify();
}
_ => {}
@@ -863,7 +876,11 @@ impl AssistantPanel {
.truncate()
.into_any_element()
} else {
change_title_editor.clone().into_any_element()
div()
.ml_2()
.w_full()
.child(change_title_editor.clone())
.into_any_element()
}
}
ActiveView::PromptEditor => {
@@ -1082,20 +1099,30 @@ impl AssistantPanel {
window,
cx,
|menu, _window, _cx| {
menu.action(
menu
.when(!is_empty, |menu| {
menu.action(
"Start New From Summary",
Box::new(NewThread {
from_thread_id: Some(thread_id.clone()),
}),
).separator()
})
.action(
"New Text Thread",
NewPromptEditor.boxed_clone(),
NewTextThread.boxed_clone(),
)
.when(!is_empty, |menu| {
menu.action(
"Continue in New Thread",
Box::new(NewThread {
from_thread_id: Some(thread_id.clone()),
}),
)
})
.separator()
.action("Settings", OpenConfiguration.boxed_clone())
.separator()
.action(
"Install MCPs",
zed_actions::Extensions {
category_filter: Some(
zed_actions::ExtensionCategoryFilter::ContextServers,
),
}
.boxed_clone(),
)
},
))
}),
@@ -1292,6 +1319,7 @@ impl AssistantPanel {
let configuration_error_ref = &configuration_error;
parent
.overflow_hidden()
.p_1p5()
.justify_end()
.gap_1()
@@ -1624,7 +1652,21 @@ impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
cx: &mut Context<PromptLibrary>,
) {
InlineAssistant::update_global(cx, |assistant, cx| {
assistant.assist(&prompt_editor, self.workspace.clone(), None, window, cx)
let Some(project) = self
.workspace
.upgrade()
.map(|workspace| workspace.read(cx).project().downgrade())
else {
return;
};
assistant.assist(
&prompt_editor,
self.workspace.clone(),
project,
None,
window,
cx,
)
})
}

View File

@@ -1,9 +1,9 @@
use std::{ops::Range, sync::Arc};
use std::{ops::Range, path::Path, sync::Arc};
use gpui::{App, Entity, SharedString};
use language::{Buffer, File};
use language_model::LanguageModelRequestMessage;
use project::ProjectPath;
use project::{ProjectPath, Worktree};
use serde::{Deserialize, Serialize};
use text::{Anchor, BufferId};
use ui::IconName;
@@ -69,10 +69,21 @@ pub struct FileContext {
#[derive(Debug, Clone)]
pub struct DirectoryContext {
pub id: ContextId,
pub project_path: ProjectPath,
pub worktree: Entity<Worktree>,
pub path: Arc<Path>,
/// Buffers of the files within the directory.
pub context_buffers: Vec<ContextBuffer>,
}
impl DirectoryContext {
pub fn project_path(&self, cx: &App) -> ProjectPath {
ProjectPath {
worktree_id: self.worktree.read(cx).id(),
path: self.path.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct SymbolContext {
pub id: ContextId,
@@ -86,12 +97,11 @@ pub struct FetchedUrlContext {
pub text: SharedString,
}
// TODO: Model<Thread> holds onto the thread even if the thread is deleted. Can either handle this
// explicitly or have a WeakModel<Thread> and remove during snapshot.
#[derive(Debug, Clone)]
pub struct ThreadContext {
pub id: ContextId,
// TODO: Entity<Thread> holds onto the thread even if the thread is deleted. Should probably be
// a WeakEntity and handle removal from the UI when it has dropped.
pub thread: Entity<Thread>,
pub text: SharedString,
}
@@ -105,12 +115,11 @@ impl ThreadContext {
}
}
// TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove
// the context from the message editor in this case.
#[derive(Clone)]
pub struct ContextBuffer {
pub id: BufferId,
// TODO: Entity<Buffer> holds onto the thread even if the thread is deleted. Should probably be
// a WeakEntity and handle removal from the UI when it has dropped.
pub buffer: Entity<Buffer>,
pub file: Arc<dyn File>,
pub version: clock::Global,

View File

@@ -34,12 +34,6 @@ use crate::context_store::ContextStore;
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
#[derive(Debug, Clone, Copy)]
pub enum ConfirmBehavior {
KeepOpen,
Close,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContextPickerMode {
File,
@@ -105,7 +99,6 @@ pub(super) struct ContextPicker {
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
confirm_behavior: ConfirmBehavior,
_subscriptions: Vec<Subscription>,
}
@@ -114,7 +107,6 @@ impl ContextPicker {
workspace: WeakEntity<Workspace>,
thread_store: Option<WeakEntity<ThreadStore>>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -143,7 +135,6 @@ impl ContextPicker {
workspace,
context_store,
thread_store,
confirm_behavior,
_subscriptions: subscriptions,
}
}
@@ -166,37 +157,32 @@ impl ContextPicker {
let modes = supported_context_picker_modes(&self.thread_store);
let menu = menu
.when(has_recent, |menu| {
menu.custom_row(|_, _| {
div()
.mb_1()
.child(
Label::new("Recent")
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element()
})
menu.when(has_recent, |menu| {
menu.custom_row(|_, _| {
div()
.mb_1()
.child(
Label::new("Recent")
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element()
})
.extend(recent_entries)
.when(has_recent, |menu| menu.separator())
.extend(modes.into_iter().map(|mode| {
let context_picker = context_picker.clone();
})
.extend(recent_entries)
.when(has_recent, |menu| menu.separator())
.extend(modes.into_iter().map(|mode| {
let context_picker = context_picker.clone();
ContextMenuEntry::new(mode.label())
.icon(mode.icon())
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.handler(move |window, cx| {
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
})
}));
match self.confirm_behavior {
ConfirmBehavior::KeepOpen => menu.keep_open_on_confirm(),
ConfirmBehavior::Close => menu,
}
ContextMenuEntry::new(mode.label())
.icon(mode.icon())
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.handler(move |window, cx| {
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
})
}))
.keep_open_on_confirm()
});
cx.subscribe(&menu, move |_, _, _: &DismissEvent, cx| {
@@ -227,7 +213,6 @@ impl ContextPicker {
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -239,7 +224,6 @@ impl ContextPicker {
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -251,7 +235,6 @@ impl ContextPicker {
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -264,7 +247,6 @@ impl ContextPicker {
thread_store.clone(),
context_picker.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -289,12 +271,14 @@ impl ContextPicker {
path_prefix,
} => {
let context_store = self.context_store.clone();
let worktree_id = project_path.worktree_id;
let path = project_path.path.clone();
ContextMenuItem::custom_entry(
move |_window, cx| {
render_file_context_entry(
ElementId::NamedInteger("ctx-recent".into(), ix),
worktree_id,
&path,
&path_prefix,
false,
@@ -466,7 +450,7 @@ fn recent_context_picker_entries(
recent.extend(
workspace
.recent_navigation_history_iter(cx)
.filter(|(path, _)| !current_files.contains(&path.path.to_path_buf()))
.filter(|(path, _)| !current_files.contains(path))
.take(4)
.filter_map(|(project_path, _)| {
project
@@ -507,14 +491,13 @@ fn recent_context_picker_entries(
recent
}
pub(crate) fn insert_crease_for_mention(
pub(crate) fn insert_fold_for_mention(
excerpt_id: ExcerptId,
crease_start: text::Anchor,
content_len: usize,
crease_label: SharedString,
crease_icon_path: SharedString,
editor_entity: Entity<Editor>,
window: &mut Window,
cx: &mut App,
) {
editor_entity.update(cx, |editor, cx| {
@@ -533,6 +516,7 @@ pub(crate) fn insert_crease_for_mention(
crease_label,
editor_entity.downgrade(),
),
merge_adjacent: false,
..Default::default()
};
@@ -546,8 +530,9 @@ pub(crate) fn insert_crease_for_mention(
render_trailer,
);
editor.insert_creases(vec![crease.clone()], cx);
editor.fold_creases(vec![crease], false, window, cx);
editor.display_map.update(cx, |display_map, cx| {
display_map.fold(vec![crease], cx);
});
});
}
@@ -604,12 +589,13 @@ fn render_fold_icon_button(
.gap_1()
.child(
Icon::from_path(icon_path.clone())
.size(IconSize::Small)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(
Label::new(label.clone())
.size(LabelSize::Small)
.buffer_font(cx)
.single_line(),
),
)

View File

@@ -18,16 +18,133 @@ use text::{Anchor, ToPoint};
use ui::prelude::*;
use workspace::Workspace;
use crate::context::AssistantContext;
use crate::context_picker::file_context_picker::search_files;
use crate::context_picker::symbol_context_picker::search_symbols;
use crate::context_store::ContextStore;
use crate::thread_store::ThreadStore;
use super::fetch_context_picker::fetch_url_content;
use super::thread_context_picker::ThreadContextEntry;
use super::file_context_picker::FileMatch;
use super::symbol_context_picker::SymbolMatch;
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
use super::{
ContextPickerMode, MentionLink, recent_context_picker_entries, supported_context_picker_modes,
ContextPickerMode, MentionLink, RecentEntry, recent_context_picker_entries,
supported_context_picker_modes,
};
pub(crate) enum Match {
Symbol(SymbolMatch),
File(FileMatch),
Thread(ThreadMatch),
Fetch(SharedString),
Mode(ContextPickerMode),
}
fn search(
mode: Option<ContextPickerMode>,
query: String,
cancellation_flag: Arc<AtomicBool>,
recent_entries: Vec<RecentEntry>,
thread_store: Option<WeakEntity<ThreadStore>>,
workspace: Entity<Workspace>,
cx: &mut App,
) -> Task<Vec<Match>> {
match mode {
Some(ContextPickerMode::File) => {
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
cx.background_spawn(async move {
search_files_task
.await
.into_iter()
.map(Match::File)
.collect()
})
}
Some(ContextPickerMode::Symbol) => {
let search_symbols_task =
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
cx.background_spawn(async move {
search_symbols_task
.await
.into_iter()
.map(Match::Symbol)
.collect()
})
}
Some(ContextPickerMode::Thread) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
let search_threads_task =
search_threads(query.clone(), cancellation_flag.clone(), thread_store, cx);
cx.background_spawn(async move {
search_threads_task
.await
.into_iter()
.map(Match::Thread)
.collect()
})
} else {
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Fetch) => {
if !query.is_empty() {
Task::ready(vec![Match::Fetch(query.into())])
} else {
Task::ready(Vec::new())
}
}
None => {
if query.is_empty() {
let mut matches = recent_entries
.into_iter()
.map(|entry| match entry {
super::RecentEntry::File {
project_path,
path_prefix,
} => Match::File(FileMatch {
mat: fuzzy::PathMatch {
score: 1.,
positions: Vec::new(),
worktree_id: project_path.worktree_id.to_usize(),
path: project_path.path,
path_prefix,
is_dir: false,
distance_to_relative_ancestor: 0,
},
is_recent: true,
}),
super::RecentEntry::Thread(thread_context_entry) => {
Match::Thread(ThreadMatch {
thread: thread_context_entry,
is_recent: true,
})
}
})
.collect::<Vec<_>>();
matches.extend(
supported_context_picker_modes(&thread_store)
.into_iter()
.map(Match::Mode),
);
Task::ready(matches)
} else {
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
cx.background_spawn(async move {
search_files_task
.await
.into_iter()
.map(Match::File)
.collect()
})
}
}
}
}
pub struct ContextPickerCompletionProvider {
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
@@ -50,97 +167,20 @@ impl ContextPickerCompletionProvider {
}
}
fn default_completions(
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
context_store: Entity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
editor: Entity<Editor>,
workspace: Entity<Workspace>,
cx: &App,
) -> Vec<Completion> {
let mut completions = Vec::new();
completions.extend(
recent_context_picker_entries(
context_store.clone(),
thread_store.clone(),
workspace.clone(),
cx,
)
.iter()
.filter_map(|entry| match entry {
super::RecentEntry::File {
project_path,
path_prefix,
} => Some(Self::completion_for_path(
project_path.clone(),
path_prefix,
true,
false,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
cx,
)),
super::RecentEntry::Thread(thread_context_entry) => {
let thread_store = thread_store
.as_ref()
.and_then(|thread_store| thread_store.upgrade())?;
Some(Self::completion_for_thread(
thread_context_entry.clone(),
excerpt_id,
source_range.clone(),
true,
editor.clone(),
context_store.clone(),
thread_store,
))
}
}),
);
completions.extend(
supported_context_picker_modes(&thread_store)
.iter()
.map(|mode| {
Completion {
replace_range: source_range.clone(),
new_text: format!("@{} ", mode.mention_prefix()),
label: CodeLabel::plain(mode.label().to_string(), None),
icon_path: Some(mode.icon().path().into()),
documentation: None,
source: project::CompletionSource::Custom,
insert_text_mode: None,
// This ensures that when a user accepts this completion, the
// completion menu will still be shown after "@category " is
// inserted
confirm: Some(Arc::new(|_, _, _| true)),
}
}),
);
completions
}
fn build_code_label_for_full_path(
file_name: &str,
directory: Option<&str>,
cx: &App,
) -> CodeLabel {
let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId);
let mut label = CodeLabel::default();
label.push_str(&file_name, None);
label.push_str(" ", None);
if let Some(directory) = directory {
label.push_str(&directory, comment_id);
fn completion_for_mode(source_range: Range<Anchor>, mode: ContextPickerMode) -> Completion {
Completion {
replace_range: source_range.clone(),
new_text: format!("@{} ", mode.mention_prefix()),
label: CodeLabel::plain(mode.label().to_string(), None),
icon_path: Some(mode.icon().path().into()),
documentation: None,
source: project::CompletionSource::Custom,
insert_text_mode: None,
// This ensures that when a user accepts this completion, the
// completion menu will still be shown after "@category " is
// inserted
confirm: Some(Arc::new(|_, _, _| true)),
}
label.filter_range = 0..label.text().len();
label
}
fn completion_for_thread(
@@ -261,11 +301,8 @@ impl ContextPickerCompletionProvider {
path_prefix,
);
let label = Self::build_code_label_for_full_path(
&file_name,
directory.as_ref().map(|s| s.as_ref()),
cx,
);
let label =
build_code_label_for_full_path(&file_name, directory.as_ref().map(|s| s.as_ref()), cx);
let full_path = if let Some(directory) = directory {
format!("{}{}", directory, file_name)
} else {
@@ -382,6 +419,22 @@ impl ContextPickerCompletionProvider {
}
}
fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: &App) -> CodeLabel {
let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId);
let mut label = CodeLabel::default();
label.push_str(&file_name, None);
label.push_str(" ", None);
if let Some(directory) = directory {
label.push_str(&directory, comment_id);
}
label.filter_range = 0..label.text().len();
label
}
impl CompletionProvider for ContextPickerCompletionProvider {
fn completions(
&self,
@@ -404,10 +457,9 @@ impl CompletionProvider for ContextPickerCompletionProvider {
return Task::ready(Ok(None));
};
let Some(workspace) = self.workspace.upgrade() else {
return Task::ready(Ok(None));
};
let Some(context_store) = self.context_store.upgrade() else {
let Some((workspace, context_store)) =
self.workspace.upgrade().zip(self.context_store.upgrade())
else {
return Task::ready(Ok(None));
};
@@ -419,154 +471,89 @@ impl CompletionProvider for ContextPickerCompletionProvider {
let editor = self.editor.clone();
let http_client = workspace.read(cx).client().http_client().clone();
let MentionCompletion { mode, argument, .. } = state;
let query = argument.unwrap_or_else(|| "".to_string());
let recent_entries = recent_context_picker_entries(
context_store.clone(),
thread_store.clone(),
workspace.clone(),
cx,
);
let search_task = search(
mode,
query,
Arc::<AtomicBool>::default(),
recent_entries,
thread_store.clone(),
workspace.clone(),
cx,
);
cx.spawn(async move |_, cx| {
let mut completions = Vec::new();
let matches = search_task.await;
let Some(editor) = editor.upgrade() else {
return Ok(None);
};
let MentionCompletion { mode, argument, .. } = state;
let query = argument.unwrap_or_else(|| "".to_string());
match mode {
Some(ContextPickerMode::File) => {
let path_matches = cx
.update(|cx| {
super::file_context_picker::search_paths(
query,
Arc::<AtomicBool>::default(),
&workspace,
cx,
)
})?
.await;
if let Some(editor) = editor.upgrade() {
completions.reserve(path_matches.len());
cx.update(|cx| {
completions.extend(path_matches.iter().map(|mat| {
Self::completion_for_path(
ProjectPath {
worktree_id: WorktreeId::from_usize(mat.worktree_id),
path: mat.path.clone(),
},
&mat.path_prefix,
false,
mat.is_dir,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
cx,
)
}));
})?;
}
}
Some(ContextPickerMode::Symbol) => {
if let Some(editor) = editor.upgrade() {
let symbol_matches = cx
.update(|cx| {
super::symbol_context_picker::search_symbols(
query,
Arc::new(AtomicBool::default()),
&workspace,
cx,
)
})?
.await?;
cx.update(|cx| {
completions.extend(symbol_matches.into_iter().filter_map(
|(_, symbol)| {
Self::completion_for_symbol(
symbol,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
workspace.clone(),
cx,
)
Ok(Some(cx.update(|cx| {
matches
.into_iter()
.filter_map(|mat| match mat {
Match::File(FileMatch { mat, is_recent }) => {
Some(Self::completion_for_path(
ProjectPath {
worktree_id: WorktreeId::from_usize(mat.worktree_id),
path: mat.path.clone(),
},
));
})?;
}
}
Some(ContextPickerMode::Fetch) => {
if let Some(editor) = editor.upgrade() {
if !query.is_empty() {
completions.push(Self::completion_for_fetch(
source_range.clone(),
query.into(),
&mat.path_prefix,
is_recent,
mat.is_dir,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
http_client.clone(),
));
}
context_store.update(cx, |store, _| {
let urls = store.context().iter().filter_map(|context| {
if let AssistantContext::FetchedUrl(context) = context {
Some(context.url.clone())
} else {
None
}
});
for url in urls {
completions.push(Self::completion_for_fetch(
source_range.clone(),
url,
excerpt_id,
editor.clone(),
context_store.clone(),
http_client.clone(),
));
}
})?;
}
}
Some(ContextPickerMode::Thread) => {
if let Some((thread_store, editor)) = thread_store
.and_then(|thread_store| thread_store.upgrade())
.zip(editor.upgrade())
{
let threads = cx
.update(|cx| {
super::thread_context_picker::search_threads(
query,
thread_store.clone(),
cx,
)
})?
.await;
for thread in threads {
completions.push(Self::completion_for_thread(
thread.clone(),
excerpt_id,
source_range.clone(),
false,
editor.clone(),
context_store.clone(),
thread_store.clone(),
));
}
}
}
None => {
cx.update(|cx| {
if let Some(editor) = editor.upgrade() {
completions.extend(Self::default_completions(
excerpt_id,
source_range.clone(),
context_store.clone(),
thread_store.clone(),
editor,
workspace.clone(),
cx,
));
))
}
})?;
}
}
Ok(Some(completions))
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
symbol,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
workspace.clone(),
cx,
),
Match::Thread(ThreadMatch {
thread, is_recent, ..
}) => {
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
Some(Self::completion_for_thread(
thread,
excerpt_id,
source_range.clone(),
is_recent,
editor.clone(),
context_store.clone(),
thread_store,
))
}
Match::Fetch(url) => Some(Self::completion_for_fetch(
source_range.clone(),
url,
excerpt_id,
editor.clone(),
context_store.clone(),
http_client.clone(),
)),
Match::Mode(mode) => {
Some(Self::completion_for_mode(source_range.clone(), mode))
}
})
.collect()
})?))
})
}
@@ -623,21 +610,20 @@ fn confirm_completion_callback(
editor: Entity<Editor>,
add_context_fn: impl Fn(&mut App) -> () + Send + Sync + 'static,
) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> {
Arc::new(move |_, window, cx| {
Arc::new(move |_, _, cx| {
add_context_fn(cx);
let crease_text = crease_text.clone();
let crease_icon_path = crease_icon_path.clone();
let editor = editor.clone();
window.defer(cx, move |window, cx| {
crate::context_picker::insert_crease_for_mention(
cx.defer(move |cx| {
crate::context_picker::insert_fold_for_mention(
excerpt_id,
start,
content_len,
crease_text,
crease_icon_path,
editor,
window,
cx,
);
});
@@ -676,7 +662,12 @@ impl MentionCompletion {
let mut end = last_mention_start + 1;
if let Some(mode_text) = parts.next() {
end += mode_text.len();
mode = ContextPickerMode::try_from(mode_text).ok();
if let Some(parsed_mode) = ContextPickerMode::try_from(mode_text).ok() {
mode = Some(parsed_mode);
} else {
argument = Some(mode_text.to_string());
}
match rest_of_line[mode_text.len()..].find(|c: char| !c.is_whitespace()) {
Some(whitespace_count) => {
if let Some(argument_text) = parts.next() {
@@ -702,13 +693,14 @@ impl MentionCompletion {
#[cfg(test)]
mod tests {
use super::*;
use gpui::{Focusable, TestAppContext, VisualTestContext};
use editor::AnchorRangeExt;
use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext};
use project::{Project, ProjectPath};
use serde_json::json;
use settings::SettingsStore;
use std::{ops::Deref, path::PathBuf};
use std::ops::Deref;
use util::{path, separator};
use workspace::AppState;
use workspace::{AppState, Item};
#[test]
fn test_mention_completion_parse() {
@@ -768,9 +760,42 @@ mod tests {
})
);
assert_eq!(
MentionCompletion::try_parse("Lorem @main", 0),
Some(MentionCompletion {
source_range: 6..11,
mode: None,
argument: Some("main".to_string()),
})
);
assert_eq!(MentionCompletion::try_parse("test@", 0), None);
}
struct AtMentionEditor(Entity<Editor>);
impl Item for AtMentionEditor {
type Event = ();
fn include_in_nav_history() -> bool {
false
}
}
impl EventEmitter<()> for AtMentionEditor {}
impl Focusable for AtMentionEditor {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.0.read(cx).focus_handle(cx).clone()
}
}
impl Render for AtMentionEditor {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.0.clone().into_any_element()
}
}
#[gpui::test]
async fn test_context_completion_provider(cx: &mut TestAppContext) {
init_test(cx);
@@ -846,25 +871,27 @@ mod tests {
.unwrap();
}
let item = workspace
.update_in(&mut cx, |workspace, window, cx| {
workspace.open_path(
ProjectPath {
worktree_id,
path: PathBuf::from("editor").into(),
},
let editor = workspace.update_in(&mut cx, |workspace, window, cx| {
let editor = cx.new(|cx| {
Editor::new(
editor::EditorMode::full(),
multi_buffer::MultiBuffer::build_simple("", cx),
None,
true,
window,
cx,
)
})
.await
.expect("Could not open test file");
let editor = cx.update(|_, cx| {
item.act_as::<Editor>(cx)
.expect("Opened test file wasn't an editor")
});
workspace.active_pane().update(cx, |pane, cx| {
pane.add_item(
Box::new(cx.new(|_| AtMentionEditor(editor.clone()))),
true,
true,
None,
window,
cx,
);
});
editor
});
let context_store = cx.new(|_| ContextStore::new(project.downgrade(), None));
@@ -895,10 +922,10 @@ mod tests {
assert_eq!(
current_completion_labels(editor),
&[
"editor dir/",
"seven.txt dir/b/",
"six.txt dir/b/",
"five.txt dir/b/",
"four.txt dir/a/",
"Files & Directories",
"Symbols",
"Fetch"
@@ -940,7 +967,7 @@ mod tests {
assert_eq!(editor.text(cx), "Lorem [@one.txt](@file:dir/a/one.txt)",);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
fold_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -951,7 +978,7 @@ mod tests {
assert_eq!(editor.text(cx), "Lorem [@one.txt](@file:dir/a/one.txt) ",);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
fold_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -965,7 +992,7 @@ mod tests {
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
fold_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -979,7 +1006,7 @@ mod tests {
);
assert!(editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
fold_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -993,14 +1020,14 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@editor](@file:dir/editor)"
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@seven.txt](@file:dir/b/seven.txt)"
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
fold_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 71)
Point::new(0, 44)..Point::new(0, 79)
]
);
});
@@ -1010,14 +1037,14 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@editor](@file:dir/editor)\n@"
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@seven.txt](@file:dir/b/seven.txt)\n@"
);
assert!(editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
fold_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 71)
Point::new(0, 44)..Point::new(0, 79)
]
);
});
@@ -1031,29 +1058,27 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@editor](@file:dir/editor)\n[@seven.txt](@file:dir/b/seven.txt)"
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@seven.txt](@file:dir/b/seven.txt)\n[@six.txt](@file:dir/b/six.txt)"
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
fold_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 71),
Point::new(1, 0)..Point::new(1, 35)
Point::new(0, 44)..Point::new(0, 79),
Point::new(1, 0)..Point::new(1, 31)
]
);
});
}
fn crease_ranges(editor: &Editor, cx: &mut App) -> Vec<Range<Point>> {
fn fold_ranges(editor: &Editor, cx: &mut App) -> Vec<Range<Point>> {
let snapshot = editor.buffer().read(cx).snapshot(cx);
editor.display_map.update(cx, |display_map, cx| {
display_map
.snapshot(cx)
.crease_snapshot
.crease_items_with_offsets(&snapshot)
.into_iter()
.map(|(_, range)| range)
.folds_in_range(0..snapshot.len())
.map(|fold| fold.range.to_point(&snapshot))
.collect()
})
}

View File

@@ -11,7 +11,7 @@ use picker::{Picker, PickerDelegate};
use ui::{Context, ListItem, Window, prelude::*};
use workspace::Workspace;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
pub struct FetchContextPicker {
@@ -23,16 +23,10 @@ impl FetchContextPicker {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = FetchContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let delegate = FetchContextPickerDelegate::new(context_picker, workspace, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
@@ -62,7 +56,6 @@ pub struct FetchContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
url: String,
}
@@ -71,13 +64,11 @@ impl FetchContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
FetchContextPickerDelegate {
context_picker,
workspace,
context_store,
confirm_behavior,
url: String::new(),
}
}
@@ -204,25 +195,15 @@ impl PickerDelegate for FetchContextPickerDelegate {
let http_client = workspace.read(cx).client().http_client().clone();
let url = self.url.clone();
let confirm_behavior = self.confirm_behavior;
cx.spawn_in(window, async move |this, cx| {
let text = cx
.background_spawn(fetch_url_content(http_client, url.clone()))
.await?;
this.update_in(cx, |this, window, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_fetched_url(url, text, cx)
})?;
match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}
anyhow::Ok(())
this.update(cx, |this, cx| {
this.delegate.context_store.update(cx, |context_store, cx| {
context_store.add_fetched_url(url, text, cx)
})
})??;
anyhow::Ok(())

View File

@@ -11,9 +11,9 @@ use picker::{Picker, PickerDelegate};
use project::{PathMatchCandidateSet, ProjectPath, WorktreeId};
use ui::{ListItem, Tooltip, prelude::*};
use util::ResultExt as _;
use workspace::{Workspace, notifications::NotifyResultExt};
use workspace::Workspace;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_picker::ContextPicker;
use crate::context_store::{ContextStore, FileInclusion};
pub struct FileContextPicker {
@@ -25,16 +25,10 @@ impl FileContextPicker {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = FileContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let delegate = FileContextPickerDelegate::new(context_picker, workspace, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
@@ -57,8 +51,7 @@ pub struct FileContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<PathMatch>,
matches: Vec<FileMatch>,
selected_index: usize,
}
@@ -67,13 +60,11 @@ impl FileContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
Self {
context_picker,
workspace,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
@@ -114,7 +105,7 @@ impl PickerDelegate for FileContextPickerDelegate {
return Task::ready(());
};
let search_task = search_paths(query, Arc::<AtomicBool>::default(), &workspace, cx);
let search_task = search_files(query, Arc::<AtomicBool>::default(), &workspace, cx);
cx.spawn_in(window, async move |this, cx| {
// TODO: This should be probably be run in the background.
@@ -127,8 +118,8 @@ impl PickerDelegate for FileContextPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(mat) = self.matches.get(self.selected_index) else {
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(FileMatch { mat, .. }) = self.matches.get(self.selected_index) else {
return;
};
@@ -153,17 +144,7 @@ impl PickerDelegate for FileContextPickerDelegate {
return;
};
let confirm_behavior = self.confirm_behavior;
cx.spawn_in(window, async move |this, cx| {
match task.await.notify_async_err(cx) {
None => anyhow::Ok(()),
Some(()) => this.update_in(cx, |this, window, cx| match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}),
}
})
.detach_and_log_err(cx);
task.detach_and_log_err(cx);
}
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
@@ -181,7 +162,7 @@ impl PickerDelegate for FileContextPickerDelegate {
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let path_match = &self.matches[ix];
let FileMatch { mat, .. } = &self.matches[ix];
Some(
ListItem::new(ix)
@@ -189,9 +170,10 @@ impl PickerDelegate for FileContextPickerDelegate {
.toggle_state(selected)
.child(render_file_context_entry(
ElementId::NamedInteger("file-ctx-picker".into(), ix),
&path_match.path,
&path_match.path_prefix,
path_match.is_dir,
WorktreeId::from_usize(mat.worktree_id),
&mat.path,
&mat.path_prefix,
mat.is_dir,
self.context_store.clone(),
cx,
)),
@@ -199,12 +181,17 @@ impl PickerDelegate for FileContextPickerDelegate {
}
}
pub(crate) fn search_paths(
pub struct FileMatch {
pub mat: PathMatch,
pub is_recent: bool,
}
pub(crate) fn search_files(
query: String,
cancellation_flag: Arc<AtomicBool>,
workspace: &Entity<Workspace>,
cx: &App,
) -> Task<Vec<PathMatch>> {
) -> Task<Vec<FileMatch>> {
if query.is_empty() {
let workspace = workspace.read(cx);
let project = workspace.project().read(cx);
@@ -213,28 +200,34 @@ pub(crate) fn search_paths(
.into_iter()
.filter_map(|(project_path, _)| {
let worktree = project.worktree_for_id(project_path.worktree_id, cx)?;
Some(PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: project_path.worktree_id.to_usize(),
path: project_path.path,
path_prefix: worktree.read(cx).root_name().into(),
distance_to_relative_ancestor: 0,
is_dir: false,
Some(FileMatch {
mat: PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: project_path.worktree_id.to_usize(),
path: project_path.path,
path_prefix: worktree.read(cx).root_name().into(),
distance_to_relative_ancestor: 0,
is_dir: false,
},
is_recent: true,
})
});
let file_matches = project.worktrees(cx).flat_map(|worktree| {
let worktree = worktree.read(cx);
let path_prefix: Arc<str> = worktree.root_name().into();
worktree.entries(false, 0).map(move |entry| PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: worktree.id().to_usize(),
path: entry.path.clone(),
path_prefix: path_prefix.clone(),
distance_to_relative_ancestor: 0,
is_dir: entry.is_dir(),
worktree.entries(false, 0).map(move |entry| FileMatch {
mat: PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: worktree.id().to_usize(),
path: entry.path.clone(),
path_prefix: path_prefix.clone(),
distance_to_relative_ancestor: 0,
is_dir: entry.is_dir(),
},
is_recent: false,
})
});
@@ -269,6 +262,12 @@ pub(crate) fn search_paths(
executor,
)
.await
.into_iter()
.map(|mat| FileMatch {
mat,
is_recent: false,
})
.collect::<Vec<_>>()
})
}
}
@@ -311,19 +310,26 @@ pub fn extract_file_name_and_directory(
pub fn render_file_context_entry(
id: ElementId,
path: &Path,
worktree_id: WorktreeId,
path: &Arc<Path>,
path_prefix: &Arc<str>,
is_directory: bool,
context_store: WeakEntity<ContextStore>,
cx: &App,
) -> Stateful<Div> {
let (file_name, directory) = extract_file_name_and_directory(path, path_prefix);
let (file_name, directory) = extract_file_name_and_directory(&path, path_prefix);
let added = context_store.upgrade().and_then(|context_store| {
let project_path = ProjectPath {
worktree_id,
path: path.clone(),
};
if is_directory {
context_store.read(cx).includes_directory(path)
context_store.read(cx).includes_directory(&project_path)
} else {
context_store.read(cx).will_include_file_path(path, cx)
context_store
.read(cx)
.will_include_file_path(&project_path, cx)
}
});
@@ -363,8 +369,9 @@ pub fn render_file_context_entry(
)
.child(Label::new("Added").size(LabelSize::Small)),
),
FileInclusion::InDirectory(dir_name) => {
let dir_name = dir_name.to_string_lossy().into_owned();
FileInclusion::InDirectory(directory_project_path) => {
// TODO: Consider using worktree full_path to include worktree name.
let directory_path = directory_project_path.path.to_string_lossy().into_owned();
el.child(
h_flex()
@@ -378,7 +385,7 @@ pub fn render_file_context_entry(
)
.child(Label::new("Included").size(LabelSize::Small)),
)
.tooltip(Tooltip::text(format!("in {dir_name}")))
.tooltip(Tooltip::text(format!("in {directory_path}")))
}
})
}

View File

@@ -2,7 +2,7 @@ use std::cmp::Reverse;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::{Context as _, Result};
use anyhow::Result;
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, AppContext, DismissEvent, Entity, FocusHandle, Focusable, Stateful, Task, WeakEntity,
@@ -15,7 +15,7 @@ use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
pub struct SymbolContextPicker {
@@ -27,16 +27,10 @@ impl SymbolContextPicker {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = SymbolContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let delegate = SymbolContextPickerDelegate::new(context_picker, workspace, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
@@ -59,7 +53,6 @@ pub struct SymbolContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<SymbolEntry>,
selected_index: usize,
}
@@ -69,13 +62,11 @@ impl SymbolContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
Self {
context_picker,
workspace,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
@@ -119,11 +110,7 @@ impl PickerDelegate for SymbolContextPickerDelegate {
let search_task = search_symbols(query, Arc::<AtomicBool>::default(), &workspace, cx);
let context_store = self.context_store.clone();
cx.spawn_in(window, async move |this, cx| {
let symbols = search_task
.await
.context("Failed to load symbols")
.log_err()
.unwrap_or_default();
let symbols = search_task.await;
let symbol_entries = context_store
.read_with(cx, |context_store, cx| {
@@ -139,7 +126,7 @@ impl PickerDelegate for SymbolContextPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(mat) = self.matches.get(self.selected_index) else {
return;
};
@@ -147,7 +134,6 @@ impl PickerDelegate for SymbolContextPickerDelegate {
return;
};
let confirm_behavior = self.confirm_behavior;
let add_symbol_task = add_symbol(
mat.symbol.clone(),
true,
@@ -157,16 +143,12 @@ impl PickerDelegate for SymbolContextPickerDelegate {
);
let selected_index = self.selected_index;
cx.spawn_in(window, async move |this, cx| {
cx.spawn(async move |this, cx| {
let included = add_symbol_task.await?;
this.update_in(cx, |this, window, cx| {
this.update(cx, |this, _| {
if let Some(mat) = this.delegate.matches.get_mut(selected_index) {
mat.is_included = included;
}
match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}
})
})
.detach_and_log_err(cx);
@@ -285,12 +267,16 @@ fn find_matching_symbol(symbol: &Symbol, candidates: &[DocumentSymbol]) -> Optio
}
}
pub struct SymbolMatch {
pub symbol: Symbol,
}
pub(crate) fn search_symbols(
query: String,
cancellation_flag: Arc<AtomicBool>,
workspace: &Entity<Workspace>,
cx: &mut App,
) -> Task<Result<Vec<(StringMatch, Symbol)>>> {
) -> Task<Vec<SymbolMatch>> {
let symbols_task = workspace.update(cx, |workspace, cx| {
workspace
.project()
@@ -298,19 +284,28 @@ pub(crate) fn search_symbols(
});
let project = workspace.read(cx).project().clone();
cx.spawn(async move |cx| {
let symbols = symbols_task.await?;
let (visible_match_candidates, external_match_candidates): (Vec<_>, Vec<_>) = project
.update(cx, |project, cx| {
symbols
.iter()
.enumerate()
.map(|(id, symbol)| StringMatchCandidate::new(id, &symbol.label.filter_text()))
.partition(|candidate| {
project
.entry_for_path(&symbols[candidate.id].path, cx)
.map_or(false, |e| !e.is_ignored)
})
})?;
let Some(symbols) = symbols_task.await.log_err() else {
return Vec::new();
};
let Some((visible_match_candidates, external_match_candidates)): Option<(Vec<_>, Vec<_>)> =
project
.update(cx, |project, cx| {
symbols
.iter()
.enumerate()
.map(|(id, symbol)| {
StringMatchCandidate::new(id, &symbol.label.filter_text())
})
.partition(|candidate| {
project
.entry_for_path(&symbols[candidate.id].path, cx)
.map_or(false, |e| !e.is_ignored)
})
})
.log_err()
else {
return Vec::new();
};
const MAX_MATCHES: usize = 100;
let mut visible_matches = cx.background_executor().block(fuzzy::match_strings(
@@ -339,7 +334,7 @@ pub(crate) fn search_symbols(
let mut matches = visible_matches;
matches.append(&mut external_matches);
Ok(matches
matches
.into_iter()
.map(|mut mat| {
let symbol = symbols[mat.candidate_id].clone();
@@ -347,19 +342,19 @@ pub(crate) fn search_symbols(
for position in &mut mat.positions {
*position += filter_start;
}
(mat, symbol)
SymbolMatch { symbol }
})
.collect())
.collect()
})
}
fn compute_symbol_entries(
symbols: Vec<(StringMatch, Symbol)>,
symbols: Vec<SymbolMatch>,
context_store: &ContextStore,
cx: &App,
) -> Vec<SymbolEntry> {
let mut symbol_entries = Vec::with_capacity(symbols.len());
for (_, symbol) in symbols {
for SymbolMatch { symbol, .. } in symbols {
let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
let is_included = if let Some(symbols_for_path) = symbols_for_path {
let mut is_included = false;

View File

@@ -1,11 +1,12 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use fuzzy::StringMatchCandidate;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use ui::{ListItem, prelude::*};
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_picker::ContextPicker;
use crate::context_store::{self, ContextStore};
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
@@ -19,16 +20,11 @@ impl ThreadContextPicker {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = ThreadContextPickerDelegate::new(
thread_store,
context_picker,
context_store,
confirm_behavior,
);
let delegate =
ThreadContextPickerDelegate::new(thread_store, context_picker, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
ThreadContextPicker { picker }
@@ -57,7 +53,6 @@ pub struct ThreadContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<ThreadContextEntry>,
selected_index: usize,
}
@@ -67,13 +62,11 @@ impl ThreadContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
ThreadContextPickerDelegate {
thread_store,
context_picker,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
@@ -114,11 +107,11 @@ impl PickerDelegate for ThreadContextPickerDelegate {
return Task::ready(());
};
let search_task = search_threads(query, threads, cx);
let search_task = search_threads(query, Arc::new(AtomicBool::default()), threads, cx);
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
this.delegate.matches = matches;
this.delegate.matches = matches.into_iter().map(|mat| mat.thread).collect();
this.delegate.selected_index = 0;
cx.notify();
})
@@ -126,7 +119,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(entry) = self.matches.get(self.selected_index) else {
return;
};
@@ -137,20 +130,15 @@ impl PickerDelegate for ThreadContextPickerDelegate {
let open_thread_task = thread_store.update(cx, |this, cx| this.open_thread(&entry.id, cx));
cx.spawn_in(window, async move |this, cx| {
cx.spawn(async move |this, cx| {
let thread = open_thread_task.await?;
this.update_in(cx, |this, window, cx| {
this.update(cx, |this, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_thread(thread, true, cx)
})
.ok();
match this.delegate.confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}
})
})
.detach_and_log_err(cx);
@@ -217,11 +205,18 @@ pub fn render_thread_context_entry(
})
}
#[derive(Clone)]
pub struct ThreadMatch {
pub thread: ThreadContextEntry,
pub is_recent: bool,
}
pub(crate) fn search_threads(
query: String,
cancellation_flag: Arc<AtomicBool>,
thread_store: Entity<ThreadStore>,
cx: &mut App,
) -> Task<Vec<ThreadContextEntry>> {
) -> Task<Vec<ThreadMatch>> {
let threads = thread_store.update(cx, |this, _cx| {
this.threads()
.into_iter()
@@ -236,6 +231,12 @@ pub(crate) fn search_threads(
cx.background_spawn(async move {
if query.is_empty() {
threads
.into_iter()
.map(|thread| ThreadMatch {
thread,
is_recent: false,
})
.collect()
} else {
let candidates = threads
.iter()
@@ -247,14 +248,17 @@ pub(crate) fn search_threads(
&query,
false,
100,
&Default::default(),
&cancellation_flag,
executor,
)
.await;
matches
.into_iter()
.map(|mat| threads[mat.candidate_id].clone())
.map(|mat| ThreadMatch {
thread: threads[mat.candidate_id].clone(),
is_recent: false,
})
.collect()
}
})

View File

@@ -1,5 +1,5 @@
use std::ops::Range;
use std::path::{Path, PathBuf};
use std::path::Path;
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
@@ -28,7 +28,7 @@ pub struct ContextStore {
// TODO: If an EntityId is used for all context types (like BufferId), can remove ContextId.
next_context_id: ContextId,
files: BTreeMap<BufferId, ContextId>,
directories: HashMap<PathBuf, ContextId>,
directories: HashMap<ProjectPath, ContextId>,
symbols: HashMap<ContextSymbolId, ContextId>,
symbol_buffers: HashMap<ContextSymbolId, Entity<Buffer>>,
symbols_by_path: HashMap<ProjectPath, Vec<ContextSymbolId>>,
@@ -93,7 +93,7 @@ impl ContextStore {
let buffer_id = this.update(cx, |_, cx| buffer.read(cx).remote_id())?;
let already_included = this.update(cx, |this, cx| {
match this.will_include_buffer(buffer_id, &project_path.path) {
match this.will_include_buffer(buffer_id, &project_path) {
Some(FileInclusion::Direct(context_id)) => {
if remove_if_exists {
this.remove_context(context_id, cx);
@@ -159,7 +159,7 @@ impl ContextStore {
return Task::ready(Err(anyhow!("failed to read project")));
};
let already_included = match self.includes_directory(&project_path.path) {
let already_included = match self.includes_directory(&project_path) {
Some(FileInclusion::Direct(context_id)) => {
if remove_if_exists {
self.remove_context(context_id, cx);
@@ -223,14 +223,12 @@ impl ContextStore {
.collect::<Vec<_>>();
if context_buffers.is_empty() {
return Err(anyhow!(
"No text files found in {}",
&project_path.path.display()
));
let full_path = cx.update(|cx| worktree.read(cx).full_path(&project_path.path))?;
return Err(anyhow!("No text files found in {}", &full_path.display()));
}
this.update(cx, |this, cx| {
this.insert_directory(project_path, context_buffers, cx);
this.insert_directory(worktree, project_path, context_buffers, cx);
})?;
anyhow::Ok(())
@@ -239,17 +237,20 @@ impl ContextStore {
fn insert_directory(
&mut self,
worktree: Entity<Worktree>,
project_path: ProjectPath,
context_buffers: Vec<ContextBuffer>,
cx: &mut Context<Self>,
) {
let id = self.next_context_id.post_inc();
self.directories.insert(project_path.path.to_path_buf(), id);
let path = project_path.path.clone();
self.directories.insert(project_path, id);
self.context
.push(AssistantContext::Directory(DirectoryContext {
id,
project_path,
worktree,
path,
context_buffers,
}));
cx.notify();
@@ -478,23 +479,31 @@ impl ContextStore {
/// Returns whether the buffer is already included directly in the context, or if it will be
/// included in the context via a directory. Directory inclusion is based on paths rather than
/// buffer IDs as the directory will be re-scanned.
pub fn will_include_buffer(&self, buffer_id: BufferId, path: &Path) -> Option<FileInclusion> {
pub fn will_include_buffer(
&self,
buffer_id: BufferId,
project_path: &ProjectPath,
) -> Option<FileInclusion> {
if let Some(context_id) = self.files.get(&buffer_id) {
return Some(FileInclusion::Direct(*context_id));
}
self.will_include_file_path_via_directory(path)
self.will_include_file_path_via_directory(project_path)
}
/// Returns whether this file path is already included directly in the context, or if it will be
/// included in the context via a directory.
pub fn will_include_file_path(&self, path: &Path, cx: &App) -> Option<FileInclusion> {
pub fn will_include_file_path(
&self,
project_path: &ProjectPath,
cx: &App,
) -> Option<FileInclusion> {
if !self.files.is_empty() {
let found_file_context = self.context.iter().find(|context| match &context {
AssistantContext::File(file_context) => {
let buffer = file_context.context_buffer.buffer.read(cx);
if let Some(file_path) = buffer_path_log_err(buffer, cx) {
*file_path == *path
if let Some(context_path) = buffer.project_path(cx) {
&context_path == project_path
} else {
false
}
@@ -506,31 +515,40 @@ impl ContextStore {
}
}
self.will_include_file_path_via_directory(path)
self.will_include_file_path_via_directory(project_path)
}
fn will_include_file_path_via_directory(&self, path: &Path) -> Option<FileInclusion> {
fn will_include_file_path_via_directory(
&self,
project_path: &ProjectPath,
) -> Option<FileInclusion> {
if self.directories.is_empty() {
return None;
}
let mut buf = path.to_path_buf();
let mut path_buf = project_path.path.to_path_buf();
while buf.pop() {
if let Some(_) = self.directories.get(&buf) {
return Some(FileInclusion::InDirectory(buf));
while path_buf.pop() {
// TODO: This isn't very efficient. Consider using a better representation of the
// directories map.
let directory_project_path = ProjectPath {
worktree_id: project_path.worktree_id,
path: path_buf.clone().into(),
};
if let Some(_) = self.directories.get(&directory_project_path) {
return Some(FileInclusion::InDirectory(directory_project_path));
}
}
None
}
pub fn includes_directory(&self, path: &Path) -> Option<FileInclusion> {
if let Some(context_id) = self.directories.get(path) {
pub fn includes_directory(&self, project_path: &ProjectPath) -> Option<FileInclusion> {
if let Some(context_id) = self.directories.get(project_path) {
return Some(FileInclusion::Direct(*context_id));
}
self.will_include_file_path_via_directory(path)
self.will_include_file_path_via_directory(project_path)
}
pub fn included_symbol(&self, symbol_id: &ContextSymbolId) -> Option<ContextId> {
@@ -564,13 +582,13 @@ impl ContextStore {
}
}
pub fn file_paths(&self, cx: &App) -> HashSet<PathBuf> {
pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
self.context
.iter()
.filter_map(|context| match context {
AssistantContext::File(file) => {
let buffer = file.context_buffer.buffer.read(cx);
buffer_path_log_err(buffer, cx).map(|p| p.to_path_buf())
buffer.project_path(cx)
}
AssistantContext::Directory(_)
| AssistantContext::Symbol(_)
@@ -587,7 +605,7 @@ impl ContextStore {
pub enum FileInclusion {
Direct(ContextId),
InDirectory(PathBuf),
InDirectory(ProjectPath),
}
// ContextBuffer without text.
@@ -654,19 +672,6 @@ fn collect_buffer_info_and_text(
Ok((buffer_info, text_task))
}
pub fn buffer_path_log_err(buffer: &Buffer, cx: &App) -> Option<Arc<Path>> {
if let Some(file) = buffer.file() {
let mut path = file.path().clone();
if path.as_os_str().is_empty() {
path = file.full_path(cx).into();
}
Some(path)
} else {
log::error!("Buffer that had a path unexpectedly no longer has a path.");
None
}
}
fn to_fenced_codeblock(path: &Path, content: Rope) -> SharedString {
let path_extension = path.extension().and_then(|ext| ext.to_str());
let path_string = path.to_string_lossy();
@@ -742,13 +747,13 @@ pub fn refresh_context_store_text(
}
}
AssistantContext::Directory(directory_context) => {
let directory_path = directory_context.project_path(cx);
let should_refresh = changed_buffers.is_empty()
|| changed_buffers.iter().any(|buffer| {
let buffer = buffer.read(cx);
buffer_path_log_err(&buffer, cx).map_or(false, |path| {
path.starts_with(&directory_context.project_path.path)
})
let Some(buffer_path) = buffer.read(cx).project_path(cx) else {
return false;
};
buffer_path.starts_with(&directory_path)
});
if should_refresh {
@@ -835,14 +840,16 @@ fn refresh_directory_text(
let context_buffers = future::join_all(futures);
let id = directory_context.id;
let project_path = directory_context.project_path.clone();
let worktree = directory_context.worktree.clone();
let path = directory_context.path.clone();
Some(cx.spawn(async move |cx| {
let context_buffers = context_buffers.await;
context_store
.update(cx, |context_store, _| {
let new_directory_context = DirectoryContext {
id,
project_path,
worktree,
path,
context_buffers,
};
context_store.replace_context(AssistantContext::Directory(new_directory_context));

View File

@@ -1,3 +1,4 @@
use std::path::Path;
use std::rc::Rc;
use collections::HashSet;
@@ -9,11 +10,12 @@ use gpui::{
};
use itertools::Itertools;
use language::Buffer;
use project::ProjectItem;
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
use workspace::{Workspace, notifications::NotifyResultExt};
use crate::context::{ContextId, ContextKind};
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
use crate::thread::Thread;
use crate::thread_store::ThreadStore;
@@ -50,7 +52,6 @@ impl ContextStrip {
workspace.clone(),
thread_store.clone(),
context_store.downgrade(),
ConfirmBehavior::KeepOpen,
window,
cx,
)
@@ -93,26 +94,23 @@ impl ContextStrip {
let active_buffer_entity = editor.buffer().read(cx).as_singleton()?;
let active_buffer = active_buffer_entity.read(cx);
let path = active_buffer.file()?.full_path(cx);
let project_path = active_buffer.project_path(cx)?;
if self
.context_store
.read(cx)
.will_include_buffer(active_buffer.remote_id(), &path)
.will_include_buffer(active_buffer.remote_id(), &project_path)
.is_some()
{
return None;
}
let name = match path.file_name() {
Some(name) => name.to_string_lossy().into_owned().into(),
None => path.to_string_lossy().into_owned().into(),
};
let file_name = active_buffer.file()?.file_name(cx);
let icon_path = FileIcons::get_icon(&path, cx);
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx);
Some(SuggestedContext::File {
name,
name: file_name.to_string_lossy().into_owned().into(),
buffer: active_buffer_entity.downgrade(),
icon_path,
})

View File

@@ -28,6 +28,7 @@ use language_model::{LanguageModelRegistry, report_assistant_event};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use project::LspAction;
use project::Project;
use project::{CodeAction, ProjectTransaction};
use prompt_store::PromptBuilder;
use settings::{Settings, SettingsStore};
@@ -254,6 +255,7 @@ impl InlineAssistant {
assistant.assist(
&active_editor,
cx.entity().downgrade(),
workspace.project().downgrade(),
thread_store,
window,
cx,
@@ -262,7 +264,14 @@ impl InlineAssistant {
}
InlineAssistTarget::Terminal(active_terminal) => {
TerminalInlineAssistant::update_global(cx, |assistant, cx| {
assistant.assist(&active_terminal, cx.entity(), thread_store, window, cx)
assistant.assist(
&active_terminal,
cx.entity().downgrade(),
workspace.project().downgrade(),
thread_store,
window,
cx,
)
})
}
};
@@ -312,17 +321,11 @@ impl InlineAssistant {
&mut self,
editor: &Entity<Editor>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
) {
let Some(project) = workspace
.upgrade()
.map(|workspace| workspace.read(cx).project().downgrade())
else {
return;
};
let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
(
editor.snapshot(window, cx),

File diff suppressed because it is too large Load Diff

View File

@@ -86,7 +86,7 @@ impl ProfileSelector {
thread_store
.update(cx, |this, cx| {
this.load_profile_by_id(&profile_id, cx);
this.load_profile_by_id(profile_id.clone(), cx);
})
.log_err();
}

View File

@@ -16,6 +16,7 @@ use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use project::Project;
use prompt_store::PromptBuilder;
use std::sync::Arc;
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
@@ -66,7 +67,8 @@ impl TerminalInlineAssistant {
pub fn assist(
&mut self,
terminal_view: &Entity<TerminalView>,
workspace: Entity<Workspace>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -75,7 +77,6 @@ impl TerminalInlineAssistant {
let assist_id = self.next_assist_id.post_inc();
let prompt_buffer =
cx.new(|cx| MultiBuffer::singleton(cx.new(|cx| Buffer::local(String::new(), cx)), cx));
let project = workspace.read(cx).project().downgrade();
let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
@@ -87,7 +88,7 @@ impl TerminalInlineAssistant {
codegen,
self.fs.clone(),
context_store.clone(),
workspace.downgrade(),
workspace.clone(),
thread_store.clone(),
window,
cx,
@@ -106,7 +107,7 @@ impl TerminalInlineAssistant {
assist_id,
terminal_view,
prompt_editor,
workspace.downgrade(),
workspace.clone(),
context_store,
window,
cx,

View File

@@ -2,37 +2,39 @@ use std::fmt::Write as _;
use std::io::Write;
use std::ops::Range;
use std::sync::Arc;
use std::time::Instant;
use agent_rules::load_worktree_rules_file;
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap};
use fs::Fs;
use feature_flags::{self, FeatureFlagAppExt};
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
PaymentRequiredError, Role, StopReason, TokenUsage,
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, TokenUsage,
};
use project::Project;
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
use project::{Project, Worktree};
use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
use prompt_store::PromptBuilder;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse,
SerializedToolUse, SharedProjectContext,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
@@ -182,7 +184,7 @@ pub struct ThreadCheckpoint {
git_checkpoint: GitStoreCheckpoint,
}
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ThreadFeedback {
Positive,
Negative,
@@ -227,7 +229,7 @@ pub struct TotalTokenUsage {
pub ratio: TokenUsageRatio,
}
#[derive(Default, PartialEq, Eq)]
#[derive(Debug, Default, PartialEq, Eq)]
pub enum TokenUsageRatio {
#[default]
Normal,
@@ -246,27 +248,39 @@ pub struct Thread {
next_message_id: MessageId,
context: BTreeMap<ContextId, AssistantContext>,
context_by_message: HashMap<MessageId, Vec<ContextId>>,
system_prompt_context: Option<AssistantSystemPromptContext>,
project_context: SharedProjectContext,
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
project: Entity<Project>,
prompt_builder: Arc<PromptBuilder>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
tool_use: ToolUseState,
action_log: Entity<ActionLog>,
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
pending_checkpoint: Option<ThreadCheckpoint>,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>,
feedback: Option<ThreadFeedback>,
message_feedback: HashMap<MessageId, ThreadFeedback>,
last_auto_capture_at: Option<Instant>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExceededWindowError {
/// Model used when last message exceeded context window
model_id: LanguageModelId,
/// Token count including last message
token_count: usize,
}
impl Thread {
pub fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
system_prompt: SharedProjectContext,
cx: &mut Context<Self>,
) -> Self {
Self {
@@ -279,7 +293,7 @@ impl Thread {
next_message_id: MessageId(0),
context: BTreeMap::default(),
context_by_message: HashMap::default(),
system_prompt_context: None,
project_context: system_prompt,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
@@ -297,7 +311,10 @@ impl Thread {
.shared()
},
cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
}
}
@@ -305,8 +322,9 @@ impl Thread {
id: ThreadId,
serialized: SerializedThread,
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext,
cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
@@ -347,7 +365,7 @@ impl Thread {
next_message_id,
context: BTreeMap::default(),
context_by_message: HashMap::default(),
system_prompt_context: None,
project_context,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
@@ -360,7 +378,10 @@ impl Thread {
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
}
}
@@ -384,6 +405,10 @@ impl Thread {
self.summary.clone()
}
pub fn project_context(&self) -> SharedProjectContext {
self.project_context.clone()
}
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
pub fn summary_or_default(&self) -> SharedString {
@@ -433,7 +458,7 @@ impl Thread {
!self.pending_completions.is_empty() || !self.all_tools_finished()
}
pub fn tools(&self) -> &Arc<ToolWorkingSet> {
pub fn tools(&self) -> &Entity<ToolWorkingSet> {
&self.tools
}
@@ -674,6 +699,9 @@ impl Thread {
git_checkpoint,
});
}
self.auto_capture_telemetry(cx);
message_id
}
@@ -801,90 +829,11 @@ impl Thread {
initial_project_snapshot,
cumulative_token_usage: this.cumulative_token_usage.clone(),
detailed_summary_state: this.detailed_summary_state.clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
})
})
}
pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
self.system_prompt_context = Some(context);
}
pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
&self.system_prompt_context
}
pub fn load_system_prompt_context(
&self,
cx: &App,
) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
let project = self.project.read(cx);
let tasks = project
.visible_worktrees(cx)
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(
project.fs().clone(),
worktree.read(cx),
cx,
)
})
.collect::<Vec<_>>();
cx.spawn(async |_cx| {
let results = futures::future::join_all(tasks).await;
let mut first_err = None;
let worktrees = results
.into_iter()
.map(|(worktree, err)| {
if first_err.is_none() && err.is_some() {
first_err = err;
}
worktree
})
.collect::<Vec<_>>();
(AssistantSystemPromptContext::new(worktrees), first_err)
})
}
fn load_worktree_info_for_system_prompt(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
let root_name = worktree.root_name().into();
let abs_path = worktree.abs_path();
let rules_task = load_worktree_rules_file(fs, worktree, cx);
let Some(rules_task) = rules_task else {
return Task::ready((
WorktreeInfoForSystemPrompt {
root_name,
abs_path,
rules_file: None,
},
None,
));
};
cx.spawn(async move |_| {
let (rules_file, rules_file_error) = match rules_task.await {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(ThreadError::Message {
header: "Error loading rules file".into(),
message: format!("{err}").into(),
}),
),
};
let worktree_info = WorktreeInfoForSystemPrompt {
root_name,
abs_path,
rules_file,
};
(worktree_info, rules_file_error)
})
}
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
@@ -895,13 +844,21 @@ impl Thread {
if model.supports_tools() {
request.tools = {
let mut tools = Vec::new();
tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool.input_schema(model.tool_input_format()),
}
}));
tools.extend(
self.tools()
.read(cx)
.enabled_tools(cx)
.into_iter()
.filter_map(|tool| {
// Skip tools that cannot be supported
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema,
})
}),
);
tools
};
@@ -934,10 +891,10 @@ impl Thread {
temperature: None,
};
if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
if let Some(project_context) = self.project_context.borrow().as_ref() {
if let Some(system_prompt) = self
.prompt_builder
.generate_assistant_system_prompt(system_prompt_context)
.generate_assistant_system_prompt(project_context)
.context("failed to generate assistant system prompt")
.log_err()
{
@@ -948,7 +905,7 @@ impl Thread {
});
}
} else {
log::error!("system_prompt_context not set.")
log::error!("project_context not set.")
}
for message in &self.messages {
@@ -1152,6 +1109,8 @@ impl Thread {
thread.touch_updated_at();
cx.emit(ThreadEvent::StreamedCompletion);
cx.notify();
thread.auto_capture_telemetry(cx);
})?;
smol::future::yield_now().await;
@@ -1178,7 +1137,8 @@ impl Thread {
match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
cx.emit(ThreadEvent::UsePendingTools);
let tool_uses = thread.use_pending_tools(cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
@@ -1190,6 +1150,20 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(
ThreadError::MaxMonthlySpendReached,
));
} else if let Some(known_error) =
error.downcast_ref::<LanguageModelKnownError>()
{
match known_error {
LanguageModelKnownError::ContextWindowLimitExceeded {
tokens,
} => {
thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(),
token_count: *tokens,
});
cx.notify();
}
}
} else {
let error_message = error
.chain()
@@ -1205,7 +1179,9 @@ impl Thread {
thread.cancel_last_completion(cx);
}
}
cx.emit(ThreadEvent::DoneStreaming);
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
thread.auto_capture_telemetry(cx);
if let Ok(initial_usage) = initial_token_usage {
let usage = thread.cumulative_token_usage.clone() - initial_usage;
@@ -1366,10 +1342,8 @@ impl Thread {
)
}
pub fn use_pending_tools(
&mut self,
cx: &mut Context<Self>,
) -> impl IntoIterator<Item = PendingToolUse> + use<> {
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx);
let request = self.to_completion_request(RequestKind::Chat, cx);
let messages = Arc::new(request.messages);
let pending_tool_uses = self
@@ -1381,7 +1355,7 @@ impl Thread {
.collect::<Vec<_>>();
for tool_use in pending_tool_uses.iter() {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
if tool.needs_confirmation(&tool_use.input, cx)
&& !AssistantSettings::get_global(cx).always_allow_tool_actions
{
@@ -1433,8 +1407,8 @@ impl Thread {
) -> Task<()> {
let tool_name: Arc<str> = tool.name().into();
let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
} else {
tool.run(
input,
@@ -1447,7 +1421,7 @@ impl Thread {
cx.spawn({
async move |thread: WeakEntity<Thread>, cx| {
let output = run_tool.await;
let output = tool_result.output.await;
thread
.update(cx, |thread, cx| {
@@ -1457,18 +1431,36 @@ impl Thread {
output,
cx,
);
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
canceled: false,
});
thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
})
.ok();
}
})
}
fn tool_finished(
&mut self,
tool_use_id: LanguageModelToolUseId,
pending_tool_use: Option<PendingToolUse>,
canceled: bool,
cx: &mut Context<Self>,
) {
if self.all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.attach_tool_results(cx);
if !canceled {
self.send_to_model(model, RequestKind::Chat, cx);
}
}
}
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
});
}
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
// Insert a user message to contain the tool results.
self.insert_user_message(
@@ -1492,11 +1484,12 @@ impl Thread {
let mut canceled = false;
for pending_tool_use in self.tool_use.cancel_pending() {
canceled = true;
cx.emit(ThreadEvent::ToolFinished {
tool_use_id: pending_tool_use.id.clone(),
pending_tool_use: Some(pending_tool_use),
canceled: true,
});
self.tool_finished(
pending_tool_use.id.clone(),
Some(pending_tool_use),
true,
cx,
);
}
canceled
};
@@ -1504,24 +1497,46 @@ impl Thread {
canceled
}
/// Returns the feedback given to the thread, if any.
pub fn feedback(&self) -> Option<ThreadFeedback> {
self.feedback
}
/// Reports feedback about the thread and stores it in our telemetry backend.
pub fn report_feedback(
pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
self.message_feedback.get(&message_id).copied()
}
pub fn report_message_feedback(
&mut self,
message_id: MessageId,
feedback: ThreadFeedback,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
if self.message_feedback.get(&message_id) == Some(&feedback) {
return Task::ready(Ok(()));
}
let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
let serialized_thread = self.serialize(cx);
let thread_id = self.id().clone();
let client = self.project.read(cx).client();
self.feedback = Some(feedback);
let enabled_tool_names: Vec<String> = self
.tools()
.read(cx)
.enabled_tools(cx)
.iter()
.map(|tool| tool.name().to_string())
.collect();
self.message_feedback.insert(message_id, feedback);
cx.notify();
let message_content = self
.message(message_id)
.map(|msg| msg.to_string())
.unwrap_or_default();
cx.background_spawn(async move {
let final_project_snapshot = final_project_snapshot.await;
let serialized_thread = serialized_thread.await?;
@@ -1536,6 +1551,9 @@ impl Thread {
"Assistant Thread Rated",
rating,
thread_id,
enabled_tool_names,
message_id = message_id.0,
message_content,
thread_data,
final_project_snapshot
);
@@ -1545,6 +1563,52 @@ impl Thread {
})
}
pub fn report_feedback(
&mut self,
feedback: ThreadFeedback,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let last_assistant_message_id = self
.messages
.iter()
.rev()
.find(|msg| msg.role == Role::Assistant)
.map(|msg| msg.id);
if let Some(message_id) = last_assistant_message_id {
self.report_message_feedback(message_id, feedback, cx)
} else {
let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
let serialized_thread = self.serialize(cx);
let thread_id = self.id().clone();
let client = self.project.read(cx).client();
self.feedback = Some(feedback);
cx.notify();
cx.background_spawn(async move {
let final_project_snapshot = final_project_snapshot.await;
let serialized_thread = serialized_thread.await?;
let thread_data = serde_json::to_value(serialized_thread)
.unwrap_or_else(|_| serde_json::Value::Null);
let rating = match feedback {
ThreadFeedback::Positive => "positive",
ThreadFeedback::Negative => "negative",
};
telemetry::event!(
"Assistant Thread Rated",
rating,
thread_id,
thread_data,
final_project_snapshot
);
client.telemetry().flush_events();
Ok(())
})
}
}
/// Create a snapshot of the current project state including git information and unsaved buffers.
fn project_snapshot(
project: Entity<Project>,
@@ -1737,14 +1801,14 @@ impl Thread {
.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
}
pub fn reject_edits_in_range(
pub fn reject_edits_in_ranges(
&mut self,
buffer: Entity<language::Buffer>,
buffer_range: Range<language::Anchor>,
buffer_ranges: Vec<Range<language::Anchor>>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.action_log.update(cx, |action_log, cx| {
action_log.reject_edits_in_range(buffer, buffer_range, cx)
action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
})
}
@@ -1756,8 +1820,48 @@ impl Thread {
&self.project
}
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone()
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
return;
}
let now = Instant::now();
if let Some(last) = self.last_auto_capture_at {
if now.duration_since(last).as_secs() < 10 {
return;
}
}
self.last_auto_capture_at = Some(now);
let thread_id = self.id().clone();
let github_login = self
.project
.read(cx)
.user_store()
.read(cx)
.current_user()
.map(|user| user.github_login.clone());
let client = self.project.read(cx).client().clone();
let serialize_task = self.serialize(cx);
cx.background_executor()
.spawn(async move {
if let Ok(serialized_thread) = serialize_task.await {
if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
telemetry::event!(
"Agent Thread Auto-Captured",
thread_id = thread_id.to_string(),
thread_data = thread_data,
auto_capture_reason = "tracked_user",
github_login = github_login
);
client.telemetry().flush_events();
}
}
})
.detach();
}
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
@@ -1768,6 +1872,16 @@ impl Thread {
let max = model.model.max_token_count();
if let Some(exceeded_error) = &self.exceeded_window_error {
if model.model.id() == exceeded_error.model_id {
return TotalTokenUsage {
total: exceeded_error.token_count,
max,
ratio: TokenUsageRatio::Exceeded,
};
}
}
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.8".to_string())
@@ -1801,19 +1915,17 @@ impl Thread {
self.tool_use
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use: None,
canceled: true,
});
self.tool_finished(tool_use_id.clone(), None, true, cx);
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Error)]
pub enum ThreadError {
#[error("Payment required")]
PaymentRequired,
#[error("Max monthly spend reached")]
MaxMonthlySpendReached,
#[error("Message {header}: {message}")]
Message {
header: SharedString,
message: SharedString,
@@ -1826,20 +1938,20 @@ pub enum ThreadEvent {
StreamedCompletion,
StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
DoneStreaming,
Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),
MessageDeleted(MessageId),
SummaryGenerated,
SummaryChanged,
UsePendingTools,
UsePendingTools {
tool_uses: Vec<PendingToolUse>,
},
ToolFinished {
#[allow(unused)]
tool_use_id: LanguageModelToolUseId,
/// The pending tool use that corresponds to this tool.
pending_tool_use: Option<PendingToolUse>,
/// Whether the tool was canceled by the user.
canceled: bool,
},
CheckpointChanged,
ToolConfirmationNeeded,
@@ -1932,9 +2044,9 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 1);
assert_eq!(request.messages.len(), 2);
let expected_full_message = format!("{}Please explain this code", expected_context);
assert_eq!(request.messages[0].string_contents(), expected_full_message);
assert_eq!(request.messages[1].string_contents(), expected_full_message);
}
#[gpui::test]
@@ -2025,20 +2137,20 @@ fn main() {{
});
// The request should contain all 3 messages
assert_eq!(request.messages.len(), 3);
assert_eq!(request.messages.len(), 4);
// Check that the contexts are properly formatted in each message
assert!(request.messages[0].string_contents().contains("file1.rs"));
assert!(!request.messages[0].string_contents().contains("file2.rs"));
assert!(!request.messages[0].string_contents().contains("file3.rs"));
assert!(!request.messages[1].string_contents().contains("file1.rs"));
assert!(request.messages[1].string_contents().contains("file2.rs"));
assert!(request.messages[1].string_contents().contains("file1.rs"));
assert!(!request.messages[1].string_contents().contains("file2.rs"));
assert!(!request.messages[1].string_contents().contains("file3.rs"));
assert!(!request.messages[2].string_contents().contains("file1.rs"));
assert!(!request.messages[2].string_contents().contains("file2.rs"));
assert!(request.messages[2].string_contents().contains("file3.rs"));
assert!(request.messages[2].string_contents().contains("file2.rs"));
assert!(!request.messages[2].string_contents().contains("file3.rs"));
assert!(!request.messages[3].string_contents().contains("file1.rs"));
assert!(!request.messages[3].string_contents().contains("file2.rs"));
assert!(request.messages[3].string_contents().contains("file3.rs"));
}
#[gpui::test]
@@ -2076,9 +2188,9 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 1);
assert_eq!(request.messages.len(), 2);
assert_eq!(
request.messages[0].string_contents(),
request.messages[1].string_contents(),
"What is the best way to learn Rust?"
);
@@ -2096,13 +2208,13 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 2);
assert_eq!(request.messages.len(), 3);
assert_eq!(
request.messages[0].string_contents(),
request.messages[1].string_contents(),
"What is the best way to learn Rust?"
);
assert_eq!(
request.messages[1].string_contents(),
request.messages[2].string_contents(),
"Are there any good books?"
);
}
@@ -2223,15 +2335,16 @@ fn main() {{
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_store = cx.update(|_, cx| {
ThreadStore::new(
project.clone(),
Arc::default(),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
.unwrap()
});
let thread_store = cx
.update(|_, cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})
.await;
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));

View File

@@ -4,11 +4,14 @@ use assistant_context_editor::SavedContextMetadata;
use editor::{Editor, EditorEvent};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, Entity, FocusHandle, Focusable, ScrollStrategy, Task, UniformListScrollHandle, WeakEntity,
Window, uniform_list,
App, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, UniformListScrollHandle,
WeakEntity, Window, uniform_list,
};
use time::{OffsetDateTime, UtcOffset};
use ui::{HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Tooltip, prelude::*};
use ui::{
HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState,
Tooltip, prelude::*,
};
use util::ResultExt;
use crate::history_store::{HistoryEntry, HistoryStore};
@@ -26,6 +29,8 @@ pub struct ThreadHistory {
matches: Vec<StringMatch>,
_subscriptions: Vec<gpui::Subscription>,
_search_task: Option<Task<()>>,
scrollbar_visibility: bool,
scrollbar_state: ScrollbarState,
}
impl ThreadHistory {
@@ -58,10 +63,13 @@ impl ThreadHistory {
this.update_all_entries(cx);
});
let scroll_handle = UniformListScrollHandle::default();
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
Self {
assistant_panel,
history_store,
scroll_handle: UniformListScrollHandle::default(),
scroll_handle,
selected_index: 0,
search_query: SharedString::new_static(""),
all_entries: entries,
@@ -69,6 +77,8 @@ impl ThreadHistory {
search_editor,
_subscriptions: vec![search_editor_subscription, history_store_subscription],
_search_task: None,
scrollbar_visibility: true,
scrollbar_state,
}
}
@@ -220,6 +230,43 @@ impl ThreadHistory {
cx.notify();
}
fn render_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) {
return None;
}
Some(
div()
.occlude()
.id("thread-history-scroll")
.h_full()
.bg(cx.theme().colors().panel_background.opacity(0.8))
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.absolute()
.right_1()
.top_0()
.bottom_0()
.w_4()
.pl_1()
.cursor_default()
.on_mouse_move(cx.listener(|_, _, _window, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _window, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _window, cx| {
cx.stop_propagation();
})
.on_scroll_wheel(cx.listener(|_, _, _window, cx| {
cx.notify();
}))
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
)
}
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
if let Some(entry) = self.get_match(self.selected_index) {
let task_result = match entry {
@@ -305,7 +352,11 @@ impl Render for ThreadHistory {
)
})
.child({
let view = v_flex().overflow_hidden().flex_grow();
let view = v_flex()
.id("list-container")
.relative()
.overflow_hidden()
.flex_grow();
if self.all_entries.is_empty() {
view.justify_center()
@@ -322,59 +373,70 @@ impl Render for ThreadHistory {
),
)
} else {
view.p_1().child(
uniform_list(
cx.entity().clone(),
"thread-history",
self.matched_count(),
move |history, range, _window, _cx| {
let range_start = range.start;
let assistant_panel = history.assistant_panel.clone();
view.pr_5()
.child(
uniform_list(
cx.entity().clone(),
"thread-history",
self.matched_count(),
move |history, range, _window, _cx| {
let range_start = range.start;
let assistant_panel = history.assistant_panel.clone();
let render_item = |index: usize,
entry: &HistoryEntry,
highlight_positions: Vec<usize>|
-> Div {
h_flex().w_full().pb_1().child(match entry {
HistoryEntry::Thread(thread) => PastThread::new(
thread.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
HistoryEntry::Context(context) => PastContext::new(
context.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
})
};
if history.has_search_query() {
history.matches[range]
.iter()
.enumerate()
.filter_map(|(index, m)| {
history.all_entries.get(m.candidate_id).map(|entry| {
render_item(index, entry, m.positions.clone())
})
let render_item = |index: usize,
entry: &HistoryEntry,
highlight_positions: Vec<usize>|
-> Div {
h_flex().w_full().pb_1().child(match entry {
HistoryEntry::Thread(thread) => PastThread::new(
thread.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
HistoryEntry::Context(context) => PastContext::new(
context.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
})
.collect()
} else {
history.all_entries[range]
.iter()
.enumerate()
.map(|(index, entry)| render_item(index, entry, vec![]))
.collect()
}
},
};
if history.has_search_query() {
history.matches[range]
.iter()
.enumerate()
.filter_map(|(index, m)| {
history.all_entries.get(m.candidate_id).map(
|entry| {
render_item(
index,
entry,
m.positions.clone(),
)
},
)
})
.collect()
} else {
history.all_entries[range]
.iter()
.enumerate()
.map(|(index, entry)| render_item(index, entry, vec![]))
.collect()
}
},
)
.p_1()
.track_scroll(self.scroll_handle.clone())
.flex_grow(),
)
.track_scroll(self.scroll_handle.clone())
.flex_grow(),
)
.when_some(self.render_scrollbar(cx), |div, scrollbar| {
div.child(scrollbar)
})
}
})
}
@@ -440,6 +502,7 @@ impl RenderOnce for PastThread {
IconButton::new("delete", IconName::TrashAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(move |window, cx| {
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
})
@@ -531,6 +594,7 @@ impl RenderOnce for PastContext {
IconButton::new("delete", IconName::TrashAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(move |window, cx| {
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
})

View File

@@ -1,88 +1,254 @@
use std::borrow::Cow;
use std::path::PathBuf;
use std::cell::{Ref, RefCell};
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::sync::Arc;
use anyhow::{Result, anyhow};
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use fs::Fs;
use futures::FutureExt as _;
use futures::future::{self, BoxFuture, Shared};
use gpui::{
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
prelude::*,
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
Subscription, Task, prelude::*,
};
use heed::Database;
use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::Project;
use prompt_store::PromptBuilder;
use project::{Project, Worktree};
use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
use crate::thread::{
DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
};
const RULES_FILE_NAMES: [&'static str; 6] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
".github/copilot-instructions.md",
"CLAUDE.md",
];
pub fn init(cx: &mut App) {
ThreadsDatabase::init(cx);
}
/// A system prompt shared by all threads created by this ThreadStore
#[derive(Clone, Default)]
pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
impl SharedProjectContext {
pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
self.0.borrow()
}
}
pub struct ThreadStore {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
project_context: SharedProjectContext,
_subscriptions: Vec<Subscription>,
}
pub struct RulesLoadingError {
pub message: SharedString,
}
impl EventEmitter<RulesLoadingError> for ThreadStore {}
impl ThreadStore {
pub fn new(
pub fn load(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
) -> Result<Entity<Self>> {
let this = cx.new(|cx| {
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let settings_subscription =
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
});
) -> Task<Entity<Self>> {
let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
cx.foreground_executor().spawn(async move {
reload.await;
thread_store
})
}
let this = Self {
project,
tools,
prompt_builder,
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
_subscriptions: vec![settings_subscription],
};
this.load_default_profile(cx);
this.register_context_server_handlers(cx);
this.reload(cx).detach_and_log_err(cx);
this
fn new(
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
) -> Self {
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let settings_subscription =
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
});
let project_subscription = cx.subscribe(&project, Self::handle_project_event);
Ok(this)
let this = Self {
project,
tools,
prompt_builder,
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
project_context: SharedProjectContext::default(),
_subscriptions: vec![settings_subscription, project_subscription],
};
this.load_default_profile(cx);
this.register_context_server_handlers(cx);
this.reload(cx).detach_and_log_err(cx);
this
}
fn handle_project_event(
&mut self,
_project: Entity<Project>,
event: &project::Event,
cx: &mut Context<Self>,
) {
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
self.reload_system_prompt(cx).detach();
}
project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| {
RULES_FILE_NAMES
.iter()
.any(|name| path.as_ref() == Path::new(name))
}) {
self.reload_system_prompt(cx).detach();
}
}
_ => {}
}
}
pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
let project = self.project.read(cx);
let tasks = project
.visible_worktrees(cx)
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(
project.fs().clone(),
worktree.read(cx),
cx,
)
})
.collect::<Vec<_>>();
cx.spawn(async move |this, cx| {
let results = futures::future::join_all(tasks).await;
let worktrees = results
.into_iter()
.map(|(worktree, rules_error)| {
if let Some(rules_error) = rules_error {
this.update(cx, |_, cx| cx.emit(rules_error)).ok();
}
worktree
})
.collect::<Vec<_>>();
this.update(cx, |this, _cx| {
*this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
})
.ok();
})
}
fn load_worktree_info_for_system_prompt(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
let root_name = worktree.root_name().into();
let abs_path = worktree.abs_path();
let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
let Some(rules_task) = rules_task else {
return Task::ready((
WorktreeContext {
root_name,
abs_path,
rules_file: None,
},
None,
));
};
cx.spawn(async move |_| {
let (rules_file, rules_file_error) = match rules_task.await {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(RulesLoadingError {
message: format!("{err}").into(),
}),
),
};
let worktree_info = WorktreeContext {
root_name,
abs_path,
rules_file,
};
(worktree_info, rules_file_error)
})
}
fn load_worktree_rules_file(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Option<Task<Result<RulesFileContext>>> {
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
})
.next();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
selected_rules_file.map(|(path_in_worktree, abs_path)| {
let fs = fs.clone();
cx.background_spawn(async move {
let abs_path = abs_path?;
let text = fs.load(&abs_path).await.with_context(|| {
format!("Failed to load assistant rules file {:?}", abs_path)
})?;
anyhow::Ok(RulesFileContext {
path_in_worktree,
abs_path: abs_path.into(),
text: text.trim().to_string(),
})
})
})
}
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
self.context_server_manager.clone()
}
pub fn tools(&self) -> Arc<ToolWorkingSet> {
pub fn tools(&self) -> Entity<ToolWorkingSet> {
self.tools.clone()
}
@@ -107,6 +273,7 @@ impl ThreadStore {
self.project.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
cx,
)
})
@@ -134,21 +301,12 @@ impl ThreadStore {
this.project.clone(),
this.tools.clone(),
this.prompt_builder.clone(),
this.project_context.clone(),
cx,
)
})
})?;
let (system_prompt_context, load_error) = thread
.update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
.await;
thread.update(cx, |thread, cx| {
thread.set_system_prompt_context(system_prompt_context);
if let Some(load_error) = load_error {
cx.emit(ThreadEvent::ShowError(load_error));
}
})?;
Ok(thread)
})
}
@@ -197,52 +355,60 @@ impl ThreadStore {
})
}
fn load_default_profile(&self, cx: &Context<Self>) {
fn load_default_profile(&self, cx: &mut Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
self.load_profile_by_id(&assistant_settings.default_profile, cx);
self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
}
pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
self.load_profile(profile, cx);
if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
self.load_profile(profile.clone(), cx);
}
}
pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
self.tools.disable_all_tools();
self.tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
);
pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
self.tools.update(cx, |tools, cx| {
tools.disable_all_tools(cx);
tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
);
});
if profile.enable_all_context_servers {
for context_server in self.context_server_manager.read(cx).all_servers() {
self.tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
self.tools.update(cx, |tools, cx| {
tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
});
}
} else {
for (context_server_id, preset) in &profile.context_servers {
self.tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
)
self.tools.update(cx, |tools, cx| {
tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
)
})
}
}
}
@@ -276,29 +442,36 @@ impl ThreadStore {
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
.log_err();
this.update(cx, |this, cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, cx| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
}
}
}
}
@@ -308,7 +481,9 @@ impl ThreadStore {
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids);
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
self.load_default_profile(cx);
}
}
@@ -335,6 +510,8 @@ pub struct SerializedThread {
pub cumulative_token_usage: TokenUsage,
#[serde(default)]
pub detailed_summary_state: DetailedSummaryState,
#[serde(default)]
pub exceeded_window_error: Option<ExceededWindowError>,
}
impl SerializedThread {
@@ -421,6 +598,7 @@ impl LegacySerializedThread {
initial_project_snapshot: self.initial_project_snapshot,
cumulative_token_usage: TokenUsage::default(),
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
}
}
}
@@ -491,7 +669,7 @@ impl ThreadsDatabase {
let database_future = executor
.spawn({
let executor = executor.clone();
let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
async move { ThreadsDatabase::new(database_path, executor) }
})
.then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))

View File

@@ -0,0 +1,89 @@
use std::sync::Arc;
use assistant_tool::{Tool, ToolWorkingSet, ToolWorkingSetEvent};
use collections::HashMap;
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
use language_model::{LanguageModel, LanguageModelToolSchemaFormat};
use ui::prelude::*;
pub struct IncompatibleToolsState {
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
tool_working_set: Entity<ToolWorkingSet>,
_tool_working_set_subscription: Subscription,
}
impl IncompatibleToolsState {
pub fn new(tool_working_set: Entity<ToolWorkingSet>, cx: &mut Context<Self>) -> Self {
let _tool_working_set_subscription =
cx.subscribe(&tool_working_set, |this, _, event, _| match event {
ToolWorkingSetEvent::EnabledToolsChanged => {
this.cache.clear();
}
});
Self {
cache: HashMap::default(),
tool_working_set,
_tool_working_set_subscription,
}
}
pub fn incompatible_tools(
&mut self,
model: &Arc<dyn LanguageModel>,
cx: &App,
) -> &[Arc<dyn Tool>] {
self.cache
.entry(model.tool_input_format())
.or_insert_with(|| {
self.tool_working_set
.read(cx)
.enabled_tools(cx)
.iter()
.filter(|tool| tool.input_schema(model.tool_input_format()).is_err())
.cloned()
.collect()
})
}
}
pub struct IncompatibleToolsTooltip {
pub incompatible_tools: Vec<Arc<dyn Tool>>,
}
impl Render for IncompatibleToolsTooltip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
ui::tooltip_container(window, cx, |container, _, cx| {
container
.w_72()
.child(Label::new("Incompatible Tools").size(LabelSize::Small))
.child(
Label::new(
"This model is incompatible with the following tools from your MCPs:",
)
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(
v_flex()
.my_1p5()
.py_0p5()
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.children(
self.incompatible_tools
.iter()
.map(|tool| Label::new(tool.name()).size(LabelSize::Small).buffer_font(cx)),
),
)
.child(Label::new("What To Do Instead").size(LabelSize::Small))
.child(
Label::new(
"Every other tool continues to work with this model, but to specifically use those, switch to another model.",
)
.size(LabelSize::Small)
.color(Color::Muted),
)
})
}
}

View File

@@ -5,7 +5,7 @@ use assistant_tool::{Tool, ToolWorkingSet};
use collections::HashMap;
use futures::FutureExt as _;
use futures::future::Shared;
use gpui::{App, SharedString, Task};
use gpui::{App, Entity, SharedString, Task};
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
@@ -49,7 +49,7 @@ impl ToolUseStatus {
}
pub struct ToolUseState {
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
@@ -59,7 +59,7 @@ pub struct ToolUseState {
pub const USING_TOOL_MARKER: &str = "<using_tool>";
impl ToolUseState {
pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
@@ -73,7 +73,7 @@ impl ToolUseState {
///
/// Accepts a function to filter the tools that should be used to populate the state.
pub fn from_serialized_messages(
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self {
@@ -199,12 +199,12 @@ impl ToolUseState {
}
})();
let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
{
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
let (icon, needs_confirmation) =
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
@@ -226,7 +226,7 @@ impl ToolUseState {
input: &serde_json::Value,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.tool(tool_name, cx) {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
tool.ui_text(input).into()
} else {
format!("Unknown tool {tool_name:?}").into()

View File

@@ -1,5 +1,7 @@
mod agent_notification;
mod context_pill;
mod user_spending;
pub use agent_notification::*;
pub use context_pill::*;
// pub use user_spending::*;

View File

@@ -280,9 +280,10 @@ impl AddedContext {
}
AssistantContext::Directory(directory_context) => {
// TODO: handle worktree disambiguation. Maybe by storing an `Arc<dyn File>` to also
// handle renames?
let full_path = &directory_context.project_path.path;
let full_path = directory_context
.worktree
.read(cx)
.full_path(&directory_context.path);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path

View File

@@ -0,0 +1,186 @@
use gpui::{Entity, Render};
use ui::{ProgressBar, prelude::*};
#[derive(RegisterComponent)]
pub struct UserSpending {
free_tier_current: u32,
free_tier_cap: u32,
over_tier_current: u32,
over_tier_cap: u32,
free_tier_progress: Entity<ProgressBar>,
over_tier_progress: Entity<ProgressBar>,
}
impl UserSpending {
pub fn new(
free_tier_current: u32,
free_tier_cap: u32,
over_tier_current: u32,
over_tier_cap: u32,
cx: &mut App,
) -> Self {
let free_tier_capped = free_tier_current == free_tier_cap;
let free_tier_near_capped =
free_tier_current as f32 / 100.0 >= free_tier_cap as f32 / 100.0 * 0.9;
let over_tier_capped = over_tier_current == over_tier_cap;
let over_tier_near_capped =
over_tier_current as f32 / 100.0 >= over_tier_cap as f32 / 100.0 * 0.9;
let free_tier_progress = cx.new(|cx| {
ProgressBar::new(
"free_tier",
free_tier_current as f32,
free_tier_cap as f32,
cx,
)
});
let over_tier_progress = cx.new(|cx| {
ProgressBar::new(
"over_tier",
over_tier_current as f32,
over_tier_cap as f32,
cx,
)
});
if free_tier_capped {
free_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().error);
});
} else if free_tier_near_capped {
free_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().warning);
});
}
if over_tier_capped {
over_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().error);
});
} else if over_tier_near_capped {
over_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().warning);
});
}
Self {
free_tier_current,
free_tier_cap,
over_tier_current,
over_tier_cap,
free_tier_progress,
over_tier_progress,
}
}
}
impl Render for UserSpending {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let formatted_free_tier = format!(
"${} / ${}",
self.free_tier_current as f32 / 100.0,
self.free_tier_cap as f32 / 100.0
);
let formatted_over_tier = format!(
"${} / ${}",
self.over_tier_current as f32 / 100.0,
self.over_tier_cap as f32 / 100.0
);
v_group()
.elevation_2(cx)
.py_1p5()
.px_2p5()
.w(px(360.))
.child(
v_flex()
.child(
v_flex()
.p_1p5()
.gap_0p5()
.child(
h_flex()
.justify_between()
.child(Label::new("Free Tier Usage").size(LabelSize::Small))
.child(
Label::new(formatted_free_tier)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(self.free_tier_progress.clone()),
)
.child(
v_flex()
.p_1p5()
.gap_0p5()
.child(
h_flex()
.justify_between()
.child(Label::new("Current Spending").size(LabelSize::Small))
.child(
Label::new(formatted_over_tier)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(self.over_tier_progress.clone()),
),
)
}
}
impl Component for UserSpending {
fn scope() -> ComponentScope {
ComponentScope::None
}
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
let new_user = cx.new(|cx| UserSpending::new(0, 2000, 0, 2000, cx));
let free_capped = cx.new(|cx| UserSpending::new(2000, 2000, 0, 2000, cx));
let free_near_capped = cx.new(|cx| UserSpending::new(1800, 2000, 0, 2000, cx));
let over_near_capped = cx.new(|cx| UserSpending::new(2000, 2000, 1800, 2000, cx));
let over_capped = cx.new(|cx| UserSpending::new(1000, 2000, 2000, 2000, cx));
Some(
v_flex()
.gap_6()
.p_4()
.children(vec![example_group(vec![
single_example(
"New User",
div().size_full().child(new_user.clone()).into_any_element(),
),
single_example(
"Free Tier Capped",
div()
.size_full()
.child(free_capped.clone())
.into_any_element(),
),
single_example(
"Free Tier Near Capped",
div()
.size_full()
.child(free_near_capped.clone())
.into_any_element(),
),
single_example(
"Over Tier Near Capped",
div()
.size_full()
.child(over_near_capped.clone())
.into_any_element(),
),
single_example(
"Over Tier Capped",
div()
.size_full()
.child(over_capped.clone())
.into_any_element(),
),
])])
.into_any_element(),
)
}
}

View File

@@ -1,52 +0,0 @@
// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
use std::process::Command;
fn main() {
if cfg!(target_os = "macos") {
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
// Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
// Seems to be required to enable Swift concurrency
println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
// Register exported Objective-C selectors, protocols, etc
println!("cargo:rustc-link-arg=-Wl,-ObjC");
}
// Populate git sha environment variable if git is available
println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
println!(
"cargo:rustc-env=TARGET={}",
std::env::var("TARGET").unwrap()
);
if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
if output.status.success() {
let git_sha = String::from_utf8_lossy(&output.stdout);
let git_sha = git_sha.trim();
println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
if let Ok(build_profile) = std::env::var("PROFILE") {
if build_profile == "release" {
// This is currently the best way to make `cargo build ...`'s build script
// to print something to stdout without extra verbosity.
println!(
"cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
);
}
}
}
}
#[cfg(target_os = "windows")]
{
#[cfg(target_env = "msvc")]
{
// todo(windows): This is to avoid stack overflow. Remove it when solved.
println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
}
}
}

View File

@@ -1,384 +0,0 @@
use crate::git_commands::{run_git, setup_temp_repo};
use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
use crate::{get_exercise_language, get_exercise_name};
use agent::RequestKind;
use anyhow::{Result, anyhow};
use collections::HashMap;
use gpui::{App, Task};
use language_model::{LanguageModel, TokenUsage};
use serde::{Deserialize, Serialize};
use std::{
fs,
io::Write,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, SystemTime},
};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct EvalResult {
pub exercise_name: String,
pub diff: String,
pub assistant_response: String,
pub elapsed_time_ms: u128,
pub timestamp: u128,
// Token usage fields
pub input_tokens: usize,
pub output_tokens: usize,
pub total_tokens: usize,
pub tool_use_counts: usize,
}
pub struct EvalOutput {
pub diff: String,
pub last_message: String,
pub elapsed_time: Duration,
pub assistant_response_count: usize,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub token_usage: TokenUsage,
}
#[derive(Deserialize)]
pub struct EvalSetup {
pub url: String,
pub base_sha: String,
}
pub struct Eval {
pub repo_path: PathBuf,
pub eval_setup: EvalSetup,
pub user_prompt: String,
}
impl Eval {
// Keep this method for potential future use, but mark it as intentionally unused
#[allow(dead_code)]
pub async fn load(_name: String, path: PathBuf, repos_dir: &Path) -> Result<Self> {
let prompt_path = path.join("prompt.txt");
let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
let setup_path = path.join("setup.json");
let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
// Move this internal function inside the load method since it's only used here
fn repo_dir_name(url: &str) -> String {
url.trim_start_matches("https://")
.replace(|c: char| !c.is_alphanumeric(), "_")
}
let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
Ok(Eval {
repo_path,
eval_setup,
user_prompt,
})
}
pub fn run(
self,
app_state: Arc<HeadlessAppState>,
model: Arc<dyn LanguageModel>,
cx: &mut App,
) -> Task<Result<EvalOutput>> {
cx.spawn(async move |cx| {
run_git(&self.repo_path, &["checkout", &self.eval_setup.base_sha]).await?;
let (assistant, done_rx) =
cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
let _worktree = assistant
.update(cx, |assistant, cx| {
assistant.project.update(cx, |project, cx| {
project.create_worktree(&self.repo_path, true, cx)
})
})?
.await?;
let start_time = std::time::SystemTime::now();
let (system_prompt_context, load_error) = cx
.update(|cx| {
assistant
.read(cx)
.thread
.read(cx)
.load_system_prompt_context(cx)
})?
.await;
if let Some(load_error) = load_error {
return Err(anyhow!("{:?}", load_error));
};
assistant.update(cx, |assistant, cx| {
assistant.thread.update(cx, |thread, cx| {
let context = vec![];
thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
thread.set_system_prompt_context(system_prompt_context);
thread.send_to_model(model, RequestKind::Chat, cx);
});
})?;
done_rx.recv().await??;
// Add this section to check untracked files
println!("Checking for untracked files:");
let untracked = run_git(
&self.repo_path,
&["ls-files", "--others", "--exclude-standard"],
)
.await?;
if untracked.is_empty() {
println!("No untracked files found");
} else {
// Add all files to git so they appear in the diff
println!("Adding untracked files to git");
run_git(&self.repo_path, &["add", "."]).await?;
}
// get git status
let _status = run_git(&self.repo_path, &["status", "--short"]).await?;
let elapsed_time = start_time.elapsed()?;
// Get diff of staged changes (the files we just added)
let staged_diff = run_git(&self.repo_path, &["diff", "--staged"]).await?;
// Get diff of unstaged changes
let unstaged_diff = run_git(&self.repo_path, &["diff"]).await?;
// Combine both diffs
let diff = if unstaged_diff.is_empty() {
staged_diff
} else if staged_diff.is_empty() {
unstaged_diff
} else {
format!(
"# Staged changes\n{}\n\n# Unstaged changes\n{}",
staged_diff, unstaged_diff
)
};
assistant.update(cx, |assistant, cx| {
let thread = assistant.thread.read(cx);
let last_message = thread.messages().last().unwrap();
if last_message.role != language_model::Role::Assistant {
return Err(anyhow!("Last message is not from assistant"));
}
let assistant_response_count = thread
.messages()
.filter(|message| message.role == language_model::Role::Assistant)
.count();
Ok(EvalOutput {
diff,
last_message: last_message.to_string(),
elapsed_time,
assistant_response_count,
tool_use_counts: assistant.tool_use_counts.clone(),
token_usage: thread.cumulative_token_usage(),
})
})?
})
}
}
impl EvalOutput {
// Keep this method for potential future use, but mark it as intentionally unused
#[allow(dead_code)]
pub fn save_to_directory(&self, output_dir: &Path, eval_output_value: String) -> Result<()> {
// Create the output directory if it doesn't exist
fs::create_dir_all(&output_dir)?;
// Save the diff to a file
let diff_path = output_dir.join("diff.patch");
let mut diff_file = fs::File::create(&diff_path)?;
diff_file.write_all(self.diff.as_bytes())?;
// Save the last message to a file
let message_path = output_dir.join("assistant_response.txt");
let mut message_file = fs::File::create(&message_path)?;
message_file.write_all(self.last_message.as_bytes())?;
// Current metrics for this run
let current_metrics = serde_json::json!({
"elapsed_time_ms": self.elapsed_time.as_millis(),
"assistant_response_count": self.assistant_response_count,
"tool_use_counts": self.tool_use_counts,
"token_usage": self.token_usage,
"eval_output_value": eval_output_value,
});
// Get current timestamp in milliseconds
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_millis()
.to_string();
// Path to metrics file
let metrics_path = output_dir.join("metrics.json");
// Load existing metrics if the file exists, or create a new object
let mut historical_metrics = if metrics_path.exists() {
let metrics_content = fs::read_to_string(&metrics_path)?;
serde_json::from_str::<serde_json::Value>(&metrics_content)
.unwrap_or_else(|_| serde_json::json!({}))
} else {
serde_json::json!({})
};
// Add new run with timestamp as key
if let serde_json::Value::Object(ref mut map) = historical_metrics {
map.insert(timestamp, current_metrics);
}
// Write updated metrics back to file
let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
let mut metrics_file = fs::File::create(&metrics_path)?;
metrics_file.write_all(metrics_json.as_bytes())?;
Ok(())
}
}
pub async fn read_instructions(exercise_path: &Path) -> Result<String> {
let instructions_path = exercise_path.join(".docs").join("instructions.md");
println!("Reading instructions from: {}", instructions_path.display());
let instructions = smol::unblock(move || std::fs::read_to_string(&instructions_path)).await?;
Ok(instructions)
}
pub async fn save_eval_results(exercise_path: &Path, results: Vec<EvalResult>) -> Result<()> {
let eval_dir = exercise_path.join("evaluation");
fs::create_dir_all(&eval_dir)?;
let eval_file = eval_dir.join("evals.json");
println!("Saving evaluation results to: {}", eval_file.display());
println!(
"Results to save: {} evaluations for exercise path: {}",
results.len(),
exercise_path.display()
);
// Check file existence before reading/writing
if eval_file.exists() {
println!("Existing evals.json file found, will update it");
} else {
println!("No existing evals.json file found, will create new one");
}
// Structure to organize evaluations by test name and timestamp
let mut eval_data: serde_json::Value = if eval_file.exists() {
let content = fs::read_to_string(&eval_file)?;
serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({}))
} else {
serde_json::json!({})
};
// Get current timestamp for this batch of results
let timestamp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_millis()
.to_string();
// Group the new results by test name (exercise name)
for result in results {
let exercise_name = &result.exercise_name;
println!("Adding result: exercise={}", exercise_name);
// Ensure the exercise entry exists
if eval_data.get(exercise_name).is_none() {
eval_data[exercise_name] = serde_json::json!({});
}
// Ensure the timestamp entry exists as an object
if eval_data[exercise_name].get(&timestamp).is_none() {
eval_data[exercise_name][&timestamp] = serde_json::json!({});
}
// Add this result under the timestamp with template name as key
eval_data[exercise_name][&timestamp] = serde_json::to_value(&result)?;
}
// Write back to file with pretty formatting
let json_content = serde_json::to_string_pretty(&eval_data)?;
match fs::write(&eval_file, json_content) {
Ok(_) => println!("✓ Successfully saved results to {}", eval_file.display()),
Err(e) => println!("✗ Failed to write results file: {}", e),
}
Ok(())
}
pub async fn run_exercise_eval(
exercise_path: PathBuf,
model: Arc<dyn LanguageModel>,
app_state: Arc<HeadlessAppState>,
base_sha: String,
_framework_path: PathBuf,
cx: gpui::AsyncApp,
) -> Result<EvalResult> {
let exercise_name = get_exercise_name(&exercise_path);
let language = get_exercise_language(&exercise_path)?;
let mut instructions = read_instructions(&exercise_path).await?;
instructions.push_str(&format!(
"\n\nWhen writing the code for this prompt, use {} to achieve the goal.",
language
));
println!("Running evaluation for exercise: {}", exercise_name);
// Create temporary directory with exercise files
let temp_dir = setup_temp_repo(&exercise_path, &base_sha).await?;
let temp_path = temp_dir.path().to_path_buf();
let local_commit_sha = run_git(&temp_path, &["rev-parse", "HEAD"]).await?;
let start_time = SystemTime::now();
// Create a basic eval struct to work with the existing system
let eval = Eval {
repo_path: temp_path.clone(),
eval_setup: EvalSetup {
url: format!("file://{}", temp_path.display()),
base_sha: local_commit_sha, // Use the local commit SHA instead of the framework base SHA
},
user_prompt: instructions.clone(),
};
// Run the evaluation
let eval_output = cx
.update(|cx| eval.run(app_state.clone(), model.clone(), cx))?
.await?;
// Get diff from git
let diff = eval_output.diff.clone();
let elapsed_time = start_time.elapsed()?;
// Calculate total tokens as the sum of input and output tokens
let input_tokens = eval_output.token_usage.input_tokens;
let output_tokens = eval_output.token_usage.output_tokens;
let tool_use_counts = eval_output.tool_use_counts.values().sum::<u32>();
let total_tokens = input_tokens + output_tokens;
// Save results to evaluation directory
let result = EvalResult {
exercise_name: exercise_name.clone(),
diff,
assistant_response: eval_output.last_message.clone(),
elapsed_time_ms: elapsed_time.as_millis(),
timestamp: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_millis(),
// Convert u32 token counts to usize
input_tokens: input_tokens.try_into().unwrap(),
output_tokens: output_tokens.try_into().unwrap(),
total_tokens: total_tokens.try_into().unwrap(),
tool_use_counts: tool_use_counts.try_into().unwrap(),
};
Ok(result)
}

View File

@@ -1,149 +0,0 @@
use anyhow::{Result, anyhow};
use std::{
fs,
path::{Path, PathBuf},
};
pub fn get_exercise_name(exercise_path: &Path) -> String {
exercise_path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string()
}
pub fn get_exercise_language(exercise_path: &Path) -> Result<String> {
// Extract the language from path (data/python/exercises/... => python)
let parts: Vec<_> = exercise_path.components().collect();
for (i, part) in parts.iter().enumerate() {
if i > 0 && part.as_os_str() == "eval_code" {
if i + 1 < parts.len() {
let language = parts[i + 1].as_os_str().to_string_lossy().to_string();
return Ok(language);
}
}
}
Err(anyhow!(
"Could not determine language from path: {:?}",
exercise_path
))
}
pub fn find_exercises(
framework_path: &Path,
languages: &[&str],
max_per_language: Option<usize>,
) -> Result<Vec<PathBuf>> {
let mut all_exercises = Vec::new();
println!("Searching for exercises in languages: {:?}", languages);
for language in languages {
let language_dir = framework_path
.join("eval_code")
.join(language)
.join("exercises")
.join("practice");
println!("Checking language directory: {:?}", language_dir);
if !language_dir.exists() {
println!("Warning: Language directory not found: {:?}", language_dir);
continue;
}
let mut exercises = Vec::new();
match fs::read_dir(&language_dir) {
Ok(entries) => {
for entry_result in entries {
match entry_result {
Ok(entry) => {
let path = entry.path();
if path.is_dir() {
// Special handling for "internal" directory
if *language == "internal" {
// Check for repo_info.json to validate it's an internal exercise
let repo_info_path = path.join(".meta").join("repo_info.json");
let instructions_path =
path.join(".docs").join("instructions.md");
if repo_info_path.exists() && instructions_path.exists() {
exercises.push(path);
}
} else {
// Map the language to the file extension - original code
let language_extension = match *language {
"python" => "py",
"go" => "go",
"rust" => "rs",
"typescript" => "ts",
"javascript" => "js",
"ruby" => "rb",
"php" => "php",
"bash" => "sh",
"multi" => "diff",
_ => continue, // Skip unsupported languages
};
// Check if this is a valid exercise with instructions and example
let instructions_path =
path.join(".docs").join("instructions.md");
let has_instructions = instructions_path.exists();
let example_path = path
.join(".meta")
.join(format!("example.{}", language_extension));
let has_example = example_path.exists();
if has_instructions && has_example {
exercises.push(path);
}
}
}
}
Err(err) => println!("Error reading directory entry: {}", err),
}
}
}
Err(err) => println!(
"Error reading directory {}: {}",
language_dir.display(),
err
),
}
// Sort exercises by name for consistent selection
exercises.sort_by(|a, b| {
let a_name = a.file_name().unwrap_or_default().to_string_lossy();
let b_name = b.file_name().unwrap_or_default().to_string_lossy();
a_name.cmp(&b_name)
});
// Apply the limit if specified
if let Some(limit) = max_per_language {
if exercises.len() > limit {
println!(
"Limiting {} exercises to {} for language {}",
exercises.len(),
limit,
language
);
exercises.truncate(limit);
}
}
println!(
"Found {} exercises for language {}: {:?}",
exercises.len(),
language,
exercises
.iter()
.map(|p| p.file_name().unwrap_or_default().to_string_lossy())
.collect::<Vec<_>>()
);
all_exercises.extend(exercises);
}
Ok(all_exercises)
}

View File

@@ -1,125 +0,0 @@
use anyhow::{Result, anyhow};
use serde::Deserialize;
use std::{fs, path::Path};
use tempfile::TempDir;
use util::command::new_smol_command;
use walkdir::WalkDir;
#[derive(Debug, Deserialize)]
pub struct SetupConfig {
#[serde(rename = "base.sha")]
pub base_sha: String,
}
#[derive(Debug, Deserialize)]
pub struct RepoInfo {
pub remote_url: String,
pub head_sha: String,
}
pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = new_smol_command("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
if output.status.success() {
Ok(String::from_utf8(output.stdout)?.trim().to_string())
} else {
Err(anyhow!(
"Git command failed: {} with status: {}",
args.join(" "),
output.status
))
}
}
pub async fn read_base_sha(framework_path: &Path) -> Result<String> {
let setup_path = framework_path.join("setup.json");
let setup_content = smol::unblock(move || std::fs::read_to_string(&setup_path)).await?;
let setup_config: SetupConfig = serde_json_lenient::from_str_lenient(&setup_content)?;
Ok(setup_config.base_sha)
}
pub async fn read_repo_info(exercise_path: &Path) -> Result<RepoInfo> {
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
println!("Reading repo info from: {}", repo_info_path.display());
let repo_info_content = smol::unblock(move || std::fs::read_to_string(&repo_info_path)).await?;
let repo_info: RepoInfo = serde_json_lenient::from_str_lenient(&repo_info_content)?;
// Remove any quotes from the strings
let remote_url = repo_info.remote_url.trim_matches('"').to_string();
let head_sha = repo_info.head_sha.trim_matches('"').to_string();
Ok(RepoInfo {
remote_url,
head_sha,
})
}
pub async fn setup_temp_repo(exercise_path: &Path, _base_sha: &str) -> Result<TempDir> {
let temp_dir = TempDir::new()?;
// Check if this is an internal exercise by looking for repo_info.json
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
if repo_info_path.exists() {
// This is an internal exercise, handle it differently
let repo_info = read_repo_info(exercise_path).await?;
// Clone the repository to the temp directory
let url = repo_info.remote_url;
let clone_path = temp_dir.path();
println!(
"Cloning repository from {} to {}",
url,
clone_path.display()
);
run_git(
&std::env::current_dir()?,
&["clone", &url, &clone_path.to_string_lossy()],
)
.await?;
// Checkout the specified commit
println!("Checking out commit: {}", repo_info.head_sha);
run_git(temp_dir.path(), &["checkout", &repo_info.head_sha]).await?;
println!("Successfully set up internal repository");
} else {
// Original code for regular exercises
// Copy the exercise files to the temp directory, excluding .docs and .meta
for entry in WalkDir::new(exercise_path).min_depth(0).max_depth(10) {
let entry = entry?;
let source_path = entry.path();
// Skip .docs and .meta directories completely
if source_path.starts_with(exercise_path.join(".docs"))
|| source_path.starts_with(exercise_path.join(".meta"))
{
continue;
}
if source_path.is_file() {
let relative_path = source_path.strip_prefix(exercise_path)?;
let dest_path = temp_dir.path().join(relative_path);
// Make sure parent directories exist
if let Some(parent) = dest_path.parent() {
fs::create_dir_all(parent)?;
}
fs::copy(source_path, dest_path)?;
}
}
// Initialize git repo in the temp directory
run_git(temp_dir.path(), &["init"]).await?;
run_git(temp_dir.path(), &["add", "."]).await?;
run_git(temp_dir.path(), &["commit", "-m", "Initial commit"]).await?;
println!("Created temp repo without .docs and .meta directories");
}
Ok(temp_dir)
}

View File

@@ -1,246 +0,0 @@
use agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
use anyhow::anyhow;
use assistant_tool::ToolWorkingSet;
use client::{Client, UserStore};
use collections::HashMap;
use dap::DapRegistry;
use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use node_runtime::NodeRuntime;
use project::{Project, RealFs};
use prompt_store::PromptBuilder;
use settings::SettingsStore;
use smol::channel;
use std::sync::Arc;
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct HeadlessAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
// Additional fields not present in `workspace::AppState`.
pub prompt_builder: Arc<PromptBuilder>,
}
pub struct HeadlessAssistant {
pub thread: Entity<Thread>,
pub project: Entity<Project>,
#[allow(dead_code)]
pub thread_store: Entity<ThreadStore>,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub done_tx: channel::Sender<anyhow::Result<()>>,
_subscription: Subscription,
}
impl HeadlessAssistant {
pub fn new(
app_state: Arc<HeadlessAppState>,
cx: &mut App,
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
let env = None;
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
Arc::new(DapRegistry::default()),
app_state.fs.clone(),
env,
cx,
);
let tools = Arc::new(ToolWorkingSet::default());
let thread_store =
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
let headless_thread = cx.new(move |cx| Self {
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
thread,
project,
thread_store,
tool_use_counts: HashMap::default(),
done_tx,
});
Ok((headless_thread, done_rx))
}
fn handle_thread_event(
&mut self,
thread: Entity<Thread>,
event: &ThreadEvent,
cx: &mut Context<Self>,
) {
match event {
ThreadEvent::ShowError(err) => self
.done_tx
.send_blocking(Err(anyhow!("{:?}", err)))
.unwrap(),
ThreadEvent::DoneStreaming => {
let thread = thread.read(cx);
if let Some(message) = thread.messages().last() {
println!("Message: {}", message.to_string());
}
if thread.all_tools_finished() {
self.done_tx.send_blocking(Ok(())).unwrap()
}
}
ThreadEvent::UsePendingTools => {
thread.update(cx, |thread, cx| {
thread.use_pending_tools(cx);
});
}
ThreadEvent::ToolConfirmationNeeded => {
// Automatically approve all tools that need confirmation in headless mode
println!("Tool confirmation needed - automatically approving in headless mode");
// Get the tools needing confirmation
let tools_needing_confirmation: Vec<_> = thread
.read(cx)
.tools_needing_confirmation()
.cloned()
.collect();
// Run each tool that needs confirmation
for tool_use in tools_needing_confirmation {
if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
thread.update(cx, |thread, cx| {
println!("Auto-approving tool: {}", tool_use.name);
// Create a request to send to the tool
let request = thread.to_completion_request(RequestKind::Chat, cx);
let messages = Arc::new(request.messages);
// Run the tool
thread.run_tool(
tool_use.id.clone(),
tool_use.ui_text.clone(),
tool_use.input.clone(),
&messages,
tool,
cx,
);
});
}
}
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
if let Some(pending_tool_use) = pending_tool_use {
println!(
"Used tool {} with input: {}",
pending_tool_use.name, pending_tool_use.input
);
*self
.tool_use_counts
.entry(pending_tool_use.name.clone())
.or_insert(0) += 1;
}
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
println!("Tool result: {:?}", tool_result);
}
if thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.default_model() {
thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
thread.send_to_model(model.model, RequestKind::Chat, cx);
});
} else {
println!(
"Warning: No active language model available to continue conversation"
);
}
}
}
_ => {}
}
}
}
pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
release_channel::init(SemanticVersion::default(), cx);
gpui_tokio::init(cx);
let mut settings_store = SettingsStore::new(cx);
settings_store
.set_default_settings(settings::default_settings().as_ref(), cx)
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
Project::init_settings(cx);
let client = Client::production(cx);
cx.set_http_client(client.http_client().clone());
let git_binary_path = None;
let fs = Arc::new(RealFs::new(
git_binary_path,
cx.background_executor().clone(),
));
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
assistant_tools::init(client.http_client().clone(), cx);
context_server::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
Arc::new(HeadlessAppState {
languages,
client,
user_store,
fs,
node_runtime: NodeRuntime::unavailable(),
prompt_builder,
})
}
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model = model_registry
.available_models(cx)
.find(|model| model.id().0 == model_name);
let Some(model) = model else {
return Err(anyhow!(
"No language model named {} was available. Available models: {}",
model_name,
model_registry
.available_models(cx)
.map(|model| model.id().0.clone())
.collect::<Vec<_>>()
.join(", ")
));
};
Ok(model)
}
pub fn authenticate_model_provider(
provider_id: LanguageModelProviderId,
cx: &mut App,
) -> Task<std::result::Result<(), AuthenticateError>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}

View File

@@ -1,205 +0,0 @@
mod eval;
mod get_exercise;
mod git_commands;
mod headless_assistant;
use clap::Parser;
use eval::{run_exercise_eval, save_eval_results};
use futures::stream::{self, StreamExt};
use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
use git_commands::read_base_sha;
use gpui::Application;
use headless_assistant::{authenticate_model_provider, find_model};
use language_model::LanguageModelRegistry;
use reqwest_client::ReqwestClient;
use std::{path::PathBuf, sync::Arc};
#[derive(Parser, Debug)]
#[command(
name = "agent_eval",
disable_version_flag = true,
before_help = "Tool eval runner"
)]
struct Args {
/// Match the names of evals to run.
#[arg(long)]
exercise_names: Vec<String>,
/// Runs all exercises, causes the exercise_names to be ignored.
#[arg(long)]
all: bool,
/// Supported language types to evaluate (default: internal).
/// Internal is data generated from the agent panel
#[arg(long, default_value = "internal")]
languages: String,
/// Name of the model (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model_name: String,
/// Name of the editor model (default: value of `--model_name`).
#[arg(long)]
editor_model_name: Option<String>,
/// Number of evaluations to run concurrently (default: 3)
#[arg(short, long, default_value = "5")]
concurrency: usize,
/// Maximum number of exercises to evaluate per language
#[arg(long)]
max_exercises_per_language: Option<usize>,
}
fn main() {
env_logger::init();
let args = Args::parse();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
// Path to the zed-ace-framework repo
let framework_path = PathBuf::from("../zed-ace-framework")
.canonicalize()
.unwrap();
// Fix the 'languages' lifetime issue by creating owned Strings instead of slices
let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
println!("Using zed-ace-framework at: {:?}", framework_path);
println!("Evaluating languages: {:?}", languages);
app.run(move |cx| {
let app_state = headless_assistant::init(cx);
let model = find_model(&args.model_name, cx).unwrap();
let editor_model = if let Some(model_name) = &args.editor_model_name {
find_model(model_name, cx).unwrap()
} else {
model.clone()
};
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(Some(model.clone()), cx);
});
let model_provider_id = model.provider_id();
let editor_model_provider_id = editor_model.provider_id();
let framework_path_clone = framework_path.clone();
let languages_clone = languages.clone();
let exercise_names = args.exercise_names.clone();
let all_flag = args.all;
cx.spawn(async move |cx| {
// Authenticate all model providers first
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
println!("framework path: {}", framework_path_clone.display());
let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
println!("base sha: {}", base_sha);
let all_exercises = find_exercises(
&framework_path_clone,
&languages_clone
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
args.max_exercises_per_language,
)
.unwrap();
println!("Found {} exercises total", all_exercises.len());
// Filter exercises if specific ones were requested
let exercises_to_run = if !exercise_names.is_empty() {
// If exercise names are specified, filter by them regardless of --all flag
all_exercises
.into_iter()
.filter(|path| {
let name = get_exercise_name(path);
exercise_names.iter().any(|filter| name.contains(filter))
})
.collect()
} else if all_flag {
// Only use all_flag if no exercise names are specified
all_exercises
} else {
// Default behavior (no filters)
all_exercises
};
println!("Will run {} exercises", exercises_to_run.len());
// Create exercise eval tasks - each exercise is a single task that will run templates sequentially
let exercise_tasks: Vec<_> = exercises_to_run
.into_iter()
.map(|exercise_path| {
let exercise_name = get_exercise_name(&exercise_path);
let model_clone = model.clone();
let app_state_clone = app_state.clone();
let base_sha_clone = base_sha.clone();
let framework_path_clone = framework_path_clone.clone();
let cx_clone = cx.clone();
async move {
println!("Processing exercise: {}", exercise_name);
let mut exercise_results = Vec::new();
match run_exercise_eval(
exercise_path.clone(),
model_clone.clone(),
app_state_clone.clone(),
base_sha_clone.clone(),
framework_path_clone.clone(),
cx_clone.clone(),
)
.await
{
Ok(result) => {
println!("Completed {}", exercise_name);
exercise_results.push(result);
}
Err(err) => {
println!("Error running {}: {}", exercise_name, err);
}
}
// Save results for this exercise
if !exercise_results.is_empty() {
if let Err(err) =
save_eval_results(&exercise_path, exercise_results.clone()).await
{
println!("Error saving results for {}: {}", exercise_name, err);
} else {
println!("Saved results for {}", exercise_name);
}
}
exercise_results
}
})
.collect();
println!(
"Running {} exercises with concurrency: {}",
exercise_tasks.len(),
args.concurrency
);
// Run exercises concurrently, with each exercise running its templates sequentially
let all_results = stream::iter(exercise_tasks)
.buffer_unordered(args.concurrency)
.flat_map(stream::iter)
.collect::<Vec<_>>()
.await;
println!("Completed {} evaluation runs", all_results.len());
cx.update(|cx| cx.quit()).unwrap();
})
.detach();
});
println!("Done running evals");
}

View File

@@ -1,25 +0,0 @@
[package]
name = "agent_rules"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/agent_rules.rs"
doctest = false
[dependencies]
anyhow.workspace = true
fs.workspace = true
gpui.workspace = true
prompt_store.workspace = true
util.workspace = true
worktree.workspace = true
workspace-hack = { version = "0.1", path = "../../tooling/workspace-hack" }
[dev-dependencies]
indoc.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-GPL

View File

@@ -1,51 +0,0 @@
use std::sync::Arc;
use anyhow::{Context as _, Result};
use fs::Fs;
use gpui::{App, AppContext, Task};
use prompt_store::SystemPromptRulesFile;
use util::maybe;
use worktree::Worktree;
const RULES_FILE_NAMES: [&'static str; 6] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
".github/copilot-instructions.md",
"CLAUDE.md",
];
pub fn load_worktree_rules_file(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Option<Task<Result<SystemPromptRulesFile>>> {
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
})
.next();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
selected_rules_file.map(|(path_in_worktree, abs_path)| {
let fs = fs.clone();
cx.background_spawn(maybe!(async move {
let abs_path = abs_path?;
let text = fs
.load(&abs_path)
.await
.with_context(|| format!("Failed to load assistant rules file {:?}", abs_path))?;
anyhow::Ok(SystemPromptRulesFile {
path_in_worktree,
abs_path: abs_path.into(),
text: text.trim().to_string(),
})
}))
})
}

View File

@@ -37,9 +37,9 @@ pub enum AnthropicModelMode {
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
#[default]
#[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
Claude3_5Sonnet,
#[default]
#[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
Claude3_7Sonnet,
#[serde(
@@ -724,4 +724,54 @@ impl ApiError {
pub fn is_rate_limit_error(&self) -> bool {
matches!(self.error_type.as_str(), "rate_limit_error")
}
pub fn match_window_exceeded(&self) -> Option<usize> {
let Some(ApiErrorCode::InvalidRequestError) = self.code() else {
return None;
};
parse_prompt_too_long(&self.message)
}
}
pub fn parse_prompt_too_long(message: &str) -> Option<usize> {
message
.strip_prefix("prompt is too long: ")?
.split_once(" tokens")?
.0
.parse::<usize>()
.ok()
}
#[test]
fn test_match_window_exceeded() {
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "prompt is too long: 220000 tokens > 200000".to_string(),
};
assert_eq!(error.match_window_exceeded(), Some(220_000));
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "prompt is too long: 1234953 tokens".to_string(),
};
assert_eq!(error.match_window_exceeded(), Some(1234953));
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "not a prompt length error".to_string(),
};
assert_eq!(error.match_window_exceeded(), None);
let error = ApiError {
error_type: "rate_limit_error".to_string(),
message: "prompt is too long: 12345 tokens".to_string(),
};
assert_eq!(error.match_window_exceeded(), None);
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "prompt is too long: invalid tokens".to_string(),
};
assert_eq!(error.match_window_exceeded(), None);
}

View File

@@ -69,7 +69,7 @@ pub enum AssistantProviderContentV1 {
},
}
#[derive(Debug, Default)]
#[derive(Clone, Debug, Default)]
pub struct AssistantSettings {
pub enabled: bool,
pub button: bool,
@@ -742,7 +742,7 @@ mod tests {
AssistantSettings::get_global(cx).default_model,
LanguageModelSelection {
provider: "zed.dev".into(),
model: "claude-3-5-sonnet-latest".into(),
model: "claude-3-7-sonnet-latest".into(),
}
);
});

View File

@@ -3,8 +3,8 @@ use buffer_diff::BufferDiff;
use collections::BTreeMap;
use futures::{StreamExt, channel::mpsc};
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
use project::{Project, ProjectItem};
use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint};
use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
use std::{cmp, ops::Range, sync::Arc};
use text::{Edit, Patch, Rope};
use util::RangeExt;
@@ -49,6 +49,10 @@ impl ActionLog {
.tracked_buffers
.entry(buffer.clone())
.or_insert_with(|| {
let open_lsp_handle = self.project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
});
let text_snapshot = buffer.read(cx).text_snapshot();
let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
let (diff_update_tx, diff_update_rx) = mpsc::unbounded();
@@ -76,6 +80,7 @@ impl ActionLog {
version: buffer.read(cx).version(),
diff,
diff_update: diff_update_tx,
_open_lsp_handle: open_lsp_handle,
_maintain_diff: cx.spawn({
let buffer = buffer.clone();
async move |this, cx| {
@@ -358,10 +363,10 @@ impl ActionLog {
}
}
pub fn reject_edits_in_range(
pub fn reject_edits_in_ranges(
&mut self,
buffer: Entity<Buffer>,
buffer_range: Range<impl language::ToPoint>,
buffer_ranges: Vec<Range<impl language::ToPoint>>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else {
@@ -398,29 +403,15 @@ impl ActionLog {
}
TrackedBufferStatus::Modified => {
buffer.update(cx, |buffer, cx| {
let buffer_range =
buffer_range.start.to_point(buffer)..buffer_range.end.to_point(buffer);
let mut buffer_row_ranges = buffer_ranges
.into_iter()
.map(|range| {
range.start.to_point(buffer).row..range.end.to_point(buffer).row
})
.peekable();
let mut edits_to_revert = Vec::new();
for edit in tracked_buffer.unreviewed_changes.edits() {
if buffer_range.end.row < edit.new.start {
break;
} else if buffer_range.start.row > edit.new.end {
continue;
}
let old_range = tracked_buffer
.base_text
.point_to_offset(Point::new(edit.old.start, 0))
..tracked_buffer.base_text.point_to_offset(cmp::min(
Point::new(edit.old.end, 0),
tracked_buffer.base_text.max_point(),
));
let old_text = tracked_buffer
.base_text
.chunks_in_range(old_range)
.collect::<String>();
let new_range = tracked_buffer
.snapshot
.anchor_before(Point::new(edit.new.start, 0))
@@ -428,7 +419,35 @@ impl ActionLog {
Point::new(edit.new.end, 0),
tracked_buffer.snapshot.max_point(),
));
edits_to_revert.push((new_range, old_text));
let new_row_range = new_range.start.to_point(buffer).row
..new_range.end.to_point(buffer).row;
let mut revert = false;
while let Some(buffer_row_range) = buffer_row_ranges.peek() {
if buffer_row_range.end < new_row_range.start {
buffer_row_ranges.next();
} else if buffer_row_range.start > new_row_range.end {
break;
} else {
revert = true;
break;
}
}
if revert {
let old_range = tracked_buffer
.base_text
.point_to_offset(Point::new(edit.old.start, 0))
..tracked_buffer.base_text.point_to_offset(cmp::min(
Point::new(edit.old.end, 0),
tracked_buffer.base_text.max_point(),
));
let old_text = tracked_buffer
.base_text
.chunks_in_range(old_range)
.collect::<String>();
edits_to_revert.push((new_range, old_text));
}
}
buffer.edit(edits_to_revert, None, cx);
@@ -594,6 +613,7 @@ fn point_to_row_edit(edit: Edit<Point>, old_text: &Rope, new_text: &Rope) -> Edi
}
}
#[derive(Copy, Clone, Debug)]
enum ChangeAuthor {
User,
Agent,
@@ -615,6 +635,7 @@ struct TrackedBuffer {
diff: Entity<BufferDiff>,
snapshot: text::BufferSnapshot,
diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>,
_open_lsp_handle: OpenLspBufferHandle,
_maintain_diff: Task<()>,
_subscription: Subscription,
}
@@ -1129,9 +1150,48 @@ mod tests {
)]
);
// If the rejected range doesn't overlap with any hunk, we ignore it.
action_log
.update(cx, |log, cx| {
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx)
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(4, 0)..Point::new(4, 0)],
cx,
)
})
.await
.unwrap();
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndE\nXYZf\nghi\njkl\nmnO"
);
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![
HunkStatus {
range: Point::new(1, 0)..Point::new(3, 0),
diff_status: DiffHunkStatusKind::Modified,
old_text: "def\n".into(),
},
HunkStatus {
range: Point::new(5, 0)..Point::new(5, 3),
diff_status: DiffHunkStatusKind::Modified,
old_text: "mno".into(),
}
],
)]
);
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(1, 0)],
cx,
)
})
.await
.unwrap();
@@ -1154,7 +1214,11 @@ mod tests {
action_log
.update(cx, |log, cx| {
log.reject_edits_in_range(buffer.clone(), Point::new(4, 0)..Point::new(4, 0), cx)
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(4, 0)..Point::new(4, 0)],
cx,
)
})
.await
.unwrap();
@@ -1166,6 +1230,82 @@ mod tests {
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test(iterations = 10)]
async fn test_reject_multiple_edits(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
.unwrap()
});
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx)
.unwrap()
});
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
});
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndE\nXYZf\nghi\njkl\nmnO"
);
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![
HunkStatus {
range: Point::new(1, 0)..Point::new(3, 0),
diff_status: DiffHunkStatusKind::Modified,
old_text: "def\n".into(),
},
HunkStatus {
range: Point::new(5, 0)..Point::new(5, 3),
diff_status: DiffHunkStatusKind::Modified,
old_text: "mno".into(),
}
],
)]
);
action_log.update(cx, |log, cx| {
let range_1 = buffer.read(cx).anchor_before(Point::new(0, 0))
..buffer.read(cx).anchor_before(Point::new(1, 0));
let range_2 = buffer.read(cx).anchor_before(Point::new(5, 0))
..buffer.read(cx).anchor_before(Point::new(5, 3));
log.reject_edits_in_ranges(buffer.clone(), vec![range_1, range_2], cx)
.detach();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndef\nghi\njkl\nmno"
);
});
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndef\nghi\njkl\nmno"
);
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test(iterations = 10)]
async fn test_reject_deleted_file(cx: &mut TestAppContext) {
init_test(cx);
@@ -1209,7 +1349,11 @@ mod tests {
action_log
.update(cx, |log, cx| {
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 0), cx)
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(0, 0)],
cx,
)
})
.await
.unwrap();
@@ -1260,7 +1404,11 @@ mod tests {
action_log
.update(cx, |log, cx| {
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 11), cx)
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(0, 11)],
cx,
)
})
.await
.unwrap();
@@ -1306,7 +1454,7 @@ mod tests {
.update(cx, |log, cx| {
let range = buffer.read(cx).random_byte_range(0, &mut rng);
log::info!("rejecting edits in range {:?}", range);
log.reject_edits_in_range(buffer.clone(), range, cx)
log.reject_edits_in_ranges(buffer.clone(), vec![range], cx)
})
.await
.unwrap();

View File

@@ -1,5 +1,6 @@
mod action_log;
mod tool_registry;
mod tool_schema;
mod tool_working_set;
use std::fmt;
@@ -16,12 +17,26 @@ use project::Project;
pub use crate::action_log::*;
pub use crate::tool_registry::*;
pub use crate::tool_schema::*;
pub use crate::tool_working_set::*;
pub fn init(cx: &mut App) {
ToolRegistry::default_global(cx);
}
/// The result of running a tool
pub struct ToolResult {
/// The asynchronous task that will eventually resolve to the tool's output
pub output: Task<Result<String>>,
}
impl From<Task<Result<String>>> for ToolResult {
/// Convert from a task to a ToolResult
fn from(output: Task<Result<String>>) -> Self {
Self { output }
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
pub enum ToolSource {
/// A native tool built-in to Zed.
@@ -51,8 +66,8 @@ pub trait Tool: 'static + Send + Sync {
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::default())
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
Ok(serde_json::Value::Object(serde_json::Map::default()))
}
/// Returns markdown to be displayed in the UI for this tool.
@@ -66,7 +81,7 @@ pub trait Tool: 'static + Send + Sync {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>>;
) -> ToolResult;
}
impl Debug for dyn Tool {

View File

@@ -0,0 +1,236 @@
use anyhow::Result;
use serde_json::Value;
use crate::LanguageModelToolSchemaFormat;
/// Tries to adapt a JSON schema representation to be compatible with the specified format.
///
/// If the json cannot be made compatible with the specified format, an error is returned.
pub fn adapt_schema_to_format(
json: &mut Value,
format: LanguageModelToolSchemaFormat,
) -> Result<()> {
match format {
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
}
}
/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
if let Value::Object(obj) = json {
const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
for key in UNSUPPORTED_KEYS {
if obj.contains_key(key) {
return Err(anyhow::anyhow!(
"Schema cannot be made compatible because it contains \"{}\" ",
key
));
}
}
const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
for key in KEYS_TO_REMOVE {
obj.remove(key);
}
if let Some(default) = obj.get("default") {
let is_null = default.is_null();
// Default is not supported, so we need to remove it
obj.remove("default");
if is_null {
obj.insert("nullable".to_string(), Value::Bool(true));
}
}
// If a type is not specified for an input parameter, add a default type
if obj.contains_key("description")
&& !obj.contains_key("type")
&& !(obj.contains_key("anyOf")
|| obj.contains_key("oneOf")
|| obj.contains_key("allOf"))
{
obj.insert("type".to_string(), Value::String("string".to_string()));
}
// Handle oneOf -> anyOf conversion
if let Some(subschemas) = obj.get_mut("oneOf") {
if subschemas.is_array() {
let subschemas_clone = subschemas.clone();
obj.remove("oneOf");
obj.insert("anyOf".to_string(), subschemas_clone);
}
}
// Recursively process all nested objects and arrays
for (_, value) in obj.iter_mut() {
if let Value::Object(_) | Value::Array(_) = value {
adapt_to_json_schema_subset(value)?;
}
}
} else if let Value::Array(arr) = json {
for item in arr.iter_mut() {
adapt_to_json_schema_subset(item)?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_transform_default_null_to_nullable() {
let mut json = json!({
"description": "A test field",
"type": "string",
"default": null
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field",
"type": "string",
"nullable": true
})
);
}
#[test]
fn test_transform_adds_type_when_missing() {
let mut json = json!({
"description": "A test field without type"
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field without type",
"type": "string"
})
);
}
#[test]
fn test_transform_removes_format() {
let mut json = json!({
"description": "A test field",
"type": "integer",
"format": "uint32"
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field",
"type": "integer"
})
);
}
#[test]
fn test_transform_one_of_to_any_of() {
let mut json = json!({
"description": "A test field",
"oneOf": [
{ "type": "string" },
{ "type": "integer" }
]
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field",
"anyOf": [
{ "type": "string" },
{ "type": "integer" }
]
})
);
}
#[test]
fn test_transform_nested_objects() {
let mut json = json!({
"type": "object",
"properties": {
"nested": {
"oneOf": [
{ "type": "string" },
{ "type": "null" }
],
"format": "email"
}
}
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"type": "object",
"properties": {
"nested": {
"anyOf": [
{ "type": "string" },
{ "type": "null" }
]
}
}
})
);
}
#[test]
fn test_transform_fails_if_unsupported_keys_exist() {
let mut json = json!({
"type": "object",
"properties": {
"$ref": "#/definitions/User",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
let mut json = json!({
"type": "object",
"properties": {
"if": "...",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
let mut json = json!({
"type": "object",
"properties": {
"then": "...",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
let mut json = json!({
"type": "object",
"properties": {
"else": "...",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
}
}

View File

@@ -1,8 +1,7 @@
use std::sync::Arc;
use collections::{HashMap, HashSet, IndexMap};
use gpui::App;
use parking_lot::Mutex;
use gpui::{App, Context, EventEmitter};
use crate::{Tool, ToolRegistry, ToolSource};
@@ -12,11 +11,6 @@ pub struct ToolId(usize);
/// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)]
pub struct ToolWorkingSet {
state: Mutex<WorkingSetState>,
}
#[derive(Default)]
struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
enabled_sources: HashSet<ToolSource>,
@@ -24,99 +18,27 @@ struct WorkingSetState {
next_tool_id: ToolId,
}
pub enum ToolWorkingSetEvent {
EnabledToolsChanged,
}
impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
self.state
.lock()
.context_server_tools_by_name
self.context_server_tools_by_name
.get(name)
.cloned()
.or_else(|| ToolRegistry::global(cx).tool(name))
}
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().tools(cx)
}
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
self.state.lock().tools_by_source(cx)
}
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().enabled_tools(cx)
}
pub fn disable_all_tools(&self) {
let mut state = self.state.lock();
state.disable_all_tools();
}
pub fn enable_source(&self, source: ToolSource, cx: &App) {
let mut state = self.state.lock();
state.enable_source(source, cx);
}
pub fn disable_source(&self, source: &ToolSource) {
let mut state = self.state.lock();
state.disable_source(source);
}
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
let mut state = self.state.lock();
let tool_id = state.next_tool_id;
state.next_tool_id.0 += 1;
state
.context_server_tools_by_id
.insert(tool_id, tool.clone());
state.tools_changed();
tool_id
}
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_enabled(source, name)
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_disabled(source, name)
}
pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
let mut state = self.state.lock();
state.enable(source, tools_to_enable);
}
pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
let mut state = self.state.lock();
state.disable(source, tools_to_disable);
}
pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
let mut state = self.state.lock();
state
.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
state.tools_changed();
}
}
impl WorkingSetState {
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
}
fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools();
tools.extend(self.context_server_tools_by_id.values().cloned());
tools
}
fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
let mut tools_by_source = IndexMap::default();
for tool in self.tools(cx) {
@@ -135,7 +57,7 @@ impl WorkingSetState {
tools_by_source
}
fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let all_tools = self.tools(cx);
all_tools
@@ -144,31 +66,12 @@ impl WorkingSetState {
.collect()
}
fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.enabled_tools_by_source
.get(source)
.map_or(false, |enabled_tools| enabled_tools.contains(name))
pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
self.enabled_tools_by_source.clear();
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_enabled(source, name)
}
fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_enable.into_iter().cloned());
}
fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_disable.contains(name));
}
fn enable_source(&mut self, source: ToolSource, cx: &App) {
pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
self.enabled_sources.insert(source.clone());
let tools_by_source = self.tools_by_source(cx);
@@ -181,14 +84,72 @@ impl WorkingSetState {
.collect::<HashSet<_>>(),
);
}
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
fn disable_source(&mut self, source: &ToolSource) {
pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
self.enabled_sources.remove(source);
self.enabled_tools_by_source.remove(source);
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
fn disable_all_tools(&mut self) {
self.enabled_tools_by_source.clear();
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1;
self.context_server_tools_by_id
.insert(tool_id, tool.clone());
self.tools_changed();
tool_id
}
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.enabled_tools_by_source
.get(source)
.map_or(false, |enabled_tools| enabled_tools.contains(name))
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_enabled(source, name)
}
pub fn enable(
&mut self,
source: ToolSource,
tools_to_enable: &[Arc<str>],
cx: &mut Context<Self>,
) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_enable.into_iter().cloned());
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn disable(
&mut self,
source: ToolSource,
tools_to_disable: &[Arc<str>],
cx: &mut Context<Self>,
) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_disable.contains(name));
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed();
}
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
}
}

View File

@@ -1,6 +1,7 @@
mod bash_tool;
mod batch_tool;
mod code_action_tool;
mod code_symbols_tool;
mod contents_tool;
mod copy_path_tool;
mod create_directory_tool;
mod create_file_tool;
@@ -15,9 +16,11 @@ mod open_tool;
mod path_search_tool;
mod read_file_tool;
mod regex_search_tool;
mod rename_tool;
mod replace;
mod schema;
mod symbol_info_tool;
mod terminal_tool;
mod thinking_tool;
use std::sync::Arc;
@@ -28,9 +31,10 @@ use gpui::App;
use http_client::HttpClientWithUrl;
use move_path_tool::MovePathTool;
use crate::bash_tool::BashTool;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
use crate::code_symbols_tool::CodeSymbolsTool;
use crate::contents_tool::ContentsTool;
use crate::create_directory_tool::CreateDirectoryTool;
use crate::create_file_tool::CreateFileTool;
use crate::delete_path_tool::DeletePathTool;
@@ -43,14 +47,16 @@ use crate::open_tool::OpenTool;
use crate::path_search_tool::PathSearchTool;
use crate::read_file_tool::ReadFileTool;
use crate::regex_search_tool::RegexSearchTool;
use crate::rename_tool::RenameTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
let registry = ToolRegistry::global(cx);
registry.register_tool(BashTool);
registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
@@ -58,15 +64,57 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(DeletePathTool);
registry.register_tool(FindReplaceFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
registry.register_tool(DiagnosticsTool);
registry.register_tool(ListDirectoryTool);
registry.register_tool(NowTool);
registry.register_tool(OpenTool);
registry.register_tool(CodeSymbolsTool);
registry.register_tool(ContentsTool);
registry.register_tool(PathSearchTool);
registry.register_tool(ReadFileTool);
registry.register_tool(RegexSearchTool);
registry.register_tool(RenameTool);
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
}
#[cfg(test)]
mod tests {
use http_client::FakeHttpClient;
use super::*;
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
crate::init(
Arc::new(http_client::HttpClientWithUrl::new(
FakeHttpClient::with_200_response(),
"https://zed.dev",
None,
)),
cx,
);
for tool in ToolRegistry::global(cx).tools() {
let actual_schema = tool
.input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset)
.unwrap();
let mut expected_schema = actual_schema.clone();
assistant_tool::adapt_schema_to_format(
&mut expected_schema,
language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
)
.unwrap();
let error_message = format!(
"Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\
Are you using `schema::json_schema_for<T>(format)` to generate the schema?",
tool.name(),
);
assert_eq!(actual_schema, expected_schema, "{}", error_message)
}
}
}

View File

@@ -1,230 +0,0 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Arc;
use ui::IconName;
use util::command::new_smol_command;
use util::markdown::MarkdownString;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct BashToolInput {
/// The bash one-liner command to execute.
command: String,
/// Working directory for the command. This must be one of the root directories of the project.
cd: String,
}
pub struct BashTool;
impl Tool for BashTool {
fn name(&self) -> String {
"bash".to_string()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
}
fn description(&self) -> String {
include_str!("./bash_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Terminal
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<BashToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<BashToolInput>(input.clone()) {
Ok(input) => {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count();
match remaining_line_count {
0 => MarkdownString::inline_code(&first_line).0,
1 => {
MarkdownString::inline_code(&format!(
"{} - {} more line",
first_line, remaining_line_count
))
.0
}
n => {
MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
}
}
}
Err(_) => "Run bash command".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input: BashToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let project = project.read(cx);
let input_path = Path::new(&input.cd);
let working_dir = if input.cd == "." {
// Accept "." as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx);
let only_worktree = match worktrees.next() {
Some(worktree) => worktree,
None => return Task::ready(Err(anyhow!("No worktrees found in the project"))),
};
if worktrees.next().is_some() {
return Task::ready(Err(anyhow!(
"'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
)));
}
only_worktree.read(cx).abs_path()
} else if input_path.is_absolute() {
// Absolute paths are allowed, but only if they're in one of the project's worktrees.
if !project
.worktrees(cx)
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
{
return Task::ready(Err(anyhow!(
"The absolute path must be within one of the project's worktrees"
)));
}
input_path.into()
} else {
let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
return Task::ready(Err(anyhow!(
"`cd` directory {} not found in the project",
&input.cd
)));
};
worktree.read(cx).abs_path()
};
cx.spawn(async move |_| {
// Add 2>&1 to merge stderr into stdout for proper interleaving.
let command = format!("({}) 2>&1", input.command);
let mut cmd = new_smol_command("bash")
.arg("-c")
.arg(&command)
.current_dir(working_dir)
.stdout(std::process::Stdio::piped())
.spawn()
.context("Failed to execute bash command")?;
// Capture stdout with a limit
let stdout = cmd.stdout.take().unwrap();
let mut reader = BufReader::new(stdout);
const MESSAGE_1: &str = "Command output too long. The first ";
const MESSAGE_2: &str = " bytes:\n\n";
const ERR_MESSAGE_1: &str = "Command failed with exit code ";
const ERR_MESSAGE_2: &str = "\n\n";
const STDOUT_LIMIT: usize = 8192;
const LIMIT: usize = STDOUT_LIMIT
- (MESSAGE_1.len()
+ (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
+ MESSAGE_2.len()
+ ERR_MESSAGE_1.len()
+ 3 // status code
+ ERR_MESSAGE_2.len());
// Read one more byte to determine whether the output was truncated
let mut buffer = vec![0; LIMIT + 1];
let mut bytes_read = 0;
// Read until we reach the limit
loop {
let read = reader.read(&mut buffer).await?;
if read == 0 {
break;
}
bytes_read += read;
if bytes_read > LIMIT {
bytes_read = LIMIT + 1;
break;
}
}
// Repeatedly fill the output reader's buffer without copying it.
loop {
let skipped_bytes = reader.fill_buf().await?;
if skipped_bytes.is_empty() {
break;
}
let skipped_bytes_len = skipped_bytes.len();
reader.consume_unpin(skipped_bytes_len);
}
let output_bytes = &buffer[..bytes_read];
// Let the process continue running
let status = cmd.status().await.context("Failed to get command status")?;
let output_string = if bytes_read > LIMIT {
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
// multi-byte characters.
let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
let output_string = String::from_utf8_lossy(
&output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
);
format!(
"{}{}{}{}",
MESSAGE_1,
output_string.len(),
MESSAGE_2,
output_string
)
} else {
String::from_utf8_lossy(&output_bytes).into()
};
let output_with_status = if status.success() {
if output_string.is_empty() {
"Command executed successfully.".to_string()
} else {
output_string.to_string()
}
} else {
format!(
"{}{}{}{}",
ERR_MESSAGE_1,
status.code().unwrap_or(-1),
ERR_MESSAGE_2,
output_string,
)
};
debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
Ok(output_with_status)
})
}
}

View File

@@ -1,7 +0,0 @@
Executes a bash one-liner and returns the combined output.
This tool spawns a bash process, combines stdout and stderr into one interleaved stream as they are produced (preserving the order of writes), and captures that stream into a string which is returned.
Make sure you use the `cd` parameter to navigate to one of the root directories of the project. NEVER do it as part of the `command` itself, otherwise it will error.
Remember that each invocation of this tool will spawn a new bash process, so you can't rely on any state from previous invocations.

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
use futures::future::join_all;
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -172,7 +172,7 @@ impl Tool for BatchTool {
IconName::Cog
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<BatchToolInput>(format)
}
@@ -219,14 +219,14 @@ impl Tool for BatchTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<BatchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
if input.invocations.is_empty() {
return Task::ready(Err(anyhow!("No tool invocations provided")));
return Task::ready(Err(anyhow!("No tool invocations provided"))).into();
}
let run_tools_concurrently = input.run_tools_concurrently;
@@ -257,11 +257,11 @@ impl Tool for BatchTool {
let project = project.clone();
let action_log = action_log.clone();
let messages = messages.clone();
let task = cx
let tool_result = cx
.update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
tasks.push(task);
tasks.push(tool_result.output);
}
Ok((tasks, tool_names))
@@ -306,5 +306,6 @@ impl Tool for BatchTool {
Ok(formatted_results.trim().to_string())
})
.into()
}
}

View File

@@ -0,0 +1,387 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use language::{self, Anchor, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{self, LspAction, Project};
use regex::Regex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{ops::Range, sync::Arc};
use ui::IconName;
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct CodeActionToolInput {
/// The relative path to the file containing the text range.
///
/// WARNING: you MUST start this path with one of the project's root directories.
pub path: String,
/// The specific code action to execute.
///
/// If this field is provided, the tool will execute the specified action.
/// If omitted, the tool will list all available code actions for the text range.
///
/// Here are some actions that are commonly supported (but may not be for this particular
/// text range; you can omit this field to list all the actions, if you want to know
/// what your options are, or you can just try an action and if it fails I'll tell you
/// what the available actions were instead):
/// - "quickfix.all" - applies all available quick fixes in the range
/// - "source.organizeImports" - sorts and cleans up import statements
/// - "source.fixAll" - applies all available auto fixes
/// - "refactor.extract" - extracts selected code into a new function or variable
/// - "refactor.inline" - inlines a variable by replacing references with its value
/// - "refactor.rewrite" - general code rewriting operations
/// - "source.addMissingImports" - adds imports for references that lack them
/// - "source.removeUnusedImports" - removes imports that aren't being used
/// - "source.implementInterface" - generates methods required by an interface/trait
/// - "source.generateAccessors" - creates getter/setter methods
/// - "source.convertToAsyncFunction" - converts callback-style code to async/await
///
/// Also, there is a special case: if you specify exactly "textDocument/rename" as the action,
/// then this will rename the symbol to whatever string you specified for the `arguments` field.
pub action: Option<String>,
/// Optional arguments to pass to the code action.
///
/// For rename operations (when action="textDocument/rename"), this should contain the new name.
/// For other code actions, these arguments may be passed to the language server.
pub arguments: Option<serde_json::Value>,
/// The text that comes immediately before the text range in the file.
pub context_before_range: String,
/// The text range. This text must appear in the file right between `context_before_range`
/// and `context_after_range`.
///
/// The file must contain exactly one occurrence of `context_before_range` followed by
/// `text_range` followed by `context_after_range`. If the file contains zero occurrences,
/// or if it contains more than one occurrence, the tool will fail, so it is absolutely
/// critical that you verify ahead of time that the string is unique. You can search
/// the file's contents to verify this ahead of time.
///
/// To make the string more likely to be unique, include a minimum of 1 line of context
/// before the text range, as well as a minimum of 1 line of context after the text range.
/// If these lines of context are not enough to obtain a string that appears only once
/// in the file, then double the number of context lines until the string becomes unique.
/// (Start with 1 line before and 1 line after though, because too much context is
/// needlessly costly.)
///
/// Do not alter the context lines of code in any way, and make sure to preserve all
/// whitespace and indentation for all lines of code. The combined string must be exactly
/// as it appears in the file, or else this tool call will fail.
pub text_range: String,
/// The text that comes immediately after the text range in the file.
pub context_after_range: String,
}
pub struct CodeActionTool;
impl Tool for CodeActionTool {
fn name(&self) -> String {
"code_actions".into()
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./code_action_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::Wand
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CodeActionToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<CodeActionToolInput>(input.clone()) {
Ok(input) => {
if let Some(action) = &input.action {
if action == "textDocument/rename" {
let new_name = match &input.arguments {
Some(serde_json::Value::String(new_name)) => new_name.clone(),
Some(value) => {
if let Ok(new_name) =
serde_json::from_value::<String>(value.clone())
{
new_name
} else {
"invalid name".to_string()
}
}
None => "missing name".to_string(),
};
format!("Rename '{}' to '{}'", input.text_range, new_name)
} else {
format!(
"Execute code action '{}' for '{}'",
action, input.text_range
)
}
} else {
format!("List available code actions for '{}'", input.text_range)
}
}
Err(_) => "Perform code action".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CodeActionToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx| {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
};
action_log.update(cx, |action_log, cx| {
action_log.buffer_read(buffer.clone(), cx);
})?;
let range = {
let Some(range) = buffer.read_with(cx, |buffer, _cx| {
find_text_range(&buffer, &input.context_before_range, &input.text_range, &input.context_after_range)
})? else {
return Err(anyhow!(
"Failed to locate the text specified by context_before_range, text_range, and context_after_range. Make sure context_before_range and context_after_range each match exactly once in the file."
));
};
range
};
if let Some(action_type) = &input.action {
// Special-case the `rename` operation
let response = if action_type == "textDocument/rename" {
let Some(new_name) = input.arguments.and_then(|args| serde_json::from_value::<String>(args).ok()) else {
return Err(anyhow!("For rename operations, 'arguments' must be a string containing the new name"));
};
let position = buffer.read_with(cx, |buffer, _| {
range.start.to_point_utf16(&buffer.snapshot())
})?;
project
.update(cx, |project, cx| {
project.perform_rename(buffer.clone(), position, new_name.clone(), cx)
})?
.await?;
format!("Renamed '{}' to '{}'", input.text_range, new_name)
} else {
// Get code actions for the range
let actions = project
.update(cx, |project, cx| {
project.code_actions(&buffer, range.clone(), None, cx)
})?
.await?;
if actions.is_empty() {
return Err(anyhow!("No code actions available for this range"));
}
// Find all matching actions
let regex = match Regex::new(action_type) {
Ok(regex) => regex,
Err(err) => return Err(anyhow!("Invalid regex pattern: {}", err)),
};
let mut matching_actions = actions
.into_iter()
.filter(|action| { regex.is_match(action.lsp_action.title()) });
let Some(action) = matching_actions.next() else {
return Err(anyhow!("No code actions match the pattern: {}", action_type));
};
// There should have been exactly one matching action.
if let Some(second) = matching_actions.next() {
let mut all_matches = vec![action, second];
all_matches.extend(matching_actions);
return Err(anyhow!(
"Pattern '{}' matches multiple code actions: {}",
action_type,
all_matches.into_iter().map(|action| action.lsp_action.title().to_string()).collect::<Vec<_>>().join(", ")
));
}
let title = action.lsp_action.title().to_string();
project
.update(cx, |project, cx| {
project.apply_code_action(buffer.clone(), action, true, cx)
})?
.await?;
format!("Completed code action: {}", title)
};
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
})?;
Ok(response)
} else {
// No action specified, so list the available ones.
let (position_start, position_end) = buffer.read_with(cx, |buffer, _| {
let snapshot = buffer.snapshot();
(
range.start.to_point_utf16(&snapshot),
range.end.to_point_utf16(&snapshot)
)
})?;
// Convert position to display coordinates (1-based)
let position_start_display = language::Point {
row: position_start.row + 1,
column: position_start.column + 1,
};
let position_end_display = language::Point {
row: position_end.row + 1,
column: position_end.column + 1,
};
// Get code actions for the range
let actions = project
.update(cx, |project, cx| {
project.code_actions(&buffer, range.clone(), None, cx)
})?
.await?;
let mut response = format!(
"Available code actions for text range '{}' at position {}:{} to {}:{} (UTF-16 coordinates):\n\n",
input.text_range,
position_start_display.row, position_start_display.column,
position_end_display.row, position_end_display.column
);
if actions.is_empty() {
response.push_str("No code actions available for this range.");
} else {
for (i, action) in actions.iter().enumerate() {
let title = match &action.lsp_action {
LspAction::Action(code_action) => code_action.title.as_str(),
LspAction::Command(command) => command.title.as_str(),
LspAction::CodeLens(code_lens) => {
if let Some(cmd) = &code_lens.command {
cmd.title.as_str()
} else {
"Unknown code lens"
}
},
};
let kind = match &action.lsp_action {
LspAction::Action(code_action) => {
if let Some(kind) = &code_action.kind {
kind.as_str()
} else {
"unknown"
}
},
LspAction::Command(_) => "command",
LspAction::CodeLens(_) => "code_lens",
};
response.push_str(&format!("{}. {title} ({kind})\n", i + 1));
}
}
Ok(response)
}
}).into()
}
}
/// Finds the range of the text in the buffer, if it appears between context_before_range
/// and context_after_range, and if that combined string has one unique result in the buffer.
///
/// If an exact match fails, it tries adding a newline to the end of context_before_range and
/// to the beginning of context_after_range to accommodate line-based context matching.
fn find_text_range(
buffer: &Buffer,
context_before_range: &str,
text_range: &str,
context_after_range: &str,
) -> Option<Range<Anchor>> {
let snapshot = buffer.snapshot();
let text = snapshot.text();
// First try with exact match
let search_string = format!("{context_before_range}{text_range}{context_after_range}");
let mut positions = text.match_indices(&search_string);
let position_result = positions.next();
if let Some(position) = position_result {
// Check if the matched string is unique
if positions.next().is_none() {
let range_start = position.0 + context_before_range.len();
let range_end = range_start + text_range.len();
let range_start_anchor = snapshot.anchor_before(snapshot.offset_to_point(range_start));
let range_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(range_end));
return Some(range_start_anchor..range_end_anchor);
}
}
// If exact match fails or is not unique, try with line-based context
// Add a newline to the end of before context and beginning of after context
let line_based_before = if context_before_range.ends_with('\n') {
context_before_range.to_string()
} else {
format!("{context_before_range}\n")
};
let line_based_after = if context_after_range.starts_with('\n') {
context_after_range.to_string()
} else {
format!("\n{context_after_range}")
};
let line_search_string = format!("{line_based_before}{text_range}{line_based_after}");
let mut line_positions = text.match_indices(&line_search_string);
let line_position = line_positions.next()?;
// The line-based search string must also appear exactly once
if line_positions.next().is_some() {
return None;
}
let line_range_start = line_position.0 + line_based_before.len();
let line_range_end = line_range_start + text_range.len();
let line_range_start_anchor =
snapshot.anchor_before(snapshot.offset_to_point(line_range_start));
let line_range_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(line_range_end));
Some(line_range_start_anchor..line_range_end_anchor)
}

View File

@@ -0,0 +1,19 @@
A tool for applying code actions to specific sections of your code. It uses language servers to provide refactoring capabilities similar to what you'd find in an IDE.
This tool can:
- List all available code actions for a selected text range
- Execute a specific code action on that range
- Rename symbols across your codebase. This tool is the preferred way to rename things, and you should always prefer to rename code symbols using this tool rather than using textual find/replace when both are available.
Use this tool when you want to:
- Discover what code actions are available for a piece of code
- Apply automatic fixes and code transformations
- Rename variables, functions, or other symbols consistently throughout your project
- Clean up imports, implement interfaces, or perform other language-specific operations
- If unsure what actions are available, call the tool without specifying an action to get a list
- For common operations, you can directly specify actions like "quickfix.all" or "source.organizeImports"
- For renaming, use the special "textDocument/rename" action and provide the new name in the arguments field
- Be specific with your text range and context to ensure the tool identifies the correct code location
The tool will automatically save any changes it makes to your files.

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use collections::IndexMap;
use gpui::{App, AsyncApp, Entity, Task};
use language::{OutlineItem, ParseStatus, Point};
@@ -91,7 +91,7 @@ impl Tool for CodeSymbolsTool {
IconName::Code
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CodeSymbolsInput>(format)
}
@@ -129,10 +129,10 @@ impl Tool for CodeSymbolsTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let regex = match input.regex {
@@ -141,7 +141,7 @@ impl Tool for CodeSymbolsTool {
.build()
{
Ok(regex) => Some(regex),
Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))),
Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))).into(),
},
None => None,
};
@@ -150,6 +150,7 @@ impl Tool for CodeSymbolsTool {
Some(path) => file_outline(project, path, action_log, regex, input.offset, cx).await,
None => project_symbols(project, regex, input.offset, cx).await,
})
.into()
}
}
@@ -179,11 +180,9 @@ pub async fn file_outline(
// Wait until the buffer has been fully parsed, so that we can read its outline.
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
while parse_status
.recv()
.await
.map_or(false, |status| status != ParseStatus::Idle)
{}
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
}
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let Some(outline) = snapshot.outline(None) else {

View File

@@ -0,0 +1,239 @@
use std::sync::Arc;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{fmt::Write, path::Path};
use ui::IconName;
use util::markdown::MarkdownString;
/// If the model requests to read a file whose size exceeds this, then
/// the tool will return the file's symbol outline instead of its contents,
/// and suggest trying again using line ranges from the outline.
const MAX_FILE_SIZE_TO_READ: usize = 16384;
/// If the model requests to list the entries in a directory with more
/// entries than this, then the tool will return a subset of the entries
/// and suggest trying again.
const MAX_DIR_ENTRIES: usize = 1024;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ContentsToolInput {
/// The relative path of the file or directory to access.
///
/// This path should never be absolute, and the first component
/// of the path should always be a root directory in a project.
///
/// <example>
/// If the project has the following root directories:
///
/// - directory1
/// - directory2
///
/// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
/// If you want to list contents in the directory `directory2/subfolder`, you should use the path `directory2/subfolder`.
/// </example>
pub path: String,
/// Optional position (1-based index) to start reading on, if you want to read a subset of the contents.
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
///
/// Defaults to 1.
pub start: Option<u32>,
/// Optional position (1-based index) to end reading on, if you want to read a subset of the contents.
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
///
/// Defaults to reading until the end of the file or directory.
pub end: Option<u32>,
}
pub struct ContentsTool;
impl Tool for ContentsTool {
fn name(&self) -> String {
"contents".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./contents_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::FileSearch
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ContentsToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ContentsToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownString::inline_code(&input.path);
match (input.start, input.end) {
(Some(start), None) => format!("Read {path} (from line {start})"),
(Some(start), Some(end)) => {
format!("Read {path} (lines {start}-{end})")
}
_ => format!("Read {path}"),
}
}
Err(_) => "Read file or directory".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ContentsToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// Sometimes models will return these even though we tell it to give a path and not a glob.
// When this happens, just list the root worktree directories.
if matches!(input.path.as_str(), "." | "" | "./" | "*") {
let output = project
.read(cx)
.worktrees(cx)
.filter_map(|worktree| {
worktree.read(cx).root_entry().and_then(|entry| {
if entry.is_dir() {
entry.path.to_str()
} else {
None
}
})
})
.collect::<Vec<_>>()
.join("\n");
return Task::ready(Ok(output)).into();
}
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found"))).into();
};
let worktree = worktree.read(cx);
let Some(entry) = worktree.entry_for_path(&project_path.path) else {
return Task::ready(Err(anyhow!("Path not found: {}", input.path))).into();
};
// If it's a directory, list its contents
if entry.is_dir() {
let mut output = String::new();
let start_index = input
.start
.map(|line| (line as usize).saturating_sub(1))
.unwrap_or(0);
let end_index = input
.end
.map(|line| (line as usize).saturating_sub(1))
.unwrap_or(MAX_DIR_ENTRIES);
let mut skipped = 0;
for (index, entry) in worktree.child_entries(&project_path.path).enumerate() {
if index >= start_index && index <= end_index {
writeln!(
output,
"{}",
Path::new(worktree.root_name()).join(&entry.path).display(),
)
.unwrap();
} else {
skipped += 1;
}
}
if output.is_empty() {
output.push_str(&input.path);
output.push_str(" is empty.");
}
if skipped > 0 {
write!(
output,
"\n\nNote: Skipped {skipped} entries. Adjust start and end to see other entries.",
).ok();
}
Task::ready(Ok(output)).into()
} else {
// It's a file, so read its contents
let file_path = input.path.clone();
cx.spawn(async move |cx| {
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
})?
.await?;
if input.start.is_some() || input.end.is_some() {
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let start = input.start.unwrap_or(1);
let lines = text.split('\n').skip(start as usize - 1);
if let Some(end) = input.end {
let count = end.saturating_sub(start).max(1); // Ensure at least 1 line
Itertools::intersperse(lines.take(count as usize), "\n").collect()
} else {
Itertools::intersperse(lines, "\n").collect()
}
})?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
Ok(result)
} else {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
if file_size <= MAX_FILE_SIZE_TO_READ {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
Ok(result)
} else {
// File is too big, so return its outline and a suggestion to
// read again with a line number range specified.
let outline = file_outline(project, file_path, action_log, None, 0, cx).await?;
Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline."))
}
}
}).into()
}
}
}

View File

@@ -0,0 +1,9 @@
Reads the contents of a path on the filesystem.
If the path is a directory, this lists all files and directories within that path.
If the path is a file, this returns the file's contents.
When reading a file, if the file is too big and no line range is specified, an outline of the file's code symbols is listed instead, which can be used to request specific line ranges in a subsequent call.
Similarly, if a directory has too many entries to show at once, a subset of entries will be shown,
and subsequent requests can use starting and ending line numbers to get other subsets.

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -55,7 +55,7 @@ impl Tool for CopyPathTool {
IconName::Clipboard
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CopyPathToolInput>(format)
}
@@ -77,10 +77,10 @@ impl Tool for CopyPathTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<CopyPathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let copy_task = project.update(cx, |project, cx| {
match project
@@ -117,5 +117,6 @@ impl Tool for CopyPathTool {
)),
}
})
.into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -45,7 +45,7 @@ impl Tool for CreateDirectoryTool {
IconName::Folder
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CreateDirectoryToolInput>(format)
}
@@ -68,14 +68,16 @@ impl Tool for CreateDirectoryTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project_path = match project.read(cx).find_project_path(&input.path, cx) {
Some(project_path) => project_path,
None => return Task::ready(Err(anyhow!("Path to create was outside the project"))),
None => {
return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
}
};
let destination_path: Arc<str> = input.path.as_str().into();
@@ -89,5 +91,6 @@ impl Tool for CreateDirectoryTool {
Ok(format!("Created directory {destination_path}"))
})
.into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -52,7 +52,7 @@ impl Tool for CreateFileTool {
IconName::FileCreate
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CreateFileToolInput>(format)
}
@@ -73,14 +73,16 @@ impl Tool for CreateFileTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<CreateFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project_path = match project.read(cx).find_project_path(&input.path, cx) {
Some(project_path) => project_path,
None => return Task::ready(Err(anyhow!("Path to create was outside the project"))),
None => {
return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
}
};
let contents: Arc<str> = input.contents.as_str().into();
let destination_path: Arc<str> = input.path.as_str().into();
@@ -106,5 +108,6 @@ impl Tool for CreateFileTool {
Ok(format!("Created file {destination_path}"))
})
.into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -45,7 +45,7 @@ impl Tool for DeletePathTool {
IconName::FileDelete
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<DeletePathToolInput>(format)
}
@@ -63,15 +63,16 @@ impl Tool for DeletePathTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
Ok(input) => input.path,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&path_str, cx) else {
return Task::ready(Err(anyhow!(
"Couldn't delete {path_str} because that path isn't in this project."
)));
)))
.into();
};
let Some(worktree) = project
@@ -80,7 +81,8 @@ impl Tool for DeletePathTool {
else {
return Task::ready(Err(anyhow!(
"Couldn't delete {path_str} because that path isn't in this project."
)));
)))
.into();
};
let worktree_snapshot = worktree.read(cx).snapshot();
@@ -132,5 +134,6 @@ impl Tool for DeletePathTool {
)),
}
})
.into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -58,7 +58,7 @@ impl Tool for DiagnosticsTool {
IconName::XCircle
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<DiagnosticsToolInput>(format)
}
@@ -83,14 +83,15 @@ impl Tool for DiagnosticsTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
match serde_json::from_value::<DiagnosticsToolInput>(input)
.ok()
.and_then(|input| input.path)
{
Some(path) if !path.is_empty() => {
let Some(project_path) = project.read(cx).find_project_path(&path, cx) else {
return Task::ready(Err(anyhow!("Could not find path {path} in project",)));
return Task::ready(Err(anyhow!("Could not find path {path} in project",)))
.into();
};
let buffer =
@@ -125,6 +126,7 @@ impl Tool for DiagnosticsTool {
Ok(output)
}
})
.into()
}
_ => {
let project = project.read(cx);
@@ -155,9 +157,10 @@ impl Tool for DiagnosticsTool {
});
if has_diagnostics {
Task::ready(Ok(output))
Task::ready(Ok(output)).into()
} else {
Task::ready(Ok("No errors or warnings found in the project.".to_string()))
.into()
}
}
}

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext as _, Entity, Task};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
@@ -128,7 +128,7 @@ impl Tool for FetchTool {
IconName::Globe
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<FetchToolInput>(format)
}
@@ -146,10 +146,10 @@ impl Tool for FetchTool {
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<FetchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let text = cx.background_spawn({
@@ -158,13 +158,15 @@ impl Tool for FetchTool {
async move { Self::build_message(http_client, &url).await }
});
cx.foreground_executor().spawn(async move {
let text = text.await?;
if text.trim().is_empty() {
bail!("no textual content found");
}
cx.foreground_executor()
.spawn(async move {
let text = text.await?;
if text.trim().is_empty() {
bail!("no textual content found");
}
Ok(text)
})
Ok(text)
})
.into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -63,6 +63,16 @@ pub struct FindReplaceFileToolInput {
/// even one character in this string is different in any way from how it appears
/// in the file, then the tool call will fail.
///
/// If you get an error that the `find` string was not found, this means that either
/// you made a mistake, or that the file has changed since you last looked at it.
/// Either way, when this happens, you should retry doing this tool call until it
/// succeeds, up to 3 times. Each time you retry, you should take another look at
/// the exact text of the file in question, to make sure that you are searching for
/// exactly the right string. Regardless of whether it was because you made a mistake
/// or because the file changed since you last looked at it, you should be extra
/// careful when retrying in this way. It's a bad experience for the user if
/// this `find` string isn't found, so be super careful to get it exactly right!
///
/// <example>
/// If a file contains this code:
///
@@ -141,7 +151,7 @@ impl Tool for FindReplaceFileTool {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<FindReplaceFileToolInput>(format)
}
@@ -159,10 +169,10 @@ impl Tool for FindReplaceFileTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<FindReplaceFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx: &mut AsyncApp| {
@@ -253,6 +263,6 @@ impl Tool for FindReplaceFileTool {
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
})
}).into()
}
}

View File

@@ -1,8 +1,10 @@
Find one unique part of a file in the project and replace that text with new text.
This tool is the preferred way to make edits to files. If you have multiple edits to make, including edits across multiple files, then make a plan to respond with a single message containing multiple calls to this tool - one call for each find/replace operation.
This tool is the preferred way to make edits to files *except* when making a rename. When making a rename specifically, the rename tool must always be used instead.
You should use this tool when you want to edit a subset of a file's contents, but not the entire file. You should not use this tool when you want to replace the entire contents of a file with completely different contents. You also should not use this tool when you want to move or rename a file. You absolutely must NEVER use this tool to create new files from scratch. If you ever consider using this tool to create a new file from scratch, for any reason, instead you must reconsider and choose a different approach.
If you have multiple edits to make, including edits across multiple files, then make a plan to respond with a single message containing a batch of calls to this tool - one call for each find/replace operation.
You should only use this tool when you want to edit a subset of a file's contents, but not the entire file. You should not use this tool when you want to replace the entire contents of a file with completely different contents. You also should not use this tool when you want to move or rename a file. You absolutely must NEVER use this tool to create new files from scratch. If you ever consider using this tool to create a new file from scratch, for any reason, instead you must reconsider and choose a different approach.
DO NOT call this tool until the code to be edited appears in the conversation! You must use another tool to read the file's contents into the conversation, or ask the user to add it to context first.

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -56,7 +56,7 @@ impl Tool for ListDirectoryTool {
IconName::Folder
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ListDirectoryToolInput>(format)
}
@@ -77,10 +77,10 @@ impl Tool for ListDirectoryTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// Sometimes models will return these even though we tell it to give a path and not a glob.
@@ -101,26 +101,26 @@ impl Tool for ListDirectoryTool {
.collect::<Vec<_>>()
.join("\n");
return Task::ready(Ok(output));
return Task::ready(Ok(output)).into();
}
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", input.path)));
return Task::ready(Err(anyhow!("Path {} not found in project", input.path))).into();
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found")));
return Task::ready(Err(anyhow!("Worktree not found"))).into();
};
let worktree = worktree.read(cx);
let Some(entry) = worktree.entry_for_path(&project_path.path) else {
return Task::ready(Err(anyhow!("Path not found: {}", input.path)));
return Task::ready(Err(anyhow!("Path not found: {}", input.path))).into();
};
if !entry.is_dir() {
return Task::ready(Err(anyhow!("{} is not a directory.", input.path)));
return Task::ready(Err(anyhow!("{} is not a directory.", input.path))).into();
}
let mut output = String::new();
@@ -133,8 +133,8 @@ impl Tool for ListDirectoryTool {
.unwrap();
}
if output.is_empty() {
return Task::ready(Ok(format!("{} is empty.", input.path)));
return Task::ready(Ok(format!("{} is empty.", input.path))).into();
}
Task::ready(Ok(output))
Task::ready(Ok(output)).into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -54,7 +54,7 @@ impl Tool for MovePathTool {
IconName::ArrowRightLeft
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<MovePathToolInput>(format)
}
@@ -90,10 +90,10 @@ impl Tool for MovePathTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<MovePathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let rename_task = project.update(cx, |project, cx| {
match project
@@ -128,5 +128,6 @@ impl Tool for MovePathTool {
)),
}
})
.into()
}
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use chrono::{Local, Utc};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -45,7 +45,7 @@ impl Tool for NowTool {
IconName::Info
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<NowToolInput>(format)
}
@@ -60,10 +60,10 @@ impl Tool for NowTool {
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input: NowToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let now = match input.timezone {
@@ -72,6 +72,6 @@ impl Tool for NowTool {
};
let text = format!("The current datetime is {now}.");
Task::ready(Ok(text))
Task::ready(Ok(text)).into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -35,7 +35,7 @@ impl Tool for OpenTool {
IconName::ArrowUpRight
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<OpenToolInput>(format)
}
@@ -53,10 +53,10 @@ impl Tool for OpenTool {
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input: OpenToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.background_spawn(async move {
@@ -64,5 +64,6 @@ impl Tool for OpenTool {
Ok(format!("Successfully opened {}", input.path_or_url))
})
.into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -53,7 +53,7 @@ impl Tool for PathSearchTool {
IconName::SearchCode
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<PathSearchToolInput>(format)
}
@@ -71,10 +71,10 @@ impl Tool for PathSearchTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
Ok(input) => (input.offset, input.glob),
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let path_matcher = match PathMatcher::new([
@@ -82,7 +82,7 @@ impl Tool for PathSearchTool {
if glob.is_empty() { "*" } else { &glob },
]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))).into(),
};
let snapshots: Vec<Snapshot> = project
.read(cx)
@@ -136,6 +136,6 @@ impl Tool for PathSearchTool {
Ok(response)
}
})
}).into()
}
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -63,7 +63,7 @@ impl Tool for ReadFileTool {
IconName::FileSearch
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ReadFileToolInput>(format)
}
@@ -88,14 +88,14 @@ impl Tool for ReadFileTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path,)));
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path,))).into();
};
let file_path = input.path.clone();
@@ -146,6 +146,6 @@ impl Tool for ReadFileTool {
Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start_line and end_line fields to see the implementations of symbols in the outline."))
}
}
})
}).into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::StreamExt;
use gpui::{App, Entity, Task};
use language::OffsetRangeExt;
@@ -26,6 +26,10 @@ pub struct RegexSearchToolInput {
/// When not provided, starts from the beginning.
#[serde(default)]
pub offset: u32,
/// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
#[serde(default)]
pub case_sensitive: bool,
}
impl RegexSearchToolInput {
@@ -56,7 +60,7 @@ impl Tool for RegexSearchTool {
IconName::Regex
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<RegexSearchToolInput>(format)
}
@@ -64,12 +68,17 @@ impl Tool for RegexSearchTool {
match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
Ok(input) => {
let page = input.page();
let regex = MarkdownString::inline_code(&input.regex);
let regex_str = MarkdownString::inline_code(&input.regex);
let case_info = if input.case_sensitive {
" (case-sensitive)"
} else {
""
};
if page > 1 {
format!("Get page {page} of search results for regex {regex}")
format!("Get page {page} of search results for regex {regex_str}{case_info}")
} else {
format!("Search files for regex {regex}")
format!("Search files for regex {regex_str}{case_info}")
}
}
Err(_) => "Search with regex".to_string(),
@@ -83,25 +92,26 @@ impl Tool for RegexSearchTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
const CONTEXT_LINES: u32 = 2;
let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
Ok(input) => (input.offset, input.regex),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let (offset, regex, case_sensitive) =
match serde_json::from_value::<RegexSearchToolInput>(input) {
Ok(input) => (input.offset, input.regex, input.case_sensitive),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let query = match SearchQuery::regex(
&regex,
false,
false,
case_sensitive,
false,
PathMatcher::default(),
PathMatcher::default(),
None,
) {
Ok(query) => query,
Err(error) => return Task::ready(Err(error)),
Err(error) => return Task::ready(Err(error)).into(),
};
let results = project.update(cx, |project, cx| project.search(query, cx));
@@ -190,6 +200,6 @@ impl Tool for RegexSearchTool {
} else {
Ok(format!("Found {matches_found} matches:\n{output}"))
}
})
}).into()
}
}

View File

@@ -0,0 +1,203 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use language::{self, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ui::IconName;
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RenameToolInput {
/// The relative path to the file containing the symbol to rename.
///
/// WARNING: you MUST start this path with one of the project's root directories.
pub path: String,
/// The new name to give to the symbol.
pub new_name: String,
/// The text that comes immediately before the symbol in the file.
pub context_before_symbol: String,
/// The symbol to rename. This text must appear in the file right between
/// `context_before_symbol` and `context_after_symbol`.
///
/// The file must contain exactly one occurrence of `context_before_symbol` followed by
/// `symbol` followed by `context_after_symbol`. If the file contains zero occurrences,
/// or if it contains more than one occurrence, the tool will fail, so it is absolutely
/// critical that you verify ahead of time that the string is unique. You can search
/// the file's contents to verify this ahead of time.
///
/// To make the string more likely to be unique, include a minimum of 1 line of context
/// before the symbol, as well as a minimum of 1 line of context after the symbol.
/// If these lines of context are not enough to obtain a string that appears only once
/// in the file, then double the number of context lines until the string becomes unique.
/// (Start with 1 line before and 1 line after though, because too much context is
/// needlessly costly.)
///
/// Do not alter the context lines of code in any way, and make sure to preserve all
/// whitespace and indentation for all lines of code. The combined string must be exactly
/// as it appears in the file, or else this tool call will fail.
pub symbol: String,
/// The text that comes immediately after the symbol in the file.
pub context_after_symbol: String,
}
pub struct RenameTool;
impl Tool for RenameTool {
fn name(&self) -> String {
"rename".into()
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./rename_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<RenameToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<RenameToolInput>(input.clone()) {
Ok(input) => {
format!("Rename '{}' to '{}'", input.symbol, input.new_name)
}
Err(_) => "Rename symbol".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<RenameToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx| {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
};
action_log.update(cx, |action_log, cx| {
action_log.buffer_read(buffer.clone(), cx);
})?;
let position = {
let Some(position) = buffer.read_with(cx, |buffer, _cx| {
find_symbol_position(&buffer, &input.context_before_symbol, &input.symbol, &input.context_after_symbol)
})? else {
return Err(anyhow!(
"Failed to locate the symbol specified by context_before_symbol, symbol, and context_after_symbol. Make sure context_before_symbol and context_after_symbol each match exactly once in the file."
));
};
buffer.read_with(cx, |buffer, _| {
position.to_point_utf16(&buffer.snapshot())
})?
};
project
.update(cx, |project, cx| {
project.perform_rename(buffer.clone(), position, input.new_name.clone(), cx)
})?
.await?;
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
})?;
Ok(format!("Renamed '{}' to '{}'", input.symbol, input.new_name))
}).into()
}
}
/// Finds the position of the symbol in the buffer, if it appears between context_before_symbol
/// and context_after_symbol, and if that combined string has one unique result in the buffer.
///
/// If an exact match fails, it tries adding a newline to the end of context_before_symbol and
/// to the beginning of context_after_symbol to accommodate line-based context matching.
fn find_symbol_position(
buffer: &Buffer,
context_before_symbol: &str,
symbol: &str,
context_after_symbol: &str,
) -> Option<language::Anchor> {
let snapshot = buffer.snapshot();
let text = snapshot.text();
// First try with exact match
let search_string = format!("{context_before_symbol}{symbol}{context_after_symbol}");
let mut positions = text.match_indices(&search_string);
let position_result = positions.next();
if let Some(position) = position_result {
// Check if the matched string is unique
if positions.next().is_none() {
let symbol_start = position.0 + context_before_symbol.len();
let symbol_start_anchor =
snapshot.anchor_before(snapshot.offset_to_point(symbol_start));
return Some(symbol_start_anchor);
}
}
// If exact match fails or is not unique, try with line-based context
// Add a newline to the end of before context and beginning of after context
let line_based_before = if context_before_symbol.ends_with('\n') {
context_before_symbol.to_string()
} else {
format!("{context_before_symbol}\n")
};
let line_based_after = if context_after_symbol.starts_with('\n') {
context_after_symbol.to_string()
} else {
format!("\n{context_after_symbol}")
};
let line_search_string = format!("{line_based_before}{symbol}{line_based_after}");
let mut line_positions = text.match_indices(&line_search_string);
let line_position = line_positions.next()?;
// The line-based search string must also appear exactly once
if line_positions.next().is_some() {
return None;
}
let line_symbol_start = line_position.0 + line_based_before.len();
let line_symbol_start_anchor =
snapshot.anchor_before(snapshot.offset_to_point(line_symbol_start));
Some(line_symbol_start_anchor)
}

View File

@@ -0,0 +1,15 @@
Renames a symbol across your codebase using the language server's semantic knowledge.
This tool performs a rename refactoring operation on a specified symbol. It uses the project's language server to analyze the code and perform the rename correctly across all files where the symbol is referenced.
Unlike a simple find and replace, this tool understands the semantic meaning of the code, so it only renames the specific symbol you specify and not unrelated text that happens to have the same name.
Examples of symbols you can rename:
- Variables
- Functions
- Classes/structs
- Fields/properties
- Methods
- Interfaces/traits
The language server handles updating all references to the renamed symbol throughout the codebase.

View File

@@ -5,23 +5,20 @@ use schemars::{
schema::{RootSchema, Schema, SchemaObject},
};
pub fn json_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> serde_json::Value {
pub fn json_schema_for<T: JsonSchema>(
format: LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let schema = root_schema_for::<T>(format);
schema_to_json(&schema, format).expect("Failed to convert tool calling schema to JSON")
schema_to_json(&schema, format)
}
pub fn schema_to_json(
fn schema_to_json(
schema: &RootSchema,
format: LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let mut value = serde_json::to_value(schema)?;
match format {
LanguageModelToolSchemaFormat::JsonSchema => Ok(value),
LanguageModelToolSchemaFormat::JsonSchemaSubset => {
transform_fields_to_json_schema_subset(&mut value);
Ok(value)
}
}
assistant_tool::adapt_schema_to_format(&mut value, format)?;
Ok(value)
}
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
@@ -79,42 +76,3 @@ impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
schemars::visit::visit_schema_object(self, schema)
}
}
fn transform_fields_to_json_schema_subset(json: &mut serde_json::Value) {
if let serde_json::Value::Object(obj) = json {
if let Some(default) = obj.get("default") {
let is_null = default.is_null();
//Default is not supported, so we need to remove it.
obj.remove("default");
if is_null {
obj.insert("nullable".to_string(), serde_json::Value::Bool(true));
}
}
// If a type is not specified for an input parameter we need to add it.
if obj.contains_key("description")
&& !obj.contains_key("type")
&& !(obj.contains_key("anyOf")
|| obj.contains_key("oneOf")
|| obj.contains_key("allOf"))
{
obj.insert(
"type".to_string(),
serde_json::Value::String("string".to_string()),
);
}
//Format field is only partially supported (e.g. not uint compatibility)
obj.remove("format");
for (_, value) in obj.iter_mut() {
if let serde_json::Value::Object(_) | serde_json::Value::Array(_) = value {
transform_fields_to_json_schema_subset(value);
}
}
} else if let serde_json::Value::Array(arr) = json {
for item in arr.iter_mut() {
transform_fields_to_json_schema_subset(item);
}
}
}

View File

@@ -1,5 +1,5 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AsyncApp, Entity, Task};
use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -84,7 +84,7 @@ impl Tool for SymbolInfoTool {
IconName::Code
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<SymbolInfoToolInput>(format)
}
@@ -122,10 +122,10 @@ impl Tool for SymbolInfoTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx| {
@@ -205,7 +205,7 @@ impl Tool for SymbolInfoTool {
} else {
Ok(output)
}
})
}).into()
}
}

View File

@@ -0,0 +1,371 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::future;
use util::get_system_shell;
use std::path::Path;
use std::sync::Arc;
use ui::IconName;
use util::command::new_smol_command;
use util::markdown::MarkdownString;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct TerminalToolInput {
/// The one-liner command to execute.
command: String,
/// Working directory for the command. This must be one of the root directories of the project.
cd: String,
}
pub struct TerminalTool;
impl Tool for TerminalTool {
fn name(&self) -> String {
"terminal".to_string()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
}
fn description(&self) -> String {
include_str!("./terminal_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Terminal
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<TerminalToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<TerminalToolInput>(input.clone()) {
Ok(input) => {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count();
match remaining_line_count {
0 => MarkdownString::inline_code(&first_line).0,
1 => {
MarkdownString::inline_code(&format!(
"{} - {} more line",
first_line, remaining_line_count
))
.0
}
n => {
MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
}
}
}
Err(_) => "Run terminal command".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input: TerminalToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project = project.read(cx);
let input_path = Path::new(&input.cd);
let working_dir = if input.cd == "." {
// Accept "." as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx);
let only_worktree = match worktrees.next() {
Some(worktree) => worktree,
None => {
return Task::ready(Err(anyhow!("No worktrees found in the project"))).into();
}
};
if worktrees.next().is_some() {
return Task::ready(Err(anyhow!(
"'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
))).into();
}
only_worktree.read(cx).abs_path()
} else if input_path.is_absolute() {
// Absolute paths are allowed, but only if they're in one of the project's worktrees.
if !project
.worktrees(cx)
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
{
return Task::ready(Err(anyhow!(
"The absolute path must be within one of the project's worktrees"
)))
.into();
}
input_path.into()
} else {
let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
return Task::ready(Err(anyhow!(
"`cd` directory {} not found in the project",
&input.cd
)))
.into();
};
worktree.read(cx).abs_path()
};
cx.background_spawn(run_command_limited(working_dir, input.command))
.into()
}
}
const LIMIT: usize = 16 * 1024;
async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
let shell = get_system_shell();
let mut cmd = new_smol_command(&shell)
.arg("-c")
.arg(&command)
.current_dir(working_dir)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.context("Failed to execute terminal command")?;
let mut combined_buffer = String::with_capacity(LIMIT + 1);
let mut out_reader = BufReader::new(cmd.stdout.take().context("Failed to get stdout")?);
let mut out_tmp_buffer = String::with_capacity(512);
let mut err_reader = BufReader::new(cmd.stderr.take().context("Failed to get stderr")?);
let mut err_tmp_buffer = String::with_capacity(512);
let mut out_line = Box::pin(
out_reader
.read_line(&mut out_tmp_buffer)
.left_future()
.fuse(),
);
let mut err_line = Box::pin(
err_reader
.read_line(&mut err_tmp_buffer)
.left_future()
.fuse(),
);
let mut has_stdout = true;
let mut has_stderr = true;
while (has_stdout || has_stderr) && combined_buffer.len() < LIMIT + 1 {
futures::select_biased! {
read = out_line => {
drop(out_line);
combined_buffer.extend(out_tmp_buffer.drain(..));
if read? == 0 {
out_line = Box::pin(future::pending().right_future().fuse());
has_stdout = false;
} else {
out_line = Box::pin(out_reader.read_line(&mut out_tmp_buffer).left_future().fuse());
}
}
read = err_line => {
drop(err_line);
combined_buffer.extend(err_tmp_buffer.drain(..));
if read? == 0 {
err_line = Box::pin(future::pending().right_future().fuse());
has_stderr = false;
} else {
err_line = Box::pin(err_reader.read_line(&mut err_tmp_buffer).left_future().fuse());
}
}
};
}
drop((out_line, err_line));
let truncated = combined_buffer.len() > LIMIT;
combined_buffer.truncate(LIMIT);
consume_reader(out_reader, truncated).await?;
consume_reader(err_reader, truncated).await?;
let status = cmd.status().await.context("Failed to get command status")?;
let output_string = if truncated {
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
// multi-byte characters.
let last_line_ix = combined_buffer.bytes().rposition(|b| b == b'\n');
let combined_buffer = &combined_buffer[..last_line_ix.unwrap_or(combined_buffer.len())];
format!(
"Command output too long. The first {} bytes:\n\n{}",
combined_buffer.len(),
output_block(&combined_buffer),
)
} else {
output_block(&combined_buffer)
};
let output_with_status = if status.success() {
if output_string.is_empty() {
"Command executed successfully.".to_string()
} else {
output_string.to_string()
}
} else {
format!(
"Command failed with exit code {} (shell: {}).\n\n{}",
status.code().unwrap_or(-1),
shell,
output_string,
)
};
Ok(output_with_status)
}
async fn consume_reader<T: AsyncReadExt + Unpin>(
mut reader: BufReader<T>,
truncated: bool,
) -> Result<(), std::io::Error> {
loop {
let skipped_bytes = reader.fill_buf().await?;
if skipped_bytes.is_empty() {
break;
}
let skipped_bytes_len = skipped_bytes.len();
reader.consume_unpin(skipped_bytes_len);
// Should only skip if we went over the limit
debug_assert!(truncated);
}
Ok(())
}
fn output_block(output: &str) -> String {
format!(
"```\n{}{}```",
output,
if output.ends_with('\n') { "" } else { "\n" }
)
}
#[cfg(test)]
#[cfg(not(windows))]
mod tests {
use gpui::TestAppContext;
use super::*;
#[gpui::test(iterations = 10)]
async fn test_run_command_simple(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let result =
run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\nHello, World!\n```");
}
#[gpui::test(iterations = 10)]
async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let command = "echo 'stdout 1' && sleep 0.01 && echo 'stderr 1' >&2 && sleep 0.01 && echo 'stdout 2' && sleep 0.01 && echo 'stderr 2' >&2";
let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
"```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
);
}
#[gpui::test(iterations = 10)]
async fn test_multiple_output_reads(cx: &mut TestAppContext) {
cx.executor().allow_parking();
// Command with multiple outputs that might require multiple reads
let result = run_command_limited(
Path::new(".").into(),
"echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
}
#[gpui::test(iterations = 10)]
async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}'; sleep 0.01;", "X".repeat(LIMIT * 2));
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
// Output should be exactly the limit
assert_eq!(content_length, LIMIT);
}
#[gpui::test(iterations = 10)]
async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
assert!(content_length <= LIMIT);
}
#[gpui::test(iterations = 10)]
async fn test_command_failure(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let result = run_command_limited(Path::new(".").into(), "exit 42".to_string()).await;
assert!(result.is_ok());
let output = result.unwrap();
// Extract the shell name from path for cleaner test output
let shell_path = std::env::var("SHELL").unwrap_or("bash".to_string());
let expected_output = format!(
"Command failed with exit code 42 (shell: {}).\n\n```\n\n```",
shell_path
);
assert_eq!(output, expected_output);
}
}

View File

@@ -0,0 +1,9 @@
Executes a shell one-liner and returns the combined output.
This tool spawns a process using the user's current shell, combines stdout and stderr into one interleaved stream as they are produced (preserving the order of writes), and captures that stream into a string which is returned.
Make sure you use the `cd` parameter to navigate to one of the root directories of the project. NEVER do it as part of the `command` itself, otherwise it will error.
Do not use this tool for commands that run indefinitely, such as servers (e.g., `python -m http.server`) or file watchers that don't terminate on their own.
Remember that each invocation of this tool will spawn a new shell process, so you can't rely on any state from previous invocations.

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -36,7 +36,7 @@ impl Tool for ThinkingTool {
IconName::LightBulb
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ThinkingToolInput>(format)
}
@@ -51,11 +51,12 @@ impl Tool for ThinkingTool {
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
// This tool just "thinks out loud" and doesn't perform any actions.
Task::ready(match serde_json::from_value::<ThinkingToolInput>(input) {
Ok(_input) => Ok("Finished thinking.".to_string()),
Err(err) => Err(anyhow!(err)),
})
.into()
}
}

View File

@@ -7,8 +7,6 @@ fn main() {
if cfg!(target_os = "macos") {
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
// Weakly link ScreenCaptureKit to ensure can be used on macOS 10.15+.
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ScreenCaptureKit");
}
// Populate git sha environment variable if git is available

View File

@@ -16,6 +16,7 @@ pub enum CliRequest {
wait: bool,
open_new_workspace: Option<bool>,
env: Option<HashMap<String, String>>,
user_data_dir: Option<String>,
},
}

View File

@@ -26,7 +26,11 @@ struct Detect;
trait InstalledApp {
fn zed_version_string(&self) -> String;
fn launch(&self, ipc_url: String) -> anyhow::Result<()>;
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus>;
fn run_foreground(
&self,
ipc_url: String,
user_data_dir: Option<&str>,
) -> io::Result<ExitStatus>;
fn path(&self) -> PathBuf;
}
@@ -58,6 +62,13 @@ struct Args {
/// Create a new workspace
#[arg(short, long, overrides_with = "add")]
new: bool,
/// Sets a custom directory for all user data (e.g., database, extensions, logs).
/// This overrides the default platform-specific data directory location.
/// On macOS, the default is `~/Library/Application Support/Zed`.
/// On Linux/FreeBSD, the default is `$XDG_DATA_HOME/zed`.
/// On Windows, the default is `%LOCALAPPDATA%\Zed`.
#[arg(long, value_name = "DIR")]
user_data_dir: Option<String>,
/// The paths to open in Zed (space-separated).
///
/// Use `path:line:column` syntax to open a file at the given line and column.
@@ -135,6 +146,12 @@ fn main() -> Result<()> {
}
let args = Args::parse();
// Set custom data directory before any path operations
let user_data_dir = args.user_data_dir.clone();
if let Some(dir) = &user_data_dir {
paths::set_custom_data_dir(dir);
}
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
let args = flatpak::set_bin_if_no_escape(args);
@@ -246,6 +263,7 @@ fn main() -> Result<()> {
let sender: JoinHandle<anyhow::Result<()>> = thread::spawn({
let exit_status = exit_status.clone();
let user_data_dir_for_thread = user_data_dir.clone();
move || {
let (_, handshake) = server.accept().context("Handshake after Zed spawn")?;
let (tx, rx) = (handshake.requests, handshake.responses);
@@ -256,6 +274,7 @@ fn main() -> Result<()> {
wait: args.wait,
open_new_workspace,
env,
user_data_dir: user_data_dir_for_thread,
})?;
while let Ok(response) = rx.recv() {
@@ -291,7 +310,7 @@ fn main() -> Result<()> {
.collect();
if args.foreground {
app.run_foreground(url)?;
app.run_foreground(url, user_data_dir.as_deref())?;
} else {
app.launch(url)?;
sender.join().unwrap()?;
@@ -437,7 +456,7 @@ mod linux {
}
fn launch(&self, ipc_url: String) -> anyhow::Result<()> {
let sock_path = paths::support_dir().join(format!("zed-{}.sock", *RELEASE_CHANNEL));
let sock_path = paths::data_dir().join(format!("zed-{}.sock", *RELEASE_CHANNEL));
let sock = UnixDatagram::unbound()?;
if sock.connect(&sock_path).is_err() {
self.boot_background(ipc_url)?;
@@ -447,10 +466,17 @@ mod linux {
Ok(())
}
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus> {
std::process::Command::new(self.0.clone())
.arg(ipc_url)
.status()
fn run_foreground(
&self,
ipc_url: String,
user_data_dir: Option<&str>,
) -> io::Result<ExitStatus> {
let mut cmd = std::process::Command::new(self.0.clone());
cmd.arg(ipc_url);
if let Some(dir) = user_data_dir {
cmd.arg("--user-data-dir").arg(dir);
}
cmd.status()
}
fn path(&self) -> PathBuf {
@@ -688,12 +714,17 @@ mod windows {
Ok(())
}
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus> {
std::process::Command::new(self.0.clone())
.arg(ipc_url)
.arg("--foreground")
.spawn()?
.wait()
fn run_foreground(
&self,
ipc_url: String,
user_data_dir: Option<&str>,
) -> io::Result<ExitStatus> {
let mut cmd = std::process::Command::new(self.0.clone());
cmd.arg(ipc_url).arg("--foreground");
if let Some(dir) = user_data_dir {
cmd.arg("--user-data-dir").arg(dir);
}
cmd.spawn()?.wait()
}
fn path(&self) -> PathBuf {
@@ -875,13 +906,22 @@ mod mac_os {
Ok(())
}
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus> {
fn run_foreground(
&self,
ipc_url: String,
user_data_dir: Option<&str>,
) -> io::Result<ExitStatus> {
let path = match self {
Bundle::App { app_bundle, .. } => app_bundle.join("Contents/MacOS/zed"),
Bundle::LocalPath { executable, .. } => executable.clone(),
};
std::process::Command::new(path).arg(ipc_url).status()
let mut cmd = std::process::Command::new(path);
cmd.arg(ipc_url);
if let Some(dir) = user_data_dir {
cmd.arg("--user-data-dir").arg(dir);
}
cmd.status()
}
fn path(&self) -> PathBuf {

View File

@@ -274,7 +274,7 @@ async fn create_billing_subscription(
customer.id
};
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
let stripe_model = stripe_billing.register_model(default_model).await?;
let success_url = format!(
"{}/account?checkout_complete=1",

View File

@@ -187,22 +187,20 @@ impl ComponentPreview {
let mut entries = Vec::new();
let known_scopes = [
ComponentScope::Layout,
ComponentScope::Input,
ComponentScope::Editor,
ComponentScope::Notification,
ComponentScope::Collaboration,
ComponentScope::VersionControl,
ComponentScope::None,
];
// Always show all components first
entries.push(PreviewEntry::AllComponents);
entries.push(PreviewEntry::Separator);
for scope in known_scopes.iter() {
if let Some(components) = scope_groups.remove(scope) {
let mut scopes: Vec<_> = scope_groups
.keys()
.filter(|scope| !matches!(**scope, ComponentScope::None))
.cloned()
.collect();
scopes.sort_by_key(|s| s.to_string());
for scope in scopes {
if let Some(components) = scope_groups.remove(&scope) {
if !components.is_empty() {
entries.push(PreviewEntry::SectionHeader(scope.to_string().into()));
let mut sorted_components = components;
@@ -215,6 +213,7 @@ impl ComponentPreview {
}
}
// Add uncategorized components last
if let Some(components) = scope_groups.get(&ComponentScope::None) {
if !components.is_empty() {
entries.push(PreviewEntry::Separator);
@@ -272,7 +271,12 @@ impl ComponentPreview {
.into_any_element()
}
PreviewEntry::Separator => ListItem::new(ix)
.child(h_flex().pt_3().child(Divider::horizontal_dashed()))
.child(
h_flex()
.occlude()
.pt_3()
.child(Divider::horizontal_dashed()),
)
.into_any_element(),
}
}

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolSource};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
use gpui::{App, Entity, Task};
use icons::IconName;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -53,16 +53,18 @@ impl Tool for ContextServerTool {
true
}
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
match &self.tool.input_schema {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
let mut schema = self.tool.input_schema.clone();
assistant_tool::adapt_schema_to_format(&mut schema, format)?;
Ok(match schema {
serde_json::Value::Null => {
serde_json::json!({ "type": "object", "properties": [] })
}
serde_json::Value::Object(map) if map.is_empty() => {
serde_json::json!({ "type": "object", "properties": [] })
}
_ => self.tool.input_schema.clone(),
}
_ => schema,
})
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
@@ -76,7 +78,7 @@ impl Tool for ContextServerTool {
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
) -> ToolResult {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
let tool_name = self.tool.name.clone();
let server_clone = server.clone();
@@ -116,8 +118,9 @@ impl Tool for ContextServerTool {
}
Ok(result)
})
.into()
} else {
Task::ready(Err(anyhow!("Context server not found")))
Task::ready(Err(anyhow!("Context server not found"))).into()
}
}
}

View File

@@ -33,6 +33,8 @@ pub enum Model {
Gpt4o,
#[serde(alias = "gpt-4", rename = "gpt-4")]
Gpt4,
#[serde(alias = "gpt-4.1", rename = "gpt-4.1")]
Gpt4_1,
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
Gpt3_5Turbo,
#[serde(alias = "o1", rename = "o1")]
@@ -50,6 +52,8 @@ pub enum Model {
Claude3_7SonnetThinking,
#[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
Gemini20Flash,
#[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")]
Gemini25Pro,
}
impl Model {
@@ -57,11 +61,12 @@ impl Model {
match self {
Self::Gpt4o
| Self::Gpt4
| Self::Gpt4_1
| Self::Gpt3_5Turbo
| Self::Claude3_5Sonnet
| Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking => true,
Self::O3Mini | Self::O1 | Self::Gemini20Flash => false,
Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
}
}
@@ -69,6 +74,7 @@ impl Model {
match id {
"gpt-4o" => Ok(Self::Gpt4o),
"gpt-4" => Ok(Self::Gpt4),
"gpt-4.1" => Ok(Self::Gpt4_1),
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
"o1" => Ok(Self::O1),
"o3-mini" => Ok(Self::O3Mini),
@@ -76,6 +82,7 @@ impl Model {
"claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
"claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
"gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
"gemini-2.5-pro" => Ok(Self::Gemini25Pro),
_ => Err(anyhow!("Invalid model id: {}", id)),
}
}
@@ -84,6 +91,7 @@ impl Model {
match self {
Self::Gpt3_5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4_1 => "gpt-4.1",
Self::Gpt4o => "gpt-4o",
Self::O3Mini => "o3-mini",
Self::O1 => "o1",
@@ -91,6 +99,7 @@ impl Model {
Self::Claude3_7Sonnet => "claude-3-7-sonnet",
Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
Self::Gemini20Flash => "gemini-2.0-flash-001",
Self::Gemini25Pro => "gemini-2.5-pro",
}
}
@@ -98,6 +107,7 @@ impl Model {
match self {
Self::Gpt3_5Turbo => "GPT-3.5",
Self::Gpt4 => "GPT-4",
Self::Gpt4_1 => "GPT-4.1",
Self::Gpt4o => "GPT-4o",
Self::O3Mini => "o3-mini",
Self::O1 => "o1",
@@ -105,6 +115,7 @@ impl Model {
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
Self::Gemini20Flash => "Gemini 2.0 Flash",
Self::Gemini25Pro => "Gemini 2.5 Pro",
}
}
@@ -112,13 +123,15 @@ impl Model {
match self {
Self::Gpt4o => 64_000,
Self::Gpt4 => 32_768,
Self::Gpt4_1 => 1_047_576,
Self::Gpt3_5Turbo => 12_288,
Self::O3Mini => 64_000,
Self::O1 => 20_000,
Self::Claude3_5Sonnet => 200_000,
Self::Claude3_7Sonnet => 90_000,
Self::Claude3_7SonnetThinking => 90_000,
Model::Gemini20Flash => 128_000,
Self::Gemini20Flash => 128_000,
Self::Gemini25Pro => 128_000,
}
}
}

View File

@@ -1172,7 +1172,7 @@ async fn test_send_breakpoints_when_editor_has_been_saved(
let (editor, cx) = cx.add_window_view(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
MultiBuffer::build_from_buffer(buffer, cx),
Some(project.clone()),
window,
@@ -1347,7 +1347,7 @@ async fn test_unsetting_breakpoints_on_clear_breakpoint_action(
let (first_editor, cx) = cx.add_window_view(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
MultiBuffer::build_from_buffer(first, cx),
Some(project.clone()),
window,
@@ -1357,7 +1357,7 @@ async fn test_unsetting_breakpoints_on_clear_breakpoint_action(
let (second_editor, cx) = cx.add_window_view(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
MultiBuffer::build_from_buffer(second, cx),
Some(project.clone()),
window,

View File

@@ -49,8 +49,8 @@ use language::{
};
use lsp::DiagnosticSeverity;
use multi_buffer::{
Anchor, AnchorRangeExt, MultiBuffer, MultiBufferPoint, MultiBufferRow, MultiBufferSnapshot,
RowInfo, ToOffset, ToPoint,
Anchor, AnchorRangeExt, ExcerptId, MultiBuffer, MultiBufferPoint, MultiBufferRow,
MultiBufferSnapshot, RowInfo, ToOffset, ToPoint,
};
use serde::Deserialize;
use std::{
@@ -574,6 +574,21 @@ impl DisplayMap {
self.block_map.read(snapshot, edits);
}
pub fn remove_inlays_for_excerpts(&mut self, excerpts_removed: &[ExcerptId]) {
let to_remove = self
.inlay_map
.current_inlays()
.filter_map(|inlay| {
if excerpts_removed.contains(&inlay.position.excerpt_id) {
Some(inlay.id)
} else {
None
}
})
.collect::<Vec<_>>();
self.inlay_map.splice(&to_remove, Vec::new());
}
fn tab_size(buffer: &Entity<MultiBuffer>, cx: &App) -> NonZeroU32 {
let buffer = buffer.read(cx).as_singleton().map(|buffer| buffer.read(cx));
let language = buffer

View File

@@ -396,9 +396,31 @@ pub enum SelectMode {
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum EditorMode {
SingleLine { auto_width: bool },
AutoHeight { max_lines: usize },
Full,
SingleLine {
auto_width: bool,
},
AutoHeight {
max_lines: usize,
},
Full {
/// When set to `true`, the editor will scale its UI elements with the buffer font size.
scale_ui_elements_with_buffer_font_size: bool,
/// When set to `true`, the editor will render a background for the active line.
show_active_line_background: bool,
},
}
impl EditorMode {
pub fn full() -> Self {
Self::Full {
scale_ui_elements_with_buffer_font_size: true,
show_active_line_background: true,
}
}
pub fn is_full(&self) -> bool {
matches!(self, Self::Full { .. })
}
}
#[derive(Copy, Clone, Debug)]
@@ -1198,7 +1220,7 @@ impl Editor {
pub fn multi_line(window: &mut Window, cx: &mut Context<Self>) -> Self {
let buffer = cx.new(|cx| Buffer::local("", cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
Self::new(EditorMode::Full, buffer, None, window, cx)
Self::new(EditorMode::full(), buffer, None, window, cx)
}
pub fn auto_width(window: &mut Window, cx: &mut Context<Self>) -> Self {
@@ -1232,7 +1254,7 @@ impl Editor {
cx: &mut Context<Self>,
) -> Self {
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
Self::new(EditorMode::Full, buffer, project, window, cx)
Self::new(EditorMode::full(), buffer, project, window, cx)
}
pub fn for_multibuffer(
@@ -1241,7 +1263,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
Self::new(EditorMode::Full, buffer, project, window, cx)
Self::new(EditorMode::full(), buffer, project, window, cx)
}
pub fn clone(&self, window: &mut Window, cx: &mut Context<Self>) -> Self {
@@ -1329,7 +1351,7 @@ impl Editor {
.then(|| language_settings::SoftWrap::None);
let mut project_subscriptions = Vec::new();
if mode == EditorMode::Full {
if mode.is_full() {
if let Some(project) = project.as_ref() {
project_subscriptions.push(cx.subscribe_in(
project,
@@ -1417,7 +1439,7 @@ impl Editor {
};
let breakpoint_store = match (mode, project.as_ref()) {
(EditorMode::Full, Some(project)) => Some(project.read(cx).breakpoint_store()),
(EditorMode::Full { .. }, Some(project)) => Some(project.read(cx).breakpoint_store()),
_ => None,
};
@@ -1468,7 +1490,7 @@ impl Editor {
show_scrollbars: true,
mode,
show_breadcrumbs: EditorSettings::get_global(cx).toolbar.breadcrumbs,
show_gutter: mode == EditorMode::Full,
show_gutter: mode.is_full(),
show_line_numbers: None,
use_relative_line_numbers: None,
show_git_diff_gutter: None,
@@ -1510,7 +1532,7 @@ impl Editor {
collapse_matches: false,
workspace: None,
input_enabled: true,
use_modal_editing: mode == EditorMode::Full,
use_modal_editing: mode.is_full(),
read_only: false,
use_autoclose: true,
use_auto_surround: true,
@@ -1527,7 +1549,7 @@ impl Editor {
edit_prediction_preview: EditPredictionPreview::Inactive {
released_too_fast: false,
},
inline_diagnostics_enabled: mode == EditorMode::Full,
inline_diagnostics_enabled: mode.is_full(),
inlay_hint_cache: InlayHintCache::new(inlay_hint_settings),
gutter_hovered: false,
@@ -1635,7 +1657,7 @@ impl Editor {
this.scroll_manager.show_scrollbars(window, cx);
jsx_tag_auto_close::refresh_enabled_in_any_buffer(&mut this, &buffer, cx);
if mode == EditorMode::Full {
if mode.is_full() {
let should_auto_hide_scrollbars = cx.should_auto_hide_scrollbars();
cx.set_global(ScrollbarAutoHide(should_auto_hide_scrollbars));
@@ -1697,7 +1719,7 @@ impl Editor {
let mode = match self.mode {
EditorMode::SingleLine { .. } => "single_line",
EditorMode::AutoHeight { .. } => "auto_height",
EditorMode::Full => "full",
EditorMode::Full { .. } => "full",
};
if EditorSettings::jupyter_enabled(cx) {
@@ -1984,6 +2006,10 @@ impl Editor {
self.mode
}
pub fn set_mode(&mut self, mode: EditorMode) {
self.mode = mode;
}
pub fn collaboration_hub(&self) -> Option<&dyn CollaborationHub> {
self.collaboration_hub.as_deref()
}
@@ -3006,7 +3032,7 @@ impl Editor {
return;
}
if self.mode == EditorMode::Full
if self.mode.is_full()
&& self.change_selections(Some(Autoscroll::fit()), window, cx, |s| s.try_cancel())
{
return;
@@ -3049,7 +3075,7 @@ impl Editor {
return true;
}
if self.mode == EditorMode::Full && self.active_diagnostics.is_some() {
if self.mode.is_full() && self.active_diagnostics.is_some() {
self.dismiss_diagnostics(cx);
return true;
}
@@ -4031,7 +4057,7 @@ impl Editor {
}
fn refresh_inlay_hints(&mut self, reason: InlayHintRefreshReason, cx: &mut Context<Self>) {
if self.semantics_provider.is_none() || self.mode != EditorMode::Full {
if self.semantics_provider.is_none() || !self.mode.is_full() {
return;
}
@@ -4107,10 +4133,13 @@ impl Editor {
if let Some(InlaySplice {
to_remove,
to_insert,
}) = self.inlay_hint_cache.remove_excerpts(excerpts_removed)
}) = self.inlay_hint_cache.remove_excerpts(&excerpts_removed)
{
self.splice_inlays(&to_remove, to_insert, cx);
}
self.display_map.update(cx, |display_map, _| {
display_map.remove_inlays_for_excerpts(&excerpts_removed)
});
return;
}
InlayHintRefreshReason::NewLinesShown => (InvalidationStrategy::None, None),
@@ -5536,7 +5565,7 @@ impl Editor {
buffer_position: language::Anchor,
cx: &App,
) -> EditPredictionSettings {
if self.mode != EditorMode::Full
if !self.mode.is_full()
|| !self.show_inline_completions_override.unwrap_or(true)
|| self.inline_completions_disabled_in_scope(buffer, buffer_position, cx)
{
@@ -17225,7 +17254,7 @@ impl Editor {
let project_settings = ProjectSettings::get_global(cx);
self.serialize_dirty_buffers = project_settings.session.restore_unsaved_buffers;
if self.mode == EditorMode::Full {
if self.mode.is_full() {
let show_inline_diagnostics = project_settings.diagnostics.inline.enabled;
let inline_blame_enabled = project_settings.git.inline_blame_enabled();
if self.show_inline_diagnostics != show_inline_diagnostics {
@@ -19513,7 +19542,7 @@ impl Render for Editor {
line_height: relative(settings.buffer_line_height.value()),
..Default::default()
},
EditorMode::Full => TextStyle {
EditorMode::Full { .. } => TextStyle {
color: cx.theme().colors().editor_foreground,
font_family: settings.buffer_font.family.clone(),
font_features: settings.buffer_font.features.clone(),
@@ -19531,7 +19560,7 @@ impl Render for Editor {
let background = match self.mode {
EditorMode::SingleLine { .. } => cx.theme().system().transparent,
EditorMode::AutoHeight { max_lines: _ } => cx.theme().system().transparent,
EditorMode::Full => cx.theme().colors().editor_background,
EditorMode::Full { .. } => cx.theme().colors().editor_background,
};
EditorElement::new(

View File

@@ -7916,7 +7916,7 @@ async fn test_multibuffer_format_during_save(cx: &mut TestAppContext) {
});
let multi_buffer_editor = cx.new_window_entity(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
multi_buffer,
Some(project.clone()),
window,
@@ -14097,7 +14097,7 @@ async fn test_mutlibuffer_in_navigation_history(cx: &mut TestAppContext) {
let cx = &mut VisualTestContext::from_window(*workspace.deref(), cx);
let multi_buffer_editor = cx.new_window_entity(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
multi_buffer,
Some(project.clone()),
window,
@@ -14556,7 +14556,7 @@ async fn test_toggle_diff_expand_in_multi_buffer(cx: &mut TestAppContext) {
});
let editor =
cx.add_window(|window, cx| Editor::new(EditorMode::Full, multi_buffer, None, window, cx));
cx.add_window(|window, cx| Editor::new(EditorMode::full(), multi_buffer, None, window, cx));
editor
.update(cx, |editor, _window, cx| {
for (buffer, diff_base) in [
@@ -14667,7 +14667,7 @@ async fn test_expand_diff_hunk_at_excerpt_boundary(cx: &mut TestAppContext) {
});
let editor =
cx.add_window(|window, cx| Editor::new(EditorMode::Full, multi_buffer, None, window, cx));
cx.add_window(|window, cx| Editor::new(EditorMode::full(), multi_buffer, None, window, cx));
editor
.update(cx, |editor, _window, cx| {
let diff = cx.new(|cx| BufferDiff::new_with_base_text(base, &buffer, cx));
@@ -16223,7 +16223,7 @@ async fn test_display_diff_hunks(cx: &mut TestAppContext) {
});
let editor = cx.add_window(|window, cx| {
Editor::new(EditorMode::Full, multibuffer, Some(project), window, cx)
Editor::new(EditorMode::full(), multibuffer, Some(project), window, cx)
});
cx.run_until_parked();
@@ -16740,7 +16740,7 @@ async fn test_find_enclosing_node_with_task(cx: &mut TestAppContext) {
let editor = cx.new_window_entity(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
multi_buffer,
Some(project.clone()),
window,
@@ -16867,7 +16867,7 @@ async fn test_folding_buffers(cx: &mut TestAppContext) {
});
let multi_buffer_editor = cx.new_window_entity(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
multi_buffer.clone(),
Some(project.clone()),
window,
@@ -17024,7 +17024,7 @@ async fn test_folding_buffers_with_one_excerpt(cx: &mut TestAppContext) {
let multi_buffer_editor = cx.new_window_entity(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
multi_buffer,
Some(project.clone()),
window,
@@ -17142,7 +17142,7 @@ async fn test_folding_buffer_when_multibuffer_has_only_one_excerpt(cx: &mut Test
});
let multi_buffer_editor = cx.new_window_entity(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
multi_buffer,
Some(project.clone()),
window,
@@ -17192,7 +17192,7 @@ async fn test_multi_buffer_navigation_with_folded_buffers(cx: &mut TestAppContex
],
cx,
);
let mut editor = Editor::new(EditorMode::Full, multi_buffer.clone(), None, window, cx);
let mut editor = Editor::new(EditorMode::full(), multi_buffer.clone(), None, window, cx);
let buffer_ids = multi_buffer.read(cx).excerpt_buffer_ids();
// fold all but the second buffer, so that we test navigating between two
@@ -17504,7 +17504,7 @@ async fn assert_highlighted_edits(
) {
let window = cx.add_window(|window, cx| {
let buffer = MultiBuffer::build_simple(text, cx);
Editor::new(EditorMode::Full, buffer, None, window, cx)
Editor::new(EditorMode::full(), buffer, None, window, cx)
});
let cx = &mut VisualTestContext::from_window(*window, cx);
@@ -17662,7 +17662,7 @@ async fn test_breakpoint_toggling(cx: &mut TestAppContext) {
let (editor, cx) = cx.add_window_view(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
MultiBuffer::build_from_buffer(buffer, cx),
Some(project.clone()),
window,
@@ -17779,7 +17779,7 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) {
let (editor, cx) = cx.add_window_view(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
MultiBuffer::build_from_buffer(buffer, cx),
Some(project.clone()),
window,
@@ -17954,7 +17954,7 @@ async fn test_breakpoint_enabling_and_disabling(cx: &mut TestAppContext) {
let (editor, cx) = cx.add_window_view(|window, cx| {
Editor::new(
EditorMode::Full,
EditorMode::full(),
MultiBuffer::build_from_buffer(buffer, cx),
Some(project.clone()),
window,

View File

@@ -1404,7 +1404,7 @@ impl EditorElement {
window: &mut Window,
cx: &mut App,
) -> Option<EditorScrollbars> {
if snapshot.mode != EditorMode::Full {
if !snapshot.mode.is_full() {
return None;
}
@@ -2373,7 +2373,7 @@ impl EditorElement {
cx: &mut App,
) -> Arc<HashMap<MultiBufferRow, LineNumberLayout>> {
let include_line_numbers = snapshot.show_line_numbers.unwrap_or_else(|| {
EditorSettings::get_global(cx).gutter.line_numbers && snapshot.mode == EditorMode::Full
EditorSettings::get_global(cx).gutter.line_numbers && snapshot.mode.is_full()
});
if !include_line_numbers {
return Arc::default();
@@ -2479,7 +2479,7 @@ impl EditorElement {
cx: &mut App,
) -> Vec<Option<AnyElement>> {
let include_fold_statuses = EditorSettings::get_global(cx).gutter.folds
&& snapshot.mode == EditorMode::Full
&& snapshot.mode.is_full()
&& self.editor.read(cx).is_singleton(cx);
if include_fold_statuses {
row_infos
@@ -4178,7 +4178,11 @@ impl EditorElement {
self.style.background,
));
if let EditorMode::Full = layout.mode {
if let EditorMode::Full {
show_active_line_background,
..
} = layout.mode
{
let mut active_rows = layout.active_rows.iter().peekable();
while let Some((start_row, contains_non_empty_selection)) = active_rows.next() {
let mut end_row = start_row.0;
@@ -4193,7 +4197,7 @@ impl EditorElement {
end_row += 1;
}
if !contains_non_empty_selection.selection {
if show_active_line_background && !contains_non_empty_selection.selection {
let highlight_h_range =
match layout.position_map.snapshot.current_line_highlight {
CurrentLineHighlight::Gutter => Some(Range {
@@ -6061,7 +6065,7 @@ impl LineWithInvisibles {
strikethrough: text_style.strikethrough,
});
if editor_mode == EditorMode::Full {
if editor_mode.is_full() {
// Line wrap pads its contents with fake whitespaces,
// avoid printing them
let is_soft_wrapped = is_row_soft_wrapped(row);
@@ -6416,7 +6420,13 @@ impl EditorElement {
/// This allows UI elements to scale based on the `buffer_font_size`.
fn rem_size(&self, cx: &mut App) -> Option<Pixels> {
match self.editor.read(cx).mode {
EditorMode::Full => {
EditorMode::Full {
scale_ui_elements_with_buffer_font_size,
..
} => {
if !scale_ui_elements_with_buffer_font_size {
return None;
}
let buffer_font_size = self.style.text.font_size;
match buffer_font_size {
AbsoluteLength::Pixels(pixels) => {
@@ -6533,7 +6543,7 @@ impl Element for EditorElement {
},
)
}
EditorMode::Full => {
EditorMode::Full { .. } => {
let mut style = Style::default();
style.size.width = relative(1.).into();
style.size.height = relative(1.).into();
@@ -8509,7 +8519,7 @@ mod tests {
init_test(cx, |_| {});
let window = cx.add_window(|window, cx| {
let buffer = MultiBuffer::build_simple(&sample_text(6, 6, 'a'), cx);
Editor::new(EditorMode::Full, buffer, None, window, cx)
Editor::new(EditorMode::full(), buffer, None, window, cx)
});
let editor = window.root(cx).unwrap();
@@ -8610,7 +8620,7 @@ mod tests {
let window = cx.add_window(|window, cx| {
let buffer = MultiBuffer::build_simple(&(sample_text(6, 6, 'a') + "\n"), cx);
Editor::new(EditorMode::Full, buffer, None, window, cx)
Editor::new(EditorMode::full(), buffer, None, window, cx)
});
let cx = &mut VisualTestContext::from_window(*window, cx);
let editor = window.root(cx).unwrap();
@@ -8681,7 +8691,7 @@ mod tests {
let window = cx.add_window(|window, cx| {
let buffer = MultiBuffer::build_simple("", cx);
Editor::new(EditorMode::Full, buffer, None, window, cx)
Editor::new(EditorMode::full(), buffer, None, window, cx)
});
let cx = &mut VisualTestContext::from_window(*window, cx);
let editor = window.root(cx).unwrap();
@@ -8767,7 +8777,7 @@ mod tests {
let actual_invisibles = collect_invisibles_from_new_editor(
cx,
EditorMode::Full,
EditorMode::full(),
input_text,
px(500.0),
show_line_numbers,
@@ -8861,7 +8871,7 @@ mod tests {
let actual_invisibles = collect_invisibles_from_new_editor(
cx,
EditorMode::Full,
EditorMode::full(),
&input_text,
px(editor_width),
show_line_numbers,

View File

@@ -8,8 +8,8 @@ use crate::{
use gpui::{
AnyElement, AsyncWindowContext, Context, Entity, Focusable as _, FontWeight, Hsla,
InteractiveElement, IntoElement, MouseButton, ParentElement, Pixels, ScrollHandle, Size,
Stateful, StatefulInteractiveElement, StyleRefinement, Styled, Task, TextStyleRefinement,
Window, div, px,
Stateful, StatefulInteractiveElement, StyleRefinement, Styled, Subscription, Task,
TextStyleRefinement, Window, div, px,
};
use itertools::Itertools;
use language::{DiagnosticEntry, Language, LanguageRegistry};
@@ -63,26 +63,31 @@ pub fn show_keyboard_hover(
window: &mut Window,
cx: &mut Context<Editor>,
) -> bool {
let info_popovers = editor.hover_state.info_popovers.clone();
for p in info_popovers {
let keyboard_grace = p.keyboard_grace.borrow();
if *keyboard_grace {
if let Some(anchor) = p.anchor {
show_hover(editor, anchor, false, window, cx);
return true;
}
if let Some(anchor) = editor.hover_state.info_popovers.iter().find_map(|p| {
if *p.keyboard_grace.borrow() {
p.anchor
} else {
None
}
}) {
show_hover(editor, anchor, false, window, cx);
return true;
}
let diagnostic_popover = editor.hover_state.diagnostic_popover.clone();
if let Some(d) = diagnostic_popover {
let keyboard_grace = d.keyboard_grace.borrow();
if *keyboard_grace {
if let Some(anchor) = d.anchor {
show_hover(editor, anchor, false, window, cx);
return true;
if let Some(anchor) = editor
.hover_state
.diagnostic_popover
.as_ref()
.and_then(|d| {
if *d.keyboard_grace.borrow() {
d.anchor
} else {
None
}
}
})
{
show_hover(editor, anchor, false, window, cx);
return true;
}
false
@@ -163,6 +168,18 @@ pub fn hover_at_inlay(
let parsed_content = parse_blocks(&blocks, &language_registry, None, cx).await;
let scroll_handle = ScrollHandle::new();
let subscription = this
.update(cx, |_, cx| {
if let Some(parsed_content) = &parsed_content {
Some(cx.observe(parsed_content, |_, _, cx| cx.notify()))
} else {
None
}
})
.ok()
.flatten();
let hover_popover = InfoPopover {
symbol_range: RangeInEditor::Inlay(inlay_hover.range.clone()),
parsed_content,
@@ -170,6 +187,7 @@ pub fn hover_at_inlay(
scroll_handle,
keyboard_grace: Rc::new(RefCell::new(false)),
anchor: None,
_subscription: subscription,
};
this.update(cx, |this, cx| {
@@ -306,40 +324,44 @@ fn show_hover(
.anchor_after(local_diagnostic.range.end),
};
let mut border_color: Option<Hsla> = None;
let mut background_color: Option<Hsla> = None;
let (background_color, border_color) = cx.update(|_, cx| {
let status_colors = cx.theme().status();
match local_diagnostic.diagnostic.severity {
DiagnosticSeverity::ERROR => {
(status_colors.error_background, status_colors.error_border)
}
DiagnosticSeverity::WARNING => (
status_colors.warning_background,
status_colors.warning_border,
),
DiagnosticSeverity::INFORMATION => {
(status_colors.info_background, status_colors.info_border)
}
DiagnosticSeverity::HINT => {
(status_colors.hint_background, status_colors.hint_border)
}
_ => (
status_colors.ignored_background,
status_colors.ignored_border,
),
}
})?;
let parsed_content = cx
.new_window_entity(|_window, cx| {
let status_colors = cx.theme().status();
match local_diagnostic.diagnostic.severity {
DiagnosticSeverity::ERROR => {
background_color = Some(status_colors.error_background);
border_color = Some(status_colors.error_border);
}
DiagnosticSeverity::WARNING => {
background_color = Some(status_colors.warning_background);
border_color = Some(status_colors.warning_border);
}
DiagnosticSeverity::INFORMATION => {
background_color = Some(status_colors.info_background);
border_color = Some(status_colors.info_border);
}
DiagnosticSeverity::HINT => {
background_color = Some(status_colors.hint_background);
border_color = Some(status_colors.hint_border);
}
_ => {
background_color = Some(status_colors.ignored_background);
border_color = Some(status_colors.ignored_border);
}
};
Markdown::new_text(SharedString::new(text), cx)
})
.new(|cx| Markdown::new_text(SharedString::new(text), cx))
.ok();
let subscription = this
.update(cx, |_, cx| {
if let Some(parsed_content) = &parsed_content {
Some(cx.observe(parsed_content, |_, _, cx| cx.notify()))
} else {
None
}
})
.ok()
.flatten();
Some(DiagnosticPopover {
local_diagnostic,
parsed_content,
@@ -347,6 +369,7 @@ fn show_hover(
background_color,
keyboard_grace: Rc::new(RefCell::new(ignore_timeout)),
anchor: Some(anchor),
_subscription: subscription,
})
} else {
None
@@ -399,6 +422,16 @@ fn show_hover(
}];
let parsed_content = parse_blocks(&blocks, &language_registry, None, cx).await;
let scroll_handle = ScrollHandle::new();
let subscription = this
.update(cx, |_, cx| {
if let Some(parsed_content) = &parsed_content {
Some(cx.observe(parsed_content, |_, _, cx| cx.notify()))
} else {
None
}
})
.ok()
.flatten();
info_popovers.push(InfoPopover {
symbol_range: RangeInEditor::Text(range),
parsed_content,
@@ -406,6 +439,7 @@ fn show_hover(
scroll_handle,
keyboard_grace: Rc::new(RefCell::new(ignore_timeout)),
anchor: Some(anchor),
_subscription: subscription,
})
}
@@ -439,6 +473,16 @@ fn show_hover(
let parsed_content = parse_blocks(&blocks, &language_registry, language, cx).await;
let scroll_handle = ScrollHandle::new();
hover_highlights.push(range.clone());
let subscription = this
.update(cx, |_, cx| {
if let Some(parsed_content) = &parsed_content {
Some(cx.observe(parsed_content, |_, _, cx| cx.notify()))
} else {
None
}
})
.ok()
.flatten();
info_popovers.push(InfoPopover {
symbol_range: RangeInEditor::Text(range),
parsed_content,
@@ -446,6 +490,7 @@ fn show_hover(
scroll_handle,
keyboard_grace: Rc::new(RefCell::new(ignore_timeout)),
anchor: Some(anchor),
_subscription: subscription,
});
}
@@ -659,7 +704,7 @@ pub fn open_markdown_url(link: SharedString, window: &mut Window, cx: &mut App)
cx.open_url(&link);
}
#[derive(Default, Debug)]
#[derive(Default)]
pub struct HoverState {
pub info_popovers: Vec<InfoPopover>,
pub diagnostic_popover: Option<DiagnosticPopover>,
@@ -741,7 +786,6 @@ impl HoverState {
}
}
#[derive(Debug, Clone)]
pub(crate) struct InfoPopover {
pub(crate) symbol_range: RangeInEditor,
pub(crate) parsed_content: Option<Entity<Markdown>>,
@@ -749,6 +793,7 @@ pub(crate) struct InfoPopover {
pub(crate) scrollbar_state: ScrollbarState,
pub(crate) keyboard_grace: Rc<RefCell<bool>>,
pub(crate) anchor: Option<Anchor>,
_subscription: Option<Subscription>,
}
impl InfoPopover {
@@ -759,7 +804,7 @@ impl InfoPopover {
cx: &mut Context<Editor>,
) -> AnyElement {
let keyboard_grace = Rc::clone(&self.keyboard_grace);
let mut d = div()
div()
.id("info_popover")
.elevation_2(cx)
// Prevent a mouse down/move on the popover from being propagated to the editor,
@@ -769,11 +814,9 @@ impl InfoPopover {
let mut keyboard_grace = keyboard_grace.borrow_mut();
*keyboard_grace = false;
cx.stop_propagation();
});
if let Some(markdown) = &self.parsed_content {
d = d
.child(
})
.when_some(self.parsed_content.clone(), |this, markdown| {
this.child(
div()
.id("info-md-container")
.overflow_y_scroll()
@@ -782,19 +825,16 @@ impl InfoPopover {
.p_2()
.track_scroll(&self.scroll_handle)
.child(
MarkdownElement::new(
markdown.clone(),
hover_markdown_style(window, cx),
)
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
})
.on_url_click(open_markdown_url),
MarkdownElement::new(markdown, hover_markdown_style(window, cx))
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
})
.on_url_click(open_markdown_url),
),
)
.child(self.render_vertical_scrollbar(cx));
}
d.into_any_element()
.child(self.render_vertical_scrollbar(cx))
})
.into_any_element()
}
pub fn scroll(&self, amount: &ScrollAmount, window: &mut Window, cx: &mut Context<Editor>) {
@@ -841,14 +881,14 @@ impl InfoPopover {
}
}
#[derive(Debug, Clone)]
pub struct DiagnosticPopover {
pub(crate) local_diagnostic: DiagnosticEntry<Anchor>,
parsed_content: Option<Entity<Markdown>>,
border_color: Option<Hsla>,
background_color: Option<Hsla>,
border_color: Hsla,
background_color: Hsla,
pub keyboard_grace: Rc<RefCell<bool>>,
pub anchor: Option<Anchor>,
_subscription: Option<Subscription>,
}
impl DiagnosticPopover {
@@ -859,53 +899,7 @@ impl DiagnosticPopover {
cx: &mut Context<Editor>,
) -> AnyElement {
let keyboard_grace = Rc::clone(&self.keyboard_grace);
let mut markdown_div = div().py_1().px_2();
if let Some(markdown) = &self.parsed_content {
let settings = ThemeSettings::get_global(cx);
let mut base_text_style = window.text_style();
base_text_style.refine(&TextStyleRefinement {
font_family: Some(settings.ui_font.family.clone()),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: Some(settings.ui_font_size(cx).into()),
color: Some(cx.theme().colors().editor_foreground),
background_color: Some(gpui::transparent_black()),
..Default::default()
});
let markdown_style = MarkdownStyle {
base_text_style,
selection_background_color: { cx.theme().players().local().selection },
link: TextStyleRefinement {
underline: Some(gpui::UnderlineStyle {
thickness: px(1.),
color: Some(cx.theme().colors().editor_foreground),
wavy: false,
}),
..Default::default()
},
..Default::default()
};
markdown_div = markdown_div.child(
MarkdownElement::new(markdown.clone(), markdown_style)
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
})
.on_url_click(open_markdown_url),
);
}
if let Some(background_color) = &self.background_color {
markdown_div = markdown_div.bg(*background_color);
}
if let Some(border_color) = &self.border_color {
markdown_div = markdown_div
.border_1()
.border_color(*border_color)
.rounded_lg();
}
let diagnostic_div = div()
div()
.id("diagnostic")
.block()
.max_h(max_size.height)
@@ -927,9 +921,51 @@ impl DiagnosticPopover {
*keyboard_grace = false;
cx.stop_propagation();
})
.child(markdown_div);
diagnostic_div.into_any_element()
.when_some(self.parsed_content.clone(), |this, markdown| {
this.child(
div()
.py_1()
.px_2()
.child(
MarkdownElement::new(markdown, {
let settings = ThemeSettings::get_global(cx);
let mut base_text_style = window.text_style();
base_text_style.refine(&TextStyleRefinement {
font_family: Some(settings.ui_font.family.clone()),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: Some(settings.ui_font_size(cx).into()),
color: Some(cx.theme().colors().editor_foreground),
background_color: Some(gpui::transparent_black()),
..Default::default()
});
MarkdownStyle {
base_text_style,
selection_background_color: {
cx.theme().players().local().selection
},
link: TextStyleRefinement {
underline: Some(gpui::UnderlineStyle {
thickness: px(1.),
color: Some(cx.theme().colors().editor_foreground),
wavy: false,
}),
..Default::default()
},
..Default::default()
}
})
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
})
.on_url_click(open_markdown_url),
)
.bg(self.background_color)
.border_1()
.border_color(self.border_color)
.rounded_lg(),
)
})
.into_any_element()
}
}
@@ -1069,7 +1105,7 @@ mod tests {
editor.hover_state.info_popovers.len(),
1,
"Expected exactly one hover but got: {:?}",
editor.hover_state.info_popovers
editor.hover_state.info_popovers.len()
);
let rendered_text = editor
.hover_state
@@ -1109,7 +1145,7 @@ mod tests {
editor.hover_state.info_popovers.len(),
1,
"Expected exactly one hover but got: {:?}",
editor.hover_state.info_popovers
editor.hover_state.info_popovers.len()
);
let rendered_text = editor
.hover_state
@@ -1204,7 +1240,7 @@ mod tests {
editor.hover_state.info_popovers.len(),
1,
"Expected exactly one hover but got: {:?}",
editor.hover_state.info_popovers
editor.hover_state.info_popovers.len()
);
let rendered_text = editor
.hover_state
@@ -1269,7 +1305,7 @@ mod tests {
editor.hover_state.info_popovers.len(),
0,
"Expected no hovers but got but got: {:?}",
editor.hover_state.info_popovers
editor.hover_state.info_popovers.len()
);
});
@@ -1293,7 +1329,7 @@ mod tests {
editor.hover_state.info_popovers.len(),
1,
"Expected exactly one hover but got: {:?}",
editor.hover_state.info_popovers
editor.hover_state.info_popovers.len()
);
let rendered_text = editor
@@ -1351,7 +1387,7 @@ mod tests {
editor.hover_state.info_popovers.len(),
1,
"Expected exactly one hover but got: {:?}",
editor.hover_state.info_popovers
editor.hover_state.info_popovers.len()
);
let rendered_text = editor
.hover_state
@@ -1417,7 +1453,7 @@ mod tests {
editor.hover_state.info_popovers.len(),
1,
"Expected exactly one hover but got: {:?}",
editor.hover_state.info_popovers
editor.hover_state.info_popovers.len()
);
let rendered_text = editor
.hover_state
@@ -1794,7 +1830,7 @@ mod tests {
assert!(
hover_state.diagnostic_popover.is_none() && hover_state.info_popovers.len() == 1
);
let popover = hover_state.info_popovers.first().cloned().unwrap();
let popover = hover_state.info_popovers.first().unwrap();
let buffer_snapshot = editor.buffer().update(cx, |buffer, cx| buffer.snapshot(cx));
assert_eq!(
popover.symbol_range,
@@ -1849,7 +1885,7 @@ mod tests {
assert!(
hover_state.diagnostic_popover.is_none() && hover_state.info_popovers.len() == 1
);
let popover = hover_state.info_popovers.first().cloned().unwrap();
let popover = hover_state.info_popovers.first().unwrap();
let buffer_snapshot = editor.buffer().update(cx, |buffer, cx| buffer.snapshot(cx));
assert_eq!(
popover.symbol_range,

View File

@@ -555,12 +555,12 @@ impl InlayHintCache {
/// Completely forget of certain excerpts that were removed from the multibuffer.
pub(super) fn remove_excerpts(
&mut self,
excerpts_removed: Vec<ExcerptId>,
excerpts_removed: &[ExcerptId],
) -> Option<InlaySplice> {
let mut to_remove = Vec::new();
for excerpt_to_remove in excerpts_removed {
self.update_tasks.remove(&excerpt_to_remove);
if let Some(cached_hints) = self.hints.remove(&excerpt_to_remove) {
self.update_tasks.remove(excerpt_to_remove);
if let Some(cached_hints) = self.hints.remove(excerpt_to_remove) {
let cached_hints = cached_hints.read();
to_remove.extend(cached_hints.ordered_hints.iter().copied());
}

View File

@@ -85,6 +85,10 @@ pub fn lsp_tasks(
.map(|(name, buffer_ids)| {
let buffers = buffer_ids
.iter()
.filter(|&&buffer_id| match for_position {
Some(for_position) => for_position.buffer_id == Some(buffer_id),
None => true,
})
.filter_map(|&buffer_id| project.read(cx).buffer_for_id(buffer_id, cx))
.collect::<Vec<_>>();
language_server_for_buffers(project.clone(), name.clone(), buffers, cx)

View File

@@ -1,10 +1,9 @@
use crate::CopyAndTrim;
use crate::actions::FormatSelections;
use crate::{
Copy, CopyPermalinkToLine, Cut, DisplayPoint, DisplaySnapshot, Editor, EditorMode,
Copy, CopyAndTrim, CopyPermalinkToLine, Cut, DisplayPoint, DisplaySnapshot, Editor,
FindAllReferences, GoToDeclaration, GoToDefinition, GoToImplementation, GoToTypeDefinition,
Paste, Rename, RevealInFileManager, SelectMode, ToDisplayPoint, ToggleCodeActions,
actions::Format, selections_collection::SelectionsCollection,
actions::{Format, FormatSelections},
selections_collection::SelectionsCollection,
};
use gpui::prelude::FluentBuilder;
use gpui::{Context, DismissEvent, Entity, Focusable as _, Pixels, Point, Subscription, Window};
@@ -123,7 +122,7 @@ pub fn deploy_context_menu(
}
// Don't show context menu for inline editors
if editor.mode() != EditorMode::Full {
if !editor.mode().is_full() {
return;
}

View File

@@ -108,7 +108,7 @@ pub(crate) fn build_editor(
window: &mut Window,
cx: &mut Context<Editor>,
) -> Editor {
Editor::new(EditorMode::Full, buffer, None, window, cx)
Editor::new(EditorMode::full(), buffer, None, window, cx)
}
pub(crate) fn build_editor_with_project(
@@ -117,5 +117,5 @@ pub(crate) fn build_editor_with_project(
window: &mut Window,
cx: &mut Context<Editor>,
) -> Editor {
Editor::new(EditorMode::Full, buffer, Some(project), window, cx)
Editor::new(EditorMode::full(), buffer, Some(project), window, cx)
}

Some files were not shown because too many files have changed in this diff Show More