Compare commits

..

66 Commits

Author SHA1 Message Date
Michael Sloan
4a4ee4fed7 Remove cli example 2025-09-17 18:13:51 -06:00
Michael Sloan
ea4bf46a36 Return 0 results when declaration count limit exceeded
Co-authored-by: Agus <agus@zed.dev>
2025-09-17 18:10:32 -06:00
Michael Sloan
05545abab6 Checkpoint
Co-authored-by: Agus <agus@zed.dev>
2025-09-17 18:10:32 -06:00
Michael Sloan
a85608566d Checkpoint
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-09-17 17:07:30 -06:00
Michael Sloan
69af5261ea Renames + fixes
Co-authored-by: Agus <agus@zed.dev>
2025-09-17 15:02:57 -06:00
Michael Sloan
b9e2f61a38 Expand declaration ranges to line boundaries and truncate, store text for file declarations
Co-authored-by: Agus <agus@zed.dev>
2025-09-17 15:00:29 -06:00
Michael Sloan
38bbb497dd Rename definition->declaration
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-09-17 12:57:49 -06:00
Agus Zubiaga
0cc7b4a93c Simple call site snippet test 2025-09-17 12:47:36 -03:00
Agus
cc32bfdfdf Checkpoint: Get score_snippets to compile
Co-Authored-By: Finn <finn@zed.dev>
2025-09-17 11:48:22 -03:00
Michael Sloan
50de8ddc28 Progress on porting scored_declaration.rs 2025-09-17 02:11:58 -06:00
Michael Sloan
f770011d7f Add WIP zeta2 request types 2025-09-17 01:56:48 -06:00
Michael Sloan
f2a6b57909 Copy in experimental cli / declaration scoring code
Co-authored-by: Oleksiy <oleksiy@zed.dev>
2025-09-17 01:55:06 -06:00
Michael Sloan
96b67ac70e Add text similarity metrics 2025-09-17 01:38:02 -06:00
Michael Sloan
64d362cbce edit prediction: Initial implementation of Tree-sitter index (not yet used) (#38301)
Release Notes:

- N/A

---------

Co-authored-by: Agus <agus@zed.dev>
Co-authored-by: oleksiy <oleksiy@zed.dev>
2025-09-17 07:25:14 +00:00
Kyrilasa
5d561aa494 agent_ui: Fix agent panel insertion to use cursor position (#38253)
Fix agent panel insertion to use cursor position

Closes #38216

Release Notes:
- Fixed agent panel text insertion to respect cursor position instead of
always appending to the end

## Before

[before.webm](https://github.com/user-attachments/assets/684d3cbe-4710-4724-8d2d-ac08f430dea8)

## After

[output.webm](https://github.com/user-attachments/assets/d1122d99-4efb-4a24-a408-db128814f98c)
2025-09-17 07:10:21 +00:00
Lukas Wirth
4ee2daeded markdown: Fix indented codeblocks having incorrect content ranges (#38225)
Closes https://github.com/zed-industries/zed/issues/37743

Release Notes:

- Fixed agent panel panicking when streaming indented codeblocks from
agent output
2025-09-17 06:48:47 +00:00
Cole Miller
c27d8e0c7a editor: Don't pull diagnostics on excerpts change in diagnostics editors (#38212)
This can lead to an infinite regress when using a language server that
supports pull diagnostics, since the excerpts for the diagnostics editor
are set based on the project's diagnostics.

Closes #36772

Release Notes:

- Fixed a bug that could cause duplicated diagnostics with some language
servers.
2025-09-16 21:58:24 -04:00
Marshall Bowers
f6c5c68751 collab: Remove user backfiller (#38291)
This PR removes the user backfiller from Collab.

Release Notes:

- N/A
2025-09-16 22:53:44 +00:00
Marshall Bowers
74e5b848ff cloud_llm_client: Make default_model and default_fast_model optional (#38288)
This PR makes the `default_model` and `default_fast_model` fields
optional on the `ListModelsResponse`.

Release Notes:

- N/A
2025-09-16 22:24:03 +00:00
Smit Barmase
ee399ebccf macOS: Make it easier to debug NSAutoFillHeuristicControllerEnabled (#38285)
Uses `setObject` instead of `registerDefaults`, so that it can be read
with `defaults read dev.zed.Zed`. Still can be overrided.

Release Notes:

- N/A
2025-09-17 03:49:47 +05:30
Max Brunsfeld
54c82f2732 Windows: Unminimize a window when activating it (#38287)
Closes #36287

Release Notes:

- Windows: Fixed an issue where a Zed window would stay minimized when
opening an existing file in that window via the Zed CLI.
2025-09-16 22:12:02 +00:00
Uwe Krause
e14a4ab90d Fix small spelling mistakes (#38284)
Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-09-16 21:58:40 +00:00
David Kleingeld
0343b5ff06 Add new crate denoise required by audio (#38217)
The audio crate will use the denoise crate to remove background noises
from microphone input.

We intent to contribute this to rodio. Before that can happen a PR needs
to land in candle. Until then this lives here.

Uses a candle fork which removes the dependency on `protoc` and has the PR's mentioned above already applied.

Release Notes:

- N/A

---------

Co-authored-by: Mikayla <mikayla@zed.dev>
2025-09-16 21:49:26 +00:00
Marshall Bowers
26202e5af2 language_models: Use message field from Cloud error responses, if present (#38286)
This PR updates the Cloud language model provider to use the `message`
field from the Cloud error response, if it is present.

Previously we would always show the entire JSON payload in the error
message, but with this change we can show just the user-facing `message`
the error response is in a shape that we recognize.

Release Notes:

- N/A
2025-09-16 21:45:25 +00:00
George Waters
ee912366a3 Check if virtual environment is in worktree root (#37510)
The problem from issue #37509 comes from local virtual environments
created with certain approaches (including the 'simple' way of `python
-m venv`) not having a `.project` file with the path to the project's
root directory. When the toolchains are sorted, a virtual environment in
the project is not treated as being for that project and therefore is
not prioritized.

With this change, if a toolchain does not have a `project` associated
with it, we check to see if it is a virtual environment, and if it is we
use its parent directory as the `project`. This will make it the top
priority (i.e. the default) if there are no other virtual environments
for a project, which is what should be expected.

Closes #37509

Release Notes:

- Improved python toolchain prioritization of local virtual
environments.
2025-09-16 21:30:32 +02:00
David Kleingeld
673a98a277 Fix a number of spelling mistakes (#38281)
My pre push hooks keep failing on these. This is easier then disabling
and re-enabling those hooks all the time :)

Closes #ISSUE

Release Notes:

- N/A
2025-09-16 19:18:39 +00:00
VBB
5674445a61 Move keyboard shortcut for pane::GoForward (#38221)
Move keyboard shortcut for `pane:GoForward` so it's going to be
displayed as a shortcut hint in UI. Currently `Forward` is shown as a
hint, which isn't consistent with `GoBack` action and can be confusing.

Release Notes: 

- Improved the displayed keybinding for the `pane::GoForward` action on
Linux.
2025-09-16 18:33:55 +02:00
Jason Lee
53513cab23 Fix filled button hover background (#38235)
Release Notes:

- Fixed filled button hover background.

## Before


https://github.com/user-attachments/assets/fbc75890-d1a4-4a0c-b54e-ca2c7e63a661

## After


https://github.com/user-attachments/assets/a3595b01-e143-4cd0-8bc4-90db9ccfbf74


This appears to be a minor calculation error, not an intentional use of
this value.

If we pass `0.92` to `fade_out`, the calculated will be `alpha: 0.08`.

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-09-16 14:47:10 +00:00
Smit Barmase
e885a939ba git_ui: Add tooltip for branch picker items (#38261)
Closes #38256

<img width="300" alt="image"
src="https://github.com/user-attachments/assets/5018951f-0f1b-4d5d-b59d-5b5266380e43"
/>


Release Notes:

- Added tooltip to Git branch picker items, making it easier to
distinguish long branch names.
2025-09-16 20:06:32 +05:30
Smit Barmase
a01a2ed0e0 languages: Add Tailwind CSS support for TypeScript (#38254)
Closes #37028

I noticed many projects use Tailwind in plain TypeScript (.ts) files, so
it makes sense to support them out of the box, alongside .js and .tsx
files we already handle. For example, see
[supabase](https://github.com/supabase/supabase/blob/master/packages/ui/src/lib/theme/defaultTheme.ts).

Note: You’ll still need to add `"classFunctions": ["cva", "cx"],`
manually for Tailwind completions to work in `cva` type methods. This is
because you don’t want completions on every string, only in specific
methods or regex matches. This is documented.

Release Notes:

- Added out-of-the-box support for Tailwind completions in `.ts` files.
2025-09-16 20:06:14 +05:30
Nathan Sobo
af3bc45a26 Drop ellipses from About Zed menu item (#38211)
Follow the macOS app style guideline.

Release Notes:

- N/A
2025-09-16 08:06:16 -06:00
Lukas Wirth
173074f248 search: Re-issue project search if search query is stale on replacement (#38251)
Closes https://github.com/zed-industries/zed/issues/34897

Release Notes:

- Fixed project search replacement replacing stale search results
2025-09-16 12:12:45 +00:00
Ben Brandt
a7cb64c64d Remove unused agent server settings module (#38250)
This was no longer in the module graph (the settings moved elsewhere) so
cleaning up the dead code.

Release Notes:

- N/A
2025-09-16 12:11:06 +00:00
Lukas Wirth
c6472fd7a8 agent_settings: Fix schema validation rejecting custom llm providers (#38248)
Closes https://github.com/zed-industries/zed/issues/37989

Release Notes:

- N/A
2025-09-16 10:23:49 +00:00
Ben Brandt
c0710fa8ca agent_servers: Set proxy env for all ACP agents (#38247)
- Use ProxySettings::proxy_url to read from settings or env 
- Export HTTP(S)_PROXY and NO_PROXY for agent CLIs 
- Add read_no_proxy_from_env and move parsing from main

Closes https://github.com/zed-industries/claude-code-acp/issues/46

Release Notes:

- acp: Pass proxy settings through to all ACP agents
2025-09-16 10:18:10 +00:00
Lukas Wirth
f321d02207 auto_update: Show update error on hover and open logs on click (#38241)
Release Notes:

- Improved error reporting when auto-updating fails
2025-09-16 08:07:02 +00:00
Lukas Wirth
1c09985fb3 worktree: Add more context to log_err calls (#38239)
Release Notes:

- N/A
2025-09-16 07:31:28 +00:00
Marshall Bowers
d986077592 client: Hide usage when not available (#38234)
Release Notes:

- N/A
2025-09-16 02:30:56 +00:00
Danilo Leal
555b6ee4e5 agent: Add small UI fixes (#38231)
Release Notes:

- N/A
2025-09-16 01:06:45 +00:00
Owen Kelly
6446963a0c agent: Make assistant panel input size configurable (#37975)
Release Notes:

- Added the `agent. message_editor_min_lines `setting to allow users to
customize the agent panel message editor default size by using a
different minimum number of lines.

<img width="800" height="1316" alt="Screenshot 2025-09-11 at 5 47 18 pm"
src="https://github.com/user-attachments/assets/20990b90-c4f9-4f5c-af59-76358642a273"
/>

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-09-16 00:27:25 +00:00
Finn Evers
ceb907e0dc onboarding: Add scrollbar to pages (#38093)
Closes #37214

This PR adds a scrollbar to the onboarding view and additionally ensures
the scroll state is properly reset when switching between the different
pages each time.

Release Notes:

- N/A
2025-09-15 19:55:02 -03:00
Alvaro Parker
3dbccc828e Fix hover element on ACP thread mode selector (#38204)
Closes #38197

This will render `^ click to also ...` on MacOS and `Ctrl + click to
also ...` on Windows and Linux.

|Before|After|
|-|-|
| <img width="683" height="197" alt="image"
src="https://github.com/user-attachments/assets/09909f1b-3163-40d1-b025-4eb9b159fbf3"
/> | <img width="683" height="197" alt="image"
src="https://github.com/user-attachments/assets/47d0290d-afa2-4b1b-a588-adfe3130d0b1"
/>|

On Mac: 

<img width="683" height="197" alt="image"
src="https://github.com/user-attachments/assets/f63103b5-1ceb-4193-ae6c-be55b97106e0"
/>

Release Notes:

- Fixed keymap hint when hovering over mode selector

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-09-15 22:48:04 +00:00
Michael Sloan
853e625259 edit predictions: Add new excerpt logic (not yet used) (#38226)
Release Notes:

- N/A

---------

Co-authored-by: agus <agus@zed.dev>
2025-09-15 16:29:58 -06:00
Kenny
0784bb8192 docs: Add "Copy as Markdown" button to toolbar (#38218)
## Summary
Adds a "Copy as Markdown" button to the documentation toolbar that
allows users to easily copy the raw markdown content of any
documentation page.

This feature is inspired by similar implementations on sites like
[Better Auth docs](https://www.better-auth.com/docs/installation) and
[Cloudflare Workers docs](https://developers.cloudflare.com/workers/)
which provide easy ways for users to copy documentation content.

## Features
- **Button placement**: Positioned between theme toggle and search icon
for optimal UX
- **Content fetching**: Retrieves raw markdown from GitHub's API for the
current page
- **Consistent styling**: Matches existing toolbar button patterns

## Test plan
- [x] Copy functionality works on all documentation pages
- [x] Toast notifications appear and disappear correctly
- [x] Button icon animations work properly (spinner → checkmark → copy)
- [x] Styling matches other toolbar buttons
- [x] Works in both light and dark themes

## Screenshots
The button appears as a copy icon between the theme and search buttons
in the left toolbar.
<img width="798" height="295" alt="image"
src="https://github.com/user-attachments/assets/37d41258-d71b-40f8-b8fe-16eaa46b8d7f"
/>
<img width="1628" height="358" alt="image"
src="https://github.com/user-attachments/assets/fc45bc04-a290-4a07-8d1a-a010a92be033"
/>

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-09-15 21:57:23 +00:00
Mikayla Maki
9046091164 Add a test that would have caught the bug last week (#38222)
This adds a test to make sure that the default value of the auto update
setting is always true. We manually re-applied the broken code from last
week, and confirmed that this test fails with that code.

Release Notes:

- N/A

---------

Co-authored-by: Ben Kunkle <ben@zed.dev>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-09-15 18:10:28 +00:00
Danilo Leal
6384966ab5 agent: Improve some items in the settings view UI (#38199)
All described in each commit; mostly small things, simplifying/clearing
up the UI.

Release Notes:

- N/A
2025-09-15 13:35:39 -03:00
Ben Kunkle
8b9c74726a docs: Call out Omarchy specifically in regards to issues with amdvlk (#38214)
Closes #28851


Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-09-15 16:03:45 +00:00
Kaan Kuscu
63586ff2e4 Add new injections for Go (#37605)
support for injecting sql, json, yaml, xml, html, css, js, lua and csv
value

if you use `/* lang */` before string literals, highlights them

**Example:**

```go
const sqlQuery = /* sql */ "SELECT * FROM users;" // highlights as SQL code
```

<img width="629" height="46" alt="Screenshot 2025-09-05 at 06 17 49"
src="https://github.com/user-attachments/assets/80f404d8-0a47-428d-bdb5-09fbee502cfe"
/>


Closes #ISSUE

Release Notes:

- Go: Added support for injecting sql, json, yaml, xml, html, css, js, lua and csv language highlights into string literals, when they are prefixed with `/* lang */`

**Example:**

```go
const sqlQuery = /* sql */ "SELECT * FROM users;" // Will be highlighted as SQL code
```
2025-09-15 15:51:03 +00:00
Conrad Irwin
35e5aa4e71 Re-add VSCode syntax node motions (#38208)
Closes #ISSUE

Release Notes:

- (preview only) restored ctrl-shift-{left,right} for Larger/Smaller
syntax node. This is VSCode's default and avoids the breaking change
from #37874
2025-09-15 09:18:07 -06:00
Richard Feldman
7ea94a32be Create failed tool call entries for missing tools (#38207)
Release Notes:

- When an agent requests a tool that doesn't exist, this is now treated
as a failed tool call instead of stopping the thread.
2025-09-15 15:07:14 +00:00
Piotr Osiewicz
6d6c3d648a lsp: Fix overnotifying about open buffers for unrelated servers (#38196)
Do not report all open buffers to new instances of the same language
server, as they can respond with ~spurious errors.

This regressed in  https://github.com/zed-industries/zed/pull/34142

Closes https://github.com/zed-industries/zed/issues/35017

Release Notes:

- Fixed Zed overly notifying language servers about open buffers, which
could've resulted in confusing errors in multi-language projects (in
e.g. Go).
2025-09-15 15:20:04 +02:00
Hichem
53b2f37452 Enhance layout and styling of tool list in AgentConfiguration (#38195)
Improve the layout and styling of the tool list in the
AgentConfiguration, ensuring better responsiveness and visual clarity.

closes #38194

<img width="1270" height="738" alt="image"
src="https://github.com/user-attachments/assets/86345e57-4fd0-43b8-8b8d-6209dc635dfb"
/>

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-09-15 13:00:22 +00:00
Lukas Wirth
92b946e8e5 acp_thread: Properly use project terminal API (#38186)
Closes https://github.com/zed-industries/zed/issues/35603

Release Notes:

- Fixed shell selection for terminal tool
2025-09-15 12:43:41 +00:00
Hakan Ensari
e9b4f59e0f Fix external agent authentication with spaces in paths (#38175)
This fixes terminal-based authentication for external ACP agents (Claude
Code, Gemini CLI) when file paths contain spaces, like "Application
Support" on macOS and "Program Files" on Windows.

When users click authentication buttons or type `/login`, they get
errors like `Cannot find module '/Users/username/Library/Application'`
because the path gets split at the space.

The fix removes redundant `shlex::try_quote` calls from
`spawn_external_agent_login`. These were causing double-quoting since
the terminal spawning code already handles proper shell escaping.

Added a test to verify paths with spaces aren't pre-quoted.

Release Notes:

- Fixed external agent authentication failures when file paths contain
spaces

---------

Co-authored-by: Hakan Ensari <hakanensari@users.noreply.github.com>
Co-authored-by: Claude <claude@anthropic.com>
2025-09-15 10:20:27 +00:00
Finn Evers
989adde57b Add scrollbars to markdown preview and syntax tree view (#38183)
Closes https://github.com/zed-industries/zed/issues/38141

This PR adds default scrollbars to the markdown preview and syntax tree
view.

Release Notes:

- Added scrollbars to the markdown preview and syntax tree view.
2025-09-15 10:17:27 +00:00
Lukas Wirth
393d6787a3 terminal: Do not auto close shell terminals if they error out (#38182)
Closes https://github.com/zed-industries/zed/issues/38134

This also reduces an annoying level of shell nesting

Release Notes:

- N/A
2025-09-15 10:09:25 +00:00
Finn Evers
4a582504d4 ui: Follow-up improvements to the scrollbar component (#38178)
This PR lands some more improvements to the reworked scrollbars.

Namely, we will now explicitly paint a background in cases where a track
is requested for the specific scrollbar, which prevents a flicker, and
also reserve space only if space actually needs to be reserved. The
latter was a regression introduced by the recent changes.

Release Notes:

- N/A
2025-09-15 09:53:33 +00:00
Smit Barmase
cfb2925169 macOS: Disable NSAutoFillHeuristicController on macOS 26 (#38179)
Closes #33182

From
https://github.com/zed-industries/zed/issues/33182#issuecomment-3289846957,
thanks @mitchellh.

Release Notes:

- Fixed an issue where scrolling could sometimes feel choppy on macOS
26.
2025-09-15 15:17:27 +05:30
Lukas Wirth
14f4e867aa terminal: Do not auto close shell terminals if they error out (#38180)
cc https://github.com/zed-industries/zed/issues/38134
Release Notes:

- N/A
2025-09-15 09:43:05 +00:00
Ben Brandt
4d54ccf494 agent_servers: Let Gemini CLI know it is running in Zed (#38058)
By passing through Zed as the surface, Gemini can know which editor it
is running in.

Release Notes:

- N/A
2025-09-15 08:30:46 +00:00
Tim Vermeulen
5b1c87b6a6 Fix incorrect ANSI color contrast adjustment on some background colors (#38155)
The `Hsla` -> `Rgba` conversion sometimes results in negative (but very
close to 0) color components due to floating point imprecision, causing
the `.powf(constants.main_trc)` computations in the `srgb_to_y` function
to evaluate to `NaN`. This propagates to `apca_contrast` which then
makes `ensure_minimum_contrast` unconditionally return `black` for
certain background colors. This PR addresses this by clamping the rgba
components in `impl From<Hsla> for Rgba` to 0-1.

Before/after:
<img width="1044" height="48" alt="before"
src="https://github.com/user-attachments/assets/771f809f-3959-43e9-8ed0-152ff284cef8"
/>
<img width="1044" height="49" alt="after"
src="https://github.com/user-attachments/assets/5fd6ae25-1ef0-4334-90d1-7fc5acf48958"
/>

Release Notes:

- Fixed an issue where ANSI colors were incorrectly adjusted to improve
contrast on some background colors
2025-09-15 07:52:56 +00:00
Vladimir Varankin
0fef17baa2 Hide BasedPyright banner in toolbar when dismissed (#38135)
This PR fixes the `BasedPyrightBanner`, making sure the banner is
completely hidden in the toolbar, when it was dismissed, or it's not
installed.

Without the fix, the banner still occupies some space in the toolbar,
making the UI looks inconsistent when editing a Python file. The bug is
**especially prominent** when the toolbar is hidden in the user's
settings (see below).

_Banner is shown_
<img width="1470" height="254" alt="Screenshot 2025-09-14 at 11 36 37"
src="https://github.com/user-attachments/assets/1415b075-0660-41ed-8069-c2318ac3a7cf"
/>

_Banner dismissed_
<img width="1470" height="207" alt="Screenshot 2025-09-14 at 11 36 44"
src="https://github.com/user-attachments/assets/828a3fba-5c50-4aba-832c-3e0cc6ed464b"
/>

_Banner dismissed (and the toolbar is hidden)_
<img width="1470" height="177" alt="Screenshot 2025-09-14 at 12 07 25"
src="https://github.com/user-attachments/assets/41aa5861-87df-491f-ac7e-09fc1558dd84"
/>

Closes n/a

Release Notes:

- Fixed the basedpyright onboarding banner
2025-09-15 09:43:04 +02:00
Umesh Yadav
526196917b language_models: Add support for API key to Ollama provider (#34110)
Closes https://github.com/zed-industries/zed/issues/19491

Release Notes:

- Ollama: Added configuration of URL and API key for remote Ollama provider.

---------

Signed-off-by: Umesh Yadav <git@umesh.dev>
Co-authored-by: Peter Tripp <peter@zed.dev>
Co-authored-by: Oliver Azevedo Barnes <oliver@liquidvoting.io>
Co-authored-by: Michael Sloan <michael@zed.dev>
2025-09-15 06:34:26 +00:00
Michael Sloan
a598fbaa73 ai: Show "API key configured for {URL}" for non-default urls (#38170)
Followup to #38163, also makes some changes intended to be included in
that PR.

Release Notes:

- N/A
2025-09-15 05:49:25 +00:00
Michael Sloan
634ae72cad Misc cleanup + clear language model provider API key editors when API keys are submitted (#38165)
Followup to #38163 along with some other misc cleanups

Release Notes:

- N/A
2025-09-15 05:08:38 +00:00
Michael Sloan
98edf1bf0b Reload API keys when URLs configured for LLM providers change (#38163)
Three motivations for this:

* Changing provider URL could cause credentials for the prior URL to be
sent to the new URL.
* The UI is in a misleading state after URL change - it shows a
configured API key, but on restart it will show no API key.
* #34110 will add support for both URL and key configuration for Ollama.
This is the first provider to have UI for setting the URL, and this
makes these issues show up more directly as odd UI interactions.

#37610 implemented something similar for the OpenAI and OpenAI
compatible providers. This extracts out some shared code, uses it in all
relevant providers, and adds more safety around key use.

I haven't tested all providers, but the per-provider changes were pretty
mechanical, so hopefully work properly.

Release Notes:

- Fixed handling of changes to LLM provider URL in settings to also load
the associated API key.
2025-09-15 03:36:24 +00:00
145 changed files with 7357 additions and 2443 deletions

2
.rules
View File

@@ -59,7 +59,7 @@ Trying to update an entity while it's already being updated must be avoided as t
When `read_with`, `update`, or `update_in` are used with an async context, the closure's return value is wrapped in an `anyhow::Result`.
`WeakEntity<T>` is a weak handle. It has `read_with`, `update`, and `update_in` methods that work the same, but always return an `anyhow::Result` so that they can fail if the entity no longer exists. This can be useful to avoid memory leaks - if entities have mutually recursive handles to eachother they will never be dropped.
`WeakEntity<T>` is a weak handle. It has `read_with`, `update`, and `update_in` methods that work the same, but always return an `anyhow::Result` so that they can fail if the entity no longer exists. This can be useful to avoid memory leaks - if entities have mutually recursive handles to each other they will never be dropped.
## Concurrency

615
Cargo.lock generated
View File

@@ -39,7 +39,6 @@ dependencies = [
"util",
"uuid",
"watch",
"which 6.0.3",
"workspace-hack",
]
@@ -301,6 +300,7 @@ dependencies = [
"futures 0.3.31",
"gpui",
"gpui_tokio",
"http_client",
"indoc",
"language",
"language_model",
@@ -416,7 +416,6 @@ dependencies = [
"serde_json",
"serde_json_lenient",
"settings",
"shlex",
"smol",
"streaming_diff",
"task",
@@ -689,6 +688,9 @@ name = "arbitrary"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223"
dependencies = [
"derive_arbitrary",
]
[[package]]
name = "arc-swap"
@@ -1024,7 +1026,6 @@ dependencies = [
"util",
"watch",
"web_search",
"which 6.0.3",
"workspace",
"workspace-hack",
"zlog",
@@ -2189,7 +2190,7 @@ dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"itertools 0.11.0",
"lazy_static",
"lazycell",
"log",
@@ -2687,6 +2688,53 @@ dependencies = [
"serde",
]
[[package]]
name = "candle-core"
version = "0.9.1"
source = "git+https://github.com/zed-industries/candle?branch=9.1-patched#724d75eb3deebefe83f2a7381a45d4fac6eda383"
dependencies = [
"byteorder",
"float8",
"gemm 0.17.1",
"half",
"memmap2",
"num-traits",
"num_cpus",
"rand 0.9.1",
"rand_distr",
"rayon",
"safetensors",
"thiserror 1.0.69",
"ug",
"yoke",
"zip 1.1.4",
]
[[package]]
name = "candle-nn"
version = "0.9.1"
source = "git+https://github.com/zed-industries/candle?branch=9.1-patched#724d75eb3deebefe83f2a7381a45d4fac6eda383"
dependencies = [
"candle-core",
"half",
"libc",
"num-traits",
"rayon",
"safetensors",
"serde",
"thiserror 1.0.69",
]
[[package]]
name = "candle-onnx"
version = "0.9.1"
source = "git+https://github.com/zed-industries/candle?branch=9.1-patched#724d75eb3deebefe83f2a7381a45d4fac6eda383"
dependencies = [
"candle-core",
"candle-nn",
"prost 0.12.6",
]
[[package]]
name = "cap-fs-ext"
version = "3.4.4"
@@ -4637,6 +4685,20 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da692b8d1080ea3045efaab14434d40468c3d8657e42abddfffca87b428f4c1b"
[[package]]
name = "denoise"
version = "0.1.0"
dependencies = [
"candle-core",
"candle-onnx",
"log",
"realfft",
"rodio",
"rustfft",
"thiserror 2.0.12",
"workspace-hack",
]
[[package]]
name = "der"
version = "0.6.1"
@@ -4668,6 +4730,17 @@ dependencies = [
"serde",
]
[[package]]
name = "derive_arbitrary"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "derive_more"
version = "0.99.19"
@@ -4823,7 +4896,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users 0.5.0",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
@@ -4981,6 +5054,25 @@ version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005"
[[package]]
name = "dyn-stack"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b"
dependencies = [
"bytemuck",
"reborrow",
]
[[package]]
name = "dyn-stack"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490bd48eb68fffcfed519b4edbfd82c69cbe741d175b84f0e0cbe8c57cbe0bdd"
dependencies = [
"bytemuck",
]
[[package]]
name = "ec4rs"
version = "1.2.0"
@@ -5042,6 +5134,36 @@ dependencies = [
"zeta",
]
[[package]]
name = "edit_prediction_context"
version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
"clap",
"collections",
"futures 0.3.31",
"gpui",
"indoc",
"itertools 0.14.0",
"language",
"log",
"ordered-float 2.10.1",
"pretty_assertions",
"project",
"regex",
"serde",
"serde_json",
"settings",
"slotmap",
"strum 0.27.1",
"text",
"tree-sitter",
"util",
"workspace-hack",
"zlog",
]
[[package]]
name = "editor"
version = "0.1.0"
@@ -5225,6 +5347,18 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3d8a32ae18130a3c84dd492d4215c3d913c3b07c6b63c2eb3eb7ff1101ab7bf"
[[package]]
name = "enum-as-inner"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "enumflags2"
version = "0.7.11"
@@ -5855,6 +5989,18 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
[[package]]
name = "float8"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4203231de188ebbdfb85c11f3c20ca2b063945710de04e7b59268731e728b462"
dependencies = [
"half",
"num-traits",
"rand 0.9.1",
"rand_distr",
]
[[package]]
name = "float_next_after"
version = "1.0.0"
@@ -6309,6 +6455,243 @@ dependencies = [
"thread_local",
]
[[package]]
name = "gemm"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32"
dependencies = [
"dyn-stack 0.10.0",
"gemm-c32 0.17.1",
"gemm-c64 0.17.1",
"gemm-common 0.17.1",
"gemm-f16 0.17.1",
"gemm-f32 0.17.1",
"gemm-f64 0.17.1",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 10.7.0",
"seq-macro",
]
[[package]]
name = "gemm"
version = "0.18.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451"
dependencies = [
"dyn-stack 0.13.0",
"gemm-c32 0.18.2",
"gemm-c64 0.18.2",
"gemm-common 0.18.2",
"gemm-f16 0.18.2",
"gemm-f32 0.18.2",
"gemm-f64 0.18.2",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 11.6.0",
"seq-macro",
]
[[package]]
name = "gemm-c32"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0"
dependencies = [
"dyn-stack 0.10.0",
"gemm-common 0.17.1",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 10.7.0",
"seq-macro",
]
[[package]]
name = "gemm-c32"
version = "0.18.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847"
dependencies = [
"dyn-stack 0.13.0",
"gemm-common 0.18.2",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 11.6.0",
"seq-macro",
]
[[package]]
name = "gemm-c64"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a"
dependencies = [
"dyn-stack 0.10.0",
"gemm-common 0.17.1",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 10.7.0",
"seq-macro",
]
[[package]]
name = "gemm-c64"
version = "0.18.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfcad8a3d35a43758330b635d02edad980c1e143dc2f21e6fd25f9e4eada8edf"
dependencies = [
"dyn-stack 0.13.0",
"gemm-common 0.18.2",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 11.6.0",
"seq-macro",
]
[[package]]
name = "gemm-common"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8"
dependencies = [
"bytemuck",
"dyn-stack 0.10.0",
"half",
"num-complex",
"num-traits",
"once_cell",
"paste",
"pulp 0.18.22",
"raw-cpuid 10.7.0",
"rayon",
"seq-macro",
"sysctl 0.5.5",
]
[[package]]
name = "gemm-common"
version = "0.18.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3"
dependencies = [
"bytemuck",
"dyn-stack 0.13.0",
"half",
"libm",
"num-complex",
"num-traits",
"once_cell",
"paste",
"pulp 0.21.5",
"raw-cpuid 11.6.0",
"rayon",
"seq-macro",
"sysctl 0.6.0",
]
[[package]]
name = "gemm-f16"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4"
dependencies = [
"dyn-stack 0.10.0",
"gemm-common 0.17.1",
"gemm-f32 0.17.1",
"half",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 10.7.0",
"rayon",
"seq-macro",
]
[[package]]
name = "gemm-f16"
version = "0.18.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109"
dependencies = [
"dyn-stack 0.13.0",
"gemm-common 0.18.2",
"gemm-f32 0.18.2",
"half",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 11.6.0",
"rayon",
"seq-macro",
]
[[package]]
name = "gemm-f32"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113"
dependencies = [
"dyn-stack 0.10.0",
"gemm-common 0.17.1",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 10.7.0",
"seq-macro",
]
[[package]]
name = "gemm-f32"
version = "0.18.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc8d3d4385393304f407392f754cd2dc4b315d05063f62cf09f47b58de276864"
dependencies = [
"dyn-stack 0.13.0",
"gemm-common 0.18.2",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 11.6.0",
"seq-macro",
]
[[package]]
name = "gemm-f64"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0"
dependencies = [
"dyn-stack 0.10.0",
"gemm-common 0.17.1",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 10.7.0",
"seq-macro",
]
[[package]]
name = "gemm-f64"
version = "0.18.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35b2a4f76ce4b8b16eadc11ccf2e083252d8237c1b589558a49b0183545015bd"
dependencies = [
"dyn-stack 0.13.0",
"gemm-common 0.18.2",
"num-complex",
"num-traits",
"paste",
"raw-cpuid 11.6.0",
"seq-macro",
]
[[package]]
name = "generator"
version = "0.8.5"
@@ -7583,9 +7966,12 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9"
dependencies = [
"bytemuck",
"cfg-if",
"crunchy",
"num-traits",
"rand 0.9.1",
"rand_distr",
]
[[package]]
@@ -9179,6 +9565,7 @@ dependencies = [
"credentials_provider",
"deepseek",
"editor",
"fs",
"futures 0.3.31",
"google_ai",
"gpui",
@@ -9212,6 +9599,7 @@ dependencies = [
"vercel",
"workspace-hack",
"x_ai",
"zed_env_vars",
]
[[package]]
@@ -9305,6 +9693,7 @@ dependencies = [
"pet-fs",
"pet-poetry",
"pet-reporter",
"pet-virtualenv",
"pretty_assertions",
"project",
"regex",
@@ -10174,6 +10563,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f"
dependencies = [
"libc",
"stable_deref_trait",
]
[[package]]
@@ -10440,12 +10830,6 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "multimap"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
[[package]]
name = "naga"
version = "25.0.1"
@@ -10819,6 +11203,7 @@ version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"bytemuck",
"num-traits",
]
@@ -12512,6 +12897,15 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "primal-check"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08"
dependencies = [
"num-integer",
]
[[package]]
name = "proc-macro-crate"
version = "3.3.0"
@@ -12820,7 +13214,7 @@ dependencies = [
"itertools 0.10.5",
"lazy_static",
"log",
"multimap 0.8.3",
"multimap",
"petgraph",
"prost 0.9.0",
"prost-types 0.9.0",
@@ -12837,9 +13231,9 @@ checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes 1.10.1",
"heck 0.5.0",
"itertools 0.12.1",
"itertools 0.11.0",
"log",
"multimap 0.10.0",
"multimap",
"once_cell",
"petgraph",
"prettyplease",
@@ -12870,7 +13264,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1"
dependencies = [
"anyhow",
"itertools 0.12.1",
"itertools 0.11.0",
"proc-macro2",
"quote",
"syn 2.0.101",
@@ -13011,6 +13405,32 @@ dependencies = [
"wasmtime-math",
]
[[package]]
name = "pulp"
version = "0.18.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6"
dependencies = [
"bytemuck",
"libm",
"num-complex",
"reborrow",
]
[[package]]
name = "pulp"
version = "0.21.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96b86df24f0a7ddd5e4b95c94fc9ed8a98f1ca94d3b01bdce2824097e7835907"
dependencies = [
"bytemuck",
"cfg-if",
"libm",
"num-complex",
"reborrow",
"version_check",
]
[[package]]
name = "qoi"
version = "0.4.1"
@@ -13187,6 +13607,16 @@ dependencies = [
"getrandom 0.3.2",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
]
[[package]]
name = "range-map"
version = "0.2.0"
@@ -13252,6 +13682,24 @@ dependencies = [
"rgb",
]
[[package]]
name = "raw-cpuid"
version = "10.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "raw-cpuid"
version = "11.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
dependencies = [
"bitflags 2.9.0",
]
[[package]]
name = "raw-window-handle"
version = "0.6.2"
@@ -13300,6 +13748,21 @@ dependencies = [
"font-types",
]
[[package]]
name = "realfft"
version = "3.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f821338fddb99d089116342c46e9f1fbf3828dba077674613e734e01d6ea8677"
dependencies = [
"rustfft",
]
[[package]]
name = "reborrow"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "recent_projects"
version = "0.1.0"
@@ -14116,6 +14579,20 @@ dependencies = [
"semver",
]
[[package]]
name = "rustfft"
version = "6.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6f140db74548f7c9d7cce60912c9ac414e74df5e718dc947d514b051b42f3f4"
dependencies = [
"num-complex",
"num-integer",
"num-traits",
"primal-check",
"strength_reduce",
"transpose",
]
[[package]]
name = "rustix"
version = "0.38.44"
@@ -14340,6 +14817,16 @@ version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "safetensors"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "salsa20"
version = "0.10.2"
@@ -14721,6 +15208,12 @@ dependencies = [
"serde",
]
[[package]]
name = "seq-macro"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc"
[[package]]
name = "serde"
version = "1.0.221"
@@ -15675,6 +16168,12 @@ dependencies = [
"workspace-hack",
]
[[package]]
name = "strength_reduce"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82"
[[package]]
name = "strict-num"
version = "0.1.1"
@@ -16165,6 +16664,34 @@ dependencies = [
"libc",
]
[[package]]
name = "sysctl"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea"
dependencies = [
"bitflags 2.9.0",
"byteorder",
"enum-as-inner",
"libc",
"thiserror 1.0.69",
"walkdir",
]
[[package]]
name = "sysctl"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc"
dependencies = [
"bitflags 2.9.0",
"byteorder",
"enum-as-inner",
"libc",
"thiserror 1.0.69",
"walkdir",
]
[[package]]
name = "sysinfo"
version = "0.31.4"
@@ -17259,6 +17786,16 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "transpose"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e"
dependencies = [
"num-integer",
"strength_reduce",
]
[[package]]
name = "tree-sitter"
version = "0.25.6"
@@ -17620,6 +18157,27 @@ dependencies = [
"winapi",
]
[[package]]
name = "ug"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90b70b37e9074642bc5f60bb23247fd072a84314ca9e71cdf8527593406a0dd3"
dependencies = [
"gemm 0.18.2",
"half",
"libloading",
"memmap2",
"num",
"num-traits",
"num_cpus",
"rayon",
"safetensors",
"serde",
"thiserror 1.0.69",
"tracing",
"yoke",
]
[[package]]
name = "ui"
version = "0.1.0"
@@ -18888,7 +19446,7 @@ dependencies = [
"reqwest 0.11.27",
"scratch",
"semver",
"zip",
"zip 0.6.6",
]
[[package]]
@@ -20033,7 +20591,7 @@ dependencies = [
"idna",
"indexmap",
"inout",
"itertools 0.12.1",
"itertools 0.11.0",
"itertools 0.13.0",
"jiff",
"lazy_static",
@@ -20047,6 +20605,7 @@ dependencies = [
"lyon_path",
"md-5",
"memchr",
"memmap2",
"mime_guess",
"miniz_oxide",
"mio 1.0.3",
@@ -20055,8 +20614,10 @@ dependencies = [
"nix 0.29.0",
"nix 0.30.1",
"nom 7.1.3",
"num",
"num-bigint",
"num-bigint-dig",
"num-complex",
"num-integer",
"num-iter",
"num-rational",
@@ -20072,6 +20633,7 @@ dependencies = [
"phf_shared",
"prettyplease",
"proc-macro2",
"prost 0.12.6",
"prost 0.9.0",
"prost-types 0.9.0",
"quote",
@@ -20079,6 +20641,7 @@ dependencies = [
"rand 0.9.1",
"rand_chacha 0.3.1",
"rand_core 0.6.4",
"rand_distr",
"regalloc2",
"regex",
"regex-automata",
@@ -20108,6 +20671,7 @@ dependencies = [
"sqlx-macros-core",
"sqlx-postgres",
"sqlx-sqlite",
"stable_deref_trait",
"strum 0.26.3",
"subtle",
"syn 1.0.109",
@@ -20141,6 +20705,7 @@ dependencies = [
"windows-sys 0.48.0",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
"winnow",
"zeroize",
"zvariant",
@@ -20677,6 +21242,7 @@ dependencies = [
name = "zed_env_vars"
version = "0.1.0"
dependencies = [
"gpui",
"workspace-hack",
]
@@ -20990,6 +21556,21 @@ dependencies = [
"zstd",
]
[[package]]
name = "zip"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cc23c04387f4da0374be4533ad1208cbb091d5c11d070dfef13676ad6497164"
dependencies = [
"arbitrary",
"crc32fast",
"crossbeam-utils",
"displaydoc",
"indexmap",
"num_enum",
"thiserror 1.0.69",
]
[[package]]
name = "zlib-rs"
version = "0.5.0"

View File

@@ -52,10 +52,12 @@ members = [
"crates/debugger_tools",
"crates/debugger_ui",
"crates/deepseek",
"crates/denoise",
"crates/diagnostics",
"crates/docs_preprocessor",
"crates/edit_prediction",
"crates/edit_prediction_button",
"crates/edit_prediction_context",
"crates/editor",
"crates/eval",
"crates/explorer_command_injector",
@@ -312,6 +314,7 @@ icons = { path = "crates/icons" }
image_viewer = { path = "crates/image_viewer" }
edit_prediction = { path = "crates/edit_prediction" }
edit_prediction_button = { path = "crates/edit_prediction_button" }
edit_prediction_context = { path = "crates/edit_prediction_context" }
inspector_ui = { path = "crates/inspector_ui" }
install_cli = { path = "crates/install_cli" }
jj = { path = "crates/jj" }
@@ -582,6 +585,7 @@ pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", re
pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
portable-pty = "0.9.0"
postage = { version = "0.5", features = ["futures-traits"] }
pretty_assertions = { version = "1.3.0", features = ["unstable"] }
@@ -630,6 +634,7 @@ sha2 = "0.10"
shellexpand = "2.1.0"
shlex = "1.3.0"
simplelog = "0.12.2"
slotmap = "1.0.6"
smallvec = { version = "1.6", features = ["union"] }
smol = "2.0"
sqlformat = "0.2"

View File

@@ -462,8 +462,8 @@
"ctrl-k ctrl-w": "workspace::CloseAllItemsAndPanes",
"back": "pane::GoBack",
"ctrl-alt--": "pane::GoBack",
"ctrl-alt-_": "pane::GoForward",
"forward": "pane::GoForward",
"ctrl-alt-_": "pane::GoForward",
"ctrl-alt-g": "search::SelectNextMatch",
"f3": "search::SelectNextMatch",
"ctrl-alt-shift-g": "search::SelectPreviousMatch",

View File

@@ -497,6 +497,8 @@
"shift-alt-down": "editor::DuplicateLineDown",
"shift-alt-right": "editor::SelectLargerSyntaxNode", // Expand selection
"shift-alt-left": "editor::SelectSmallerSyntaxNode", // Shrink selection
"ctrl-shift-right": "editor::SelectLargerSyntaxNode", // Expand selection (VSCode version)
"ctrl-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink selection (VSCode version)
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand

View File

@@ -914,7 +914,11 @@
/// Whether to have terminal cards in the agent panel expanded, showing the whole command output.
///
/// Default: true
"expand_terminal_card": true
"expand_terminal_card": true,
// Minimum number of lines to display in the agent message editor.
//
// Default: 4
"message_editor_min_lines": 4
},
// The settings for slash commands.
"slash_commands": {

View File

@@ -45,7 +45,6 @@ url.workspace = true
util.workspace = true
uuid.workspace = true
watch.workspace = true
which.workspace = true
workspace-hack.workspace = true
[dev-dependencies]

View File

@@ -7,12 +7,12 @@ use agent_settings::AgentSettings;
use collections::HashSet;
pub use connection::*;
pub use diff::*;
use futures::future::Shared;
use language::language_settings::FormatOnSave;
pub use mention::*;
use project::lsp_store::{FormatTrigger, LspFormatTarget};
use serde::{Deserialize, Serialize};
use settings::Settings as _;
use task::{Shell, ShellBuilder};
pub use terminal::*;
use action_log::ActionLog;
@@ -34,7 +34,7 @@ use std::rc::Rc;
use std::time::{Duration, Instant};
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
use ui::App;
use util::{ResultExt, get_system_shell};
use util::{ResultExt, get_default_system_shell};
use uuid::Uuid;
#[derive(Debug)]
@@ -786,7 +786,6 @@ pub struct AcpThread {
token_usage: Option<TokenUsage>,
prompt_capabilities: acp::PromptCapabilities,
_observe_prompt_capabilities: Task<anyhow::Result<()>>,
determine_shell: Shared<Task<String>>,
terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
}
@@ -873,20 +872,6 @@ impl AcpThread {
}
});
let determine_shell = cx
.background_spawn(async move {
if cfg!(windows) {
return get_system_shell();
}
if which::which("bash").is_ok() {
"bash".into()
} else {
get_system_shell()
}
})
.shared();
Self {
action_log,
shared_buffers: Default::default(),
@@ -901,7 +886,6 @@ impl AcpThread {
prompt_capabilities,
_observe_prompt_capabilities: task,
terminals: HashMap::default(),
determine_shell,
}
}
@@ -1127,9 +1111,33 @@ impl AcpThread {
let update = update.into();
let languages = self.project.read(cx).languages().clone();
let ix = self
.index_for_tool_call(update.id())
.context("Tool call not found")?;
let ix = match self.index_for_tool_call(update.id()) {
Some(ix) => ix,
None => {
// Tool call not found - create a failed tool call entry
let failed_tool_call = ToolCall {
id: update.id().clone(),
label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
kind: acp::ToolKind::Fetch,
content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
acp::ContentBlock::Text(acp::TextContent {
text: "Tool call not found".to_string(),
annotations: None,
meta: None,
}),
&languages,
cx,
))],
status: ToolCallStatus::Failed,
locations: Vec::new(),
resolved_locations: Vec::new(),
raw_input: None,
raw_output: None,
};
self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
return Ok(());
}
};
let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
unreachable!()
};
@@ -1940,28 +1948,13 @@ impl AcpThread {
pub fn create_terminal(
&self,
mut command: String,
command: String,
args: Vec<String>,
extra_env: Vec<acp::EnvVariable>,
cwd: Option<PathBuf>,
output_byte_limit: Option<u64>,
cx: &mut Context<Self>,
) -> Task<Result<Entity<Terminal>>> {
for arg in args {
command.push(' ');
command.push_str(&arg);
}
let shell_command = if cfg!(windows) {
format!("$null | & {{{}}}", command.replace("\"", "'"))
} else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) {
// Make sure once we're *inside* the shell, we cd into `cwd`
format!("(cd {cwd}; {}) </dev/null", command)
} else {
format!("({}) </dev/null", command)
};
let args = vec!["-c".into(), shell_command];
let env = match &cwd {
Some(dir) => self.project.update(cx, |project, cx| {
project.directory_environment(dir.as_path().into(), cx)
@@ -1982,20 +1975,30 @@ impl AcpThread {
let project = self.project.clone();
let language_registry = project.read(cx).languages().clone();
let determine_shell = self.determine_shell.clone();
let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
let terminal_task = cx.spawn({
let terminal_id = terminal_id.clone();
async move |_this, cx| {
let program = determine_shell.await;
let env = env.await;
let (command, args) = ShellBuilder::new(
project
.update(cx, |project, cx| {
project
.remote_client()
.and_then(|r| r.read(cx).default_system_shell())
})?
.as_deref(),
&Shell::Program(get_default_system_shell()),
)
.redirect_stdin_to_dev_null()
.build(Some(command), &args);
let terminal = project
.update(cx, |project, cx| {
project.create_terminal_task(
task::SpawnInTerminal {
command: Some(program),
args,
command: Some(command.clone()),
args: args.clone(),
cwd: cwd.clone(),
env,
..Default::default()
@@ -2008,7 +2011,7 @@ impl AcpThread {
cx.new(|cx| {
Terminal::new(
terminal_id,
command,
&format!("{} {}", command, args.join(" ")),
cwd,
output_byte_limit.map(|l| l as usize),
terminal,
@@ -3181,4 +3184,65 @@ mod tests {
Task::ready(Ok(()))
}
}
#[gpui::test]
async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
.await
.unwrap();
// Try to update a tool call that doesn't exist
let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
thread.update(cx, |thread, cx| {
let result = thread.handle_session_update(
acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
id: nonexistent_id.clone(),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
..Default::default()
},
meta: None,
}),
cx,
);
// The update should succeed (not return an error)
assert!(result.is_ok());
// There should now be exactly one entry in the thread
assert_eq!(thread.entries.len(), 1);
// The entry should be a failed tool call
if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
assert_eq!(tool_call.id, nonexistent_id);
assert!(matches!(tool_call.status, ToolCallStatus::Failed));
assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
// Check that the content contains the error message
assert_eq!(tool_call.content.len(), 1);
if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
match content_block {
ContentBlock::Markdown { markdown } => {
let markdown_text = markdown.read(cx).source();
assert!(markdown_text.contains("Tool call not found"));
}
ContentBlock::Empty => panic!("Expected markdown content, got empty"),
ContentBlock::ResourceLink { .. } => {
panic!("Expected markdown content, got resource link")
}
}
} else {
panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
}
} else {
panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
}
});
}
}

View File

@@ -28,7 +28,7 @@ pub struct TerminalOutput {
impl Terminal {
pub fn new(
id: acp::TerminalId,
command: String,
command_label: &str,
working_dir: Option<PathBuf>,
output_byte_limit: Option<usize>,
terminal: Entity<terminal::Terminal>,
@@ -40,7 +40,7 @@ impl Terminal {
id,
command: cx.new(|cx| {
Markdown::new(
format!("```\n{}\n```", command).into(),
format!("```\n{}\n```", command_label).into(),
Some(language_registry.clone()),
None,
cx,

View File

@@ -1,4 +1,4 @@
use auto_update::{AutoUpdateStatus, AutoUpdater, DismissErrorMessage, VersionCheckType};
use auto_update::{AutoUpdateStatus, AutoUpdater, DismissMessage, VersionCheckType};
use editor::Editor;
use extension_host::{ExtensionOperation, ExtensionStore};
use futures::StreamExt;
@@ -280,18 +280,13 @@ impl ActivityIndicator {
});
}
fn dismiss_error_message(
&mut self,
_: &DismissErrorMessage,
_: &mut Window,
cx: &mut Context<Self>,
) {
let error_dismissed = if let Some(updater) = &self.auto_updater {
updater.update(cx, |updater, cx| updater.dismiss_error(cx))
fn dismiss_message(&mut self, _: &DismissMessage, _: &mut Window, cx: &mut Context<Self>) {
let dismissed = if let Some(updater) = &self.auto_updater {
updater.update(cx, |updater, cx| updater.dismiss(cx))
} else {
false
};
if error_dismissed {
if dismissed {
return;
}
@@ -513,7 +508,7 @@ impl ActivityIndicator {
on_click: Some(Arc::new(move |this, window, cx| {
this.statuses
.retain(|status| !downloading.contains(&status.name));
this.dismiss_error_message(&DismissErrorMessage, window, cx)
this.dismiss_message(&DismissMessage, window, cx)
})),
tooltip_message: None,
});
@@ -542,7 +537,7 @@ impl ActivityIndicator {
on_click: Some(Arc::new(move |this, window, cx| {
this.statuses
.retain(|status| !checking_for_update.contains(&status.name));
this.dismiss_error_message(&DismissErrorMessage, window, cx)
this.dismiss_message(&DismissMessage, window, cx)
})),
tooltip_message: None,
});
@@ -650,13 +645,14 @@ impl ActivityIndicator {
.and_then(|updater| match &updater.read(cx).status() {
AutoUpdateStatus::Checking => Some(Content {
icon: Some(
Icon::new(IconName::Download)
Icon::new(IconName::LoadCircle)
.size(IconSize::Small)
.with_rotate_animation(3)
.into_any_element(),
),
message: "Checking for Zed updates…".to_string(),
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx)
this.dismiss_message(&DismissMessage, window, cx)
})),
tooltip_message: None,
}),
@@ -668,19 +664,20 @@ impl ActivityIndicator {
),
message: "Downloading Zed update…".to_string(),
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx)
this.dismiss_message(&DismissMessage, window, cx)
})),
tooltip_message: Some(Self::version_tooltip_message(version)),
}),
AutoUpdateStatus::Installing { version } => Some(Content {
icon: Some(
Icon::new(IconName::Download)
Icon::new(IconName::LoadCircle)
.size(IconSize::Small)
.with_rotate_animation(3)
.into_any_element(),
),
message: "Installing Zed update…".to_string(),
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx)
this.dismiss_message(&DismissMessage, window, cx)
})),
tooltip_message: Some(Self::version_tooltip_message(version)),
}),
@@ -690,17 +687,18 @@ impl ActivityIndicator {
on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))),
tooltip_message: Some(Self::version_tooltip_message(version)),
}),
AutoUpdateStatus::Errored => Some(Content {
AutoUpdateStatus::Errored { error } => Some(Content {
icon: Some(
Icon::new(IconName::Warning)
.size(IconSize::Small)
.into_any_element(),
),
message: "Auto update failed".to_string(),
message: "Failed to update Zed".to_string(),
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&DismissErrorMessage, window, cx)
window.dispatch_action(Box::new(workspace::OpenLog), cx);
this.dismiss_message(&DismissMessage, window, cx);
})),
tooltip_message: None,
tooltip_message: Some(format!("{error}")),
}),
AutoUpdateStatus::Idle => None,
})
@@ -738,7 +736,7 @@ impl ActivityIndicator {
})),
message,
on_click: Some(Arc::new(|this, window, cx| {
this.dismiss_error_message(&Default::default(), window, cx)
this.dismiss_message(&Default::default(), window, cx)
})),
tooltip_message: None,
})
@@ -777,7 +775,7 @@ impl Render for ActivityIndicator {
let result = h_flex()
.id("activity-indicator")
.on_action(cx.listener(Self::show_error_message))
.on_action(cx.listener(Self::dismiss_error_message));
.on_action(cx.listener(Self::dismiss_message));
let Some(content) = self.content_to_render(cx) else {
return result;
};

View File

@@ -30,6 +30,7 @@ fs.workspace = true
futures.workspace = true
gpui.workspace = true
gpui_tokio = { workspace = true, optional = true }
http_client.workspace = true
indoc.workspace = true
language.workspace = true
language_model.workspace = true

View File

@@ -7,15 +7,19 @@ mod gemini;
pub mod e2e_tests;
pub use claude::*;
use client::ProxySettings;
use collections::HashMap;
pub use custom::*;
use fs::Fs;
pub use gemini::*;
use http_client::read_no_proxy_from_env;
use project::agent_server_store::AgentServerStore;
use acp_thread::AgentConnection;
use anyhow::Result;
use gpui::{App, Entity, SharedString, Task};
use gpui::{App, AppContext, Entity, SharedString, Task};
use project::Project;
use settings::SettingsStore;
use std::{any::Any, path::Path, rc::Rc, sync::Arc};
pub use acp::AcpConnection;
@@ -77,3 +81,25 @@ impl dyn AgentServer {
self.into_any().downcast().ok()
}
}
/// Load the default proxy environment variables to pass through to the agent
pub fn load_proxy_env(cx: &mut App) -> HashMap<String, String> {
let proxy_url = cx
.read_global(|settings: &SettingsStore, _| settings.get::<ProxySettings>(None).proxy_url());
let mut env = HashMap::default();
if let Some(proxy_url) = &proxy_url {
let env_var = if proxy_url.scheme() == "https" {
"HTTPS_PROXY"
} else {
"HTTP_PROXY"
};
env.insert(env_var.to_owned(), proxy_url.to_string());
}
if let Some(no_proxy) = read_no_proxy_from_env() {
env.insert("NO_PROXY".to_owned(), no_proxy);
}
env
}

View File

@@ -10,7 +10,7 @@ use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, SharedString, Task};
use project::agent_server_store::{AllAgentServersSettings, CLAUDE_CODE_NAME};
use crate::{AgentServer, AgentServerDelegate};
use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
use acp_thread::AgentConnection;
#[derive(Clone)]
@@ -60,6 +60,7 @@ impl AgentServer for ClaudeCode {
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
let extra_env = load_proxy_env(cx);
let default_mode = self.default_mode(cx);
cx.spawn(async move |cx| {
@@ -70,7 +71,7 @@ impl AgentServer for ClaudeCode {
.context("Claude Code is not registered")?;
anyhow::Ok(agent.get_command(
root_dir.as_deref(),
Default::default(),
extra_env,
delegate.status_tx,
delegate.new_version_available,
&mut cx.to_async(),

View File

@@ -1,4 +1,4 @@
use crate::AgentServerDelegate;
use crate::{AgentServerDelegate, load_proxy_env};
use acp_thread::AgentConnection;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
@@ -65,6 +65,7 @@ impl crate::AgentServer for CustomAgentServer {
let is_remote = delegate.project.read(cx).is_via_remote_server();
let default_mode = self.default_mode(cx);
let store = delegate.store.downgrade();
let extra_env = load_proxy_env(cx);
cx.spawn(async move |cx| {
let (command, root_dir, login) = store
@@ -76,7 +77,7 @@ impl crate::AgentServer for CustomAgentServer {
})?;
anyhow::Ok(agent.get_command(
root_dir.as_deref(),
Default::default(),
extra_env,
delegate.status_tx,
delegate.new_version_available,
&mut cx.to_async(),

View File

@@ -1,15 +1,12 @@
use std::rc::Rc;
use std::{any::Any, path::Path};
use crate::{AgentServer, AgentServerDelegate};
use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
use acp_thread::AgentConnection;
use anyhow::{Context as _, Result};
use client::ProxySettings;
use collections::HashMap;
use gpui::{App, AppContext, SharedString, Task};
use gpui::{App, SharedString, Task};
use language_models::provider::google::GoogleLanguageModelProvider;
use project::agent_server_store::GEMINI_NAME;
use settings::SettingsStore;
#[derive(Clone)]
pub struct Gemini;
@@ -37,17 +34,20 @@ impl AgentServer for Gemini {
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
let proxy_url = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<ProxySettings>(None).proxy.clone()
});
let mut extra_env = load_proxy_env(cx);
let default_mode = self.default_mode(cx);
cx.spawn(async move |cx| {
let mut extra_env = HashMap::default();
if let Some(api_key) = cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() {
extra_env.insert("GEMINI_API_KEY".into(), api_key.key);
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
if let Some(api_key) = cx
.update(GoogleLanguageModelProvider::api_key_for_gemini_cli)?
.await
.ok()
{
extra_env.insert("GEMINI_API_KEY".into(), api_key);
}
let (mut command, root_dir, login) = store
let (command, root_dir, login) = store
.update(cx, |store, cx| {
let agent = store
.get_external_agent(&GEMINI_NAME.into())
@@ -62,14 +62,6 @@ impl AgentServer for Gemini {
})??
.await?;
// Add proxy flag if proxy settings are configured in Zed and not in the args
if let Some(proxy_url_value) = &proxy_url
&& !command.args.iter().any(|arg| arg.contains("--proxy"))
{
command.args.push("--proxy".into());
command.args.push(proxy_url_value.clone());
}
let connection = crate::acp::connect(
name,
command,

View File

@@ -1,125 +0,0 @@
use agent_client_protocol as acp;
use std::path::PathBuf;
use crate::AgentServerCommand;
use anyhow::Result;
use collections::HashMap;
use gpui::{App, SharedString};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsKey, SettingsSources, SettingsUi};
pub fn init(cx: &mut App) {
AllAgentServersSettings::register(cx);
}
#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, SettingsUi, SettingsKey)]
#[settings_key(key = "agent_servers")]
pub struct AllAgentServersSettings {
pub gemini: Option<BuiltinAgentServerSettings>,
pub claude: Option<BuiltinAgentServerSettings>,
/// Custom agent servers configured by the user
#[serde(flatten)]
pub custom: HashMap<SharedString, CustomAgentServerSettings>,
}
#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)]
pub struct BuiltinAgentServerSettings {
/// Absolute path to a binary to be used when launching this agent.
///
/// This can be used to run a specific binary without automatic downloads or searching `$PATH`.
#[serde(rename = "command")]
pub path: Option<PathBuf>,
/// If a binary is specified in `command`, it will be passed these arguments.
pub args: Option<Vec<String>>,
/// If a binary is specified in `command`, it will be passed these environment variables.
pub env: Option<HashMap<String, String>>,
/// Whether to skip searching `$PATH` for an agent server binary when
/// launching this agent.
///
/// This has no effect if a `command` is specified. Otherwise, when this is
/// `false`, Zed will search `$PATH` for an agent server binary and, if one
/// is found, use it for threads with this agent. If no agent binary is
/// found on `$PATH`, Zed will automatically install and use its own binary.
/// When this is `true`, Zed will not search `$PATH`, and will always use
/// its own binary.
///
/// Default: true
pub ignore_system_version: Option<bool>,
/// The default mode for new threads.
///
/// Note: Not all agents support modes.
///
/// Default: None
#[serde(skip_serializing_if = "Option::is_none")]
pub default_mode: Option<acp::SessionModeId>,
}
impl BuiltinAgentServerSettings {
pub(crate) fn custom_command(self) -> Option<AgentServerCommand> {
self.path.map(|path| AgentServerCommand {
path,
args: self.args.unwrap_or_default(),
env: self.env,
})
}
}
impl From<AgentServerCommand> for BuiltinAgentServerSettings {
fn from(value: AgentServerCommand) -> Self {
BuiltinAgentServerSettings {
path: Some(value.path),
args: Some(value.args),
env: value.env,
..Default::default()
}
}
}
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)]
pub struct CustomAgentServerSettings {
#[serde(flatten)]
pub command: AgentServerCommand,
/// The default mode for new threads.
///
/// Note: Not all agents support modes.
///
/// Default: None
#[serde(skip_serializing_if = "Option::is_none")]
pub default_mode: Option<acp::SessionModeId>,
}
impl settings::Settings for AllAgentServersSettings {
type FileContent = Self;
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
let mut settings = AllAgentServersSettings::default();
for AllAgentServersSettings {
gemini,
claude,
custom,
} in sources.defaults_and_customizations()
{
if gemini.is_some() {
settings.gemini = gemini.clone();
}
if claude.is_some() {
settings.claude = claude.clone();
}
// Merge custom agents
for (name, config) in custom {
// Skip built-in agent names to avoid conflicts
if name != "gemini" && name != "claude" {
settings.custom.insert(name.clone(), config.clone());
}
}
}
Ok(settings)
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}

View File

@@ -75,6 +75,7 @@ pub struct AgentSettings {
pub expand_edit_card: bool,
pub expand_terminal_card: bool,
pub use_modifier_to_send: bool,
pub message_editor_min_lines: usize,
}
impl AgentSettings {
@@ -107,6 +108,10 @@ impl AgentSettings {
model,
});
}
pub fn set_message_editor_max_lines(&self) -> usize {
self.message_editor_min_lines * 2
}
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
@@ -320,6 +325,10 @@ pub struct AgentSettingsContent {
///
/// Default: false
use_modifier_to_send: Option<bool>,
/// Minimum number of lines of height the agent message editor should have.
///
/// Default: 4
message_editor_min_lines: Option<usize>,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)]
@@ -355,21 +364,30 @@ impl JsonSchema for LanguageModelProviderSetting {
}
fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema {
// list the builtin providers as a subset so that we still auto complete them in the settings
json_schema!({
"enum": [
"amazon-bedrock",
"anthropic",
"copilot_chat",
"deepseek",
"google",
"lmstudio",
"mistral",
"ollama",
"openai",
"openrouter",
"vercel",
"x_ai",
"zed.dev"
"anyOf": [
{
"type": "string",
"enum": [
"amazon-bedrock",
"anthropic",
"copilot_chat",
"deepseek",
"google",
"lmstudio",
"mistral",
"ollama",
"openai",
"openrouter",
"vercel",
"x_ai",
"zed.dev"
]
},
{
"type": "string",
}
]
})
}
@@ -472,6 +490,10 @@ impl Settings for AgentSettings {
&mut settings.use_modifier_to_send,
value.use_modifier_to_send,
);
merge(
&mut settings.message_editor_min_lines,
value.message_editor_min_lines,
);
settings
.model_parameters

View File

@@ -80,7 +80,6 @@ serde.workspace = true
serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
shlex.workspace = true
smol.workspace = true
streaming_diff.workspace = true
task.workspace = true

View File

@@ -1099,11 +1099,16 @@ impl MessageEditor {
}
pub fn insert_selections(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let buffer = self.editor.read(cx).buffer().clone();
let Some(buffer) = buffer.read(cx).as_singleton() else {
let editor = self.editor.read(cx);
let editor_buffer = editor.buffer().read(cx);
let Some(buffer) = editor_buffer.as_singleton() else {
return;
};
let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len()));
let cursor_anchor = editor.selections.newest_anchor().head();
let cursor_offset = cursor_anchor.to_offset(&editor_buffer.snapshot(cx));
let anchor = buffer.update(cx, |buffer, _cx| {
buffer.anchor_before(cursor_offset.min(buffer.len()))
});
let Some(workspace) = self.workspace.upgrade() else {
return;
};
@@ -1117,13 +1122,7 @@ impl MessageEditor {
return;
};
self.editor.update(cx, |message_editor, cx| {
message_editor.edit(
[(
multi_buffer::Anchor::max()..multi_buffer::Anchor::max(),
completion.new_text,
)],
cx,
);
message_editor.edit([(cursor_anchor..cursor_anchor, completion.new_text)], cx);
});
if let Some(confirm) = completion.confirm {
confirm(CompletionIntent::Complete, window, cx);

View File

@@ -107,13 +107,15 @@ impl ModeSelector {
.text_sm()
.text_color(Color::Muted.color(cx))
.child("Hold")
.child(div().pt_0p5().children(ui::render_modifiers(
&gpui::Modifiers::secondary_key(),
PlatformStyle::platform(),
None,
Some(ui::TextSize::Default.rems(cx).into()),
true,
)))
.child(h_flex().flex_shrink_0().children(
ui::render_modifiers(
&gpui::Modifiers::secondary_key(),
PlatformStyle::platform(),
None,
Some(ui::TextSize::Default.rems(cx).into()),
true,
),
))
.child(div().map(|this| {
if is_default {
this.child("to also unset as default")

View File

@@ -500,20 +500,24 @@ impl Render for AcpThreadHistory {
),
)
} else {
view.pr_5()
.child(
uniform_list(
"thread-history",
self.visible_items.len(),
cx.processor(|this, range: Range<usize>, window, cx| {
this.render_list_items(range, window, cx)
}),
)
.p_1()
.track_scroll(self.scroll_handle.clone())
.flex_grow(),
view.child(
uniform_list(
"thread-history",
self.visible_items.len(),
cx.processor(|this, range: Range<usize>, window, cx| {
this.render_list_items(range, window, cx)
}),
)
.vertical_scrollbar_for(self.scroll_handle.clone(), window, cx)
.p_1()
.pr_4()
.track_scroll(self.scroll_handle.clone())
.flex_grow(),
)
.vertical_scrollbar_for(
self.scroll_handle.clone(),
window,
cx,
)
}
})
}

View File

@@ -9,7 +9,7 @@ use agent_client_protocol::{self as acp, PromptCapabilities};
use agent_servers::{AgentServer, AgentServerDelegate};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, NotifyWhenAgentWaiting};
use agent2::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer};
use anyhow::{Context as _, Result, anyhow, bail};
use anyhow::{Result, anyhow, bail};
use arrayvec::ArrayVec;
use audio::{Audio, Sound};
use buffer_diff::BufferDiff;
@@ -71,9 +71,6 @@ use crate::{
RejectOnce, ToggleBurnMode, ToggleProfileSelector,
};
pub const MIN_EDITOR_LINES: usize = 4;
pub const MAX_EDITOR_LINES: usize = 8;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum ThreadFeedback {
Positive,
@@ -357,8 +354,8 @@ impl AcpThreadView {
agent.name(),
&placeholder,
editor::EditorMode::AutoHeight {
min_lines: MIN_EDITOR_LINES,
max_lines: Some(MAX_EDITOR_LINES),
min_lines: AgentSettings::get_global(cx).message_editor_min_lines,
max_lines: Some(AgentSettings::get_global(cx).set_message_editor_max_lines()),
},
window,
cx,
@@ -857,10 +854,11 @@ impl AcpThreadView {
cx,
)
} else {
let agent_settings = AgentSettings::get_global(cx);
editor.set_mode(
EditorMode::AutoHeight {
min_lines: MIN_EDITOR_LINES,
max_lines: Some(MAX_EDITOR_LINES),
min_lines: agent_settings.message_editor_min_lines,
max_lines: Some(agent_settings.set_message_editor_max_lines()),
},
cx,
)
@@ -1584,19 +1582,6 @@ impl AcpThreadView {
window.spawn(cx, async move |cx| {
let mut task = login.clone();
task.command = task
.command
.map(|command| anyhow::Ok(shlex::try_quote(&command)?.to_string()))
.transpose()?;
task.args = task
.args
.iter()
.map(|arg| {
Ok(shlex::try_quote(arg)
.context("Failed to quote argument")?
.to_string())
})
.collect::<Result<Vec<_>>>()?;
task.full_label = task.label.clone();
task.id = task::TaskId(format!("external-agent-{}-login", task.label));
task.command_label = task.label.clone();
@@ -3197,10 +3182,14 @@ impl AcpThreadView {
};
Button::new(SharedString::from(method_id.clone()), name)
.when(ix == 0, |el| {
el.style(ButtonStyle::Tinted(ui::TintColor::Warning))
})
.label_size(LabelSize::Small)
.map(|this| {
if ix == 0 {
this.style(ButtonStyle::Tinted(TintColor::Warning))
} else {
this.style(ButtonStyle::Outlined)
}
})
.on_click({
cx.listener(move |this, _, window, cx| {
telemetry::event!(
@@ -5680,6 +5669,23 @@ pub(crate) mod tests {
});
}
#[gpui::test]
async fn test_spawn_external_agent_login_handles_spaces(cx: &mut TestAppContext) {
init_test(cx);
// Verify paths with spaces aren't pre-quoted
let path_with_spaces = "/Users/test/Library/Application Support/Zed/cli.js";
let login_task = task::SpawnInTerminal {
command: Some("node".to_string()),
args: vec![path_with_spaces.to_string(), "/login".to_string()],
..Default::default()
};
// Args should be passed as-is, not pre-quoted
assert!(!login_task.args[0].starts_with('"'));
assert!(!login_task.args[0].starts_with('\''));
}
#[gpui::test]
async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) {
init_test(cx);

View File

@@ -274,13 +274,28 @@ impl AgentConfiguration {
*is_expanded = !*is_expanded;
}
})),
)
.when(provider.is_authenticated(cx), |parent| {
),
)
.child(
v_flex()
.w_full()
.px_2()
.gap_1()
.when(is_expanded, |parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
None => parent.child(Label::new(format!(
"No configuration view for {provider_name}",
))),
})
.when(is_expanded && provider.is_authenticated(cx), |parent| {
parent.child(
Button::new(
SharedString::from(format!("new-thread-{provider_id}")),
"Start New Thread",
)
.full_width()
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
.icon_position(IconPosition::Start)
.icon(IconName::Thread)
.icon_size(IconSize::Small)
@@ -297,17 +312,6 @@ impl AgentConfiguration {
)
}),
)
.child(
div()
.w_full()
.px_2()
.when(is_expanded, |parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
None => parent.child(Label::new(format!(
"No configuration view for {provider_name}",
))),
}),
)
}
fn render_provider_configuration_section(
@@ -561,11 +565,28 @@ impl AgentConfiguration {
.color(Color::Muted),
),
)
.children(
context_server_ids.into_iter().map(|context_server_id| {
self.render_context_server(context_server_id, window, cx)
}),
)
.map(|parent| {
if context_server_ids.is_empty() {
parent.child(
h_flex()
.p_4()
.justify_center()
.border_1()
.border_dashed()
.border_color(cx.theme().colors().border.opacity(0.6))
.rounded_sm()
.child(
Label::new("No MCP servers added yet.")
.color(Color::Muted)
.size(LabelSize::Small),
),
)
} else {
parent.children(context_server_ids.into_iter().map(|context_server_id| {
self.render_context_server(context_server_id, window, cx)
}))
}
})
.child(
h_flex()
.justify_between()
@@ -818,6 +839,8 @@ impl AgentConfiguration {
)
.child(
h_flex()
.flex_1()
.min_w_0()
.child(
Disclosure::new(
"tool-list-disclosure",
@@ -841,17 +864,19 @@ impl AgentConfiguration {
.id(SharedString::from(format!("tooltip-{}", item_id)))
.h_full()
.w_3()
.mx_1()
.ml_1()
.mr_1p5()
.justify_center()
.tooltip(Tooltip::text(tooltip_text))
.child(status_indicator),
)
.child(Label::new(item_id).ml_0p5())
.child(Label::new(item_id).truncate())
.child(
div()
.id("extension-source")
.mt_0p5()
.mx_1()
.flex_none()
.tooltip(Tooltip::text(source_tooltip))
.child(
Icon::new(source_icon)
@@ -873,7 +898,8 @@ impl AgentConfiguration {
)
.child(
h_flex()
.gap_1()
.gap_0p5()
.flex_none()
.child(context_server_configuration_menu)
.child(
Switch::new("context-server-switch", is_running.into())
@@ -1123,6 +1149,7 @@ impl AgentConfiguration {
SharedString::from(format!("start_acp_thread-{name}")),
"Start New Thread",
)
.layer(ElevationIndex::ModalSurface)
.label_size(LabelSize::Small)
.icon(IconName::Thread)
.icon_position(IconPosition::Start)

View File

@@ -63,7 +63,6 @@ ui.workspace = true
util.workspace = true
watch.workspace = true
web_search.workspace = true
which.workspace = true
workspace-hack.workspace = true
workspace.workspace = true

View File

@@ -52,7 +52,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
let registry = ToolRegistry::global(cx);
registry.register_tool(TerminalTool::new(cx));
registry.register_tool(TerminalTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);

View File

@@ -160,7 +160,7 @@ mod tests {
&mut parser,
&mut rng
),
// This output is marlformed, so we're doing our best effort
// This output is malformed, so we're doing our best effort
"Hello world\n```\n\nThe end\n".to_string()
);
}
@@ -182,7 +182,7 @@ mod tests {
&mut parser,
&mut rng
),
// This output is marlformed, so we're doing our best effort
// This output is malformed, so we're doing our best effort
"```\nHello world\n```\n".to_string()
);
}

View File

@@ -916,7 +916,7 @@ impl Loader {
if !found_non_static {
found_non_static = true;
eprintln!(
"Warning: Found non-static non-tree-sitter functions in the external scannner"
"Warning: Found non-static non-tree-sitter functions in the external scanner"
);
}
eprintln!(" `{function_name}`");

View File

@@ -6,7 +6,7 @@ use action_log::ActionLog;
use agent_settings;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt as _, future::Shared};
use futures::FutureExt as _;
use gpui::{
AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, TextStyleRefinement,
WeakEntity, Window,
@@ -26,11 +26,12 @@ use std::{
sync::Arc,
time::{Duration, Instant},
};
use task::{Shell, ShellBuilder};
use terminal_view::TerminalView;
use theme::ThemeSettings;
use ui::{CommonAnimationExt, Disclosure, Tooltip, prelude::*};
use util::{
ResultExt, get_system_shell, markdown::MarkdownInlineCode, size::format_file_size,
ResultExt, get_default_system_shell, markdown::MarkdownInlineCode, size::format_file_size,
time::duration_alt_display,
};
use workspace::Workspace;
@@ -45,29 +46,10 @@ pub struct TerminalToolInput {
cd: String,
}
pub struct TerminalTool {
determine_shell: Shared<Task<String>>,
}
pub struct TerminalTool;
impl TerminalTool {
pub const NAME: &str = "terminal";
pub(crate) fn new(cx: &mut App) -> Self {
let determine_shell = cx.background_spawn(async move {
if cfg!(windows) {
return get_system_shell();
}
if which::which("bash").is_ok() {
"bash".into()
} else {
get_system_shell()
}
});
Self {
determine_shell: determine_shell.shared(),
}
}
}
impl Tool for TerminalTool {
@@ -135,19 +117,6 @@ impl Tool for TerminalTool {
Ok(dir) => dir,
Err(err) => return Task::ready(Err(err)).into(),
};
let program = self.determine_shell.clone();
let command = if cfg!(windows) {
format!("$null | & {{{}}}", input.command.replace("\"", "'"))
} else if let Some(cwd) = working_dir
.as_ref()
.and_then(|cwd| cwd.as_os_str().to_str())
{
// Make sure once we're *inside* the shell, we cd into `cwd`
format!("(cd {cwd}; {}) </dev/null", input.command)
} else {
format!("({}) </dev/null", input.command)
};
let args = vec!["-c".into(), command];
let cwd = working_dir.clone();
let env = match &working_dir {
@@ -156,6 +125,11 @@ impl Tool for TerminalTool {
}),
None => Task::ready(None).shared(),
};
let remote_shell = project.update(cx, |project, cx| {
project
.remote_client()
.and_then(|r| r.read(cx).default_system_shell())
});
let env = cx.spawn(async move |_| {
let mut env = env.await.unwrap_or_default();
@@ -171,8 +145,13 @@ impl Tool for TerminalTool {
let task = cx.background_spawn(async move {
let env = env.await;
let pty_system = native_pty_system();
let program = program.await;
let mut cmd = CommandBuilder::new(program);
let (command, args) = ShellBuilder::new(
remote_shell.as_deref(),
&Shell::Program(get_default_system_shell()),
)
.redirect_stdin_to_dev_null()
.build(Some(input.command.clone()), &[]);
let mut cmd = CommandBuilder::new(command);
cmd.args(args);
for (k, v) in env {
cmd.env(k, v);
@@ -208,16 +187,22 @@ impl Tool for TerminalTool {
};
};
let command = input.command.clone();
let terminal = cx.spawn({
let project = project.downgrade();
async move |cx| {
let program = program.await;
let (command, args) = ShellBuilder::new(
remote_shell.as_deref(),
&Shell::Program(get_default_system_shell()),
)
.redirect_stdin_to_dev_null()
.build(Some(input.command), &[]);
let env = env.await;
project
.update(cx, |project, cx| {
project.create_terminal_task(
task::SpawnInTerminal {
command: Some(program),
command: Some(command),
args,
cwd,
env,
@@ -230,14 +215,8 @@ impl Tool for TerminalTool {
}
});
let command_markdown = cx.new(|cx| {
Markdown::new(
format!("```bash\n{}\n```", input.command).into(),
None,
None,
cx,
)
});
let command_markdown =
cx.new(|cx| Markdown::new(format!("```bash\n{}\n```", command).into(), None, None, cx));
let card = cx.new(|cx| {
TerminalToolCard::new(
@@ -288,7 +267,7 @@ impl Tool for TerminalTool {
let previous_len = content.len();
let (processed_content, finished_with_empty_output) = process_content(
&content,
&input.command,
&command,
exit_status.map(portable_pty::ExitStatus::from),
);
@@ -740,7 +719,6 @@ mod tests {
if cfg!(windows) {
return;
}
init_test(&executor, cx);
let fs = Arc::new(RealFs::new(None, executor));
@@ -763,7 +741,7 @@ mod tests {
};
let result = cx.update(|cx| {
TerminalTool::run(
Arc::new(TerminalTool::new(cx)),
Arc::new(TerminalTool),
serde_json::to_value(input).unwrap(),
Arc::default(),
project.clone(),
@@ -783,7 +761,6 @@ mod tests {
if cfg!(windows) {
return;
}
init_test(&executor, cx);
let fs = Arc::new(RealFs::new(None, executor));
@@ -798,7 +775,7 @@ mod tests {
let check = |input, expected, cx: &mut App| {
let headless_result = TerminalTool::run(
Arc::new(TerminalTool::new(cx)),
Arc::new(TerminalTool),
serde_json::to_value(input).unwrap(),
Arc::default(),
project.clone(),

View File

@@ -211,7 +211,7 @@ impl Audio {
agc_source.set_enabled(LIVE_SETTINGS.control_input_volume.load(Ordering::Relaxed));
})
.replayable(REPLAY_DURATION)
.expect("REPLAY_DURATION is longer then 100ms");
.expect("REPLAY_DURATION is longer than 100ms");
cx.update_default_global(|this: &mut Self, _cx| {
let output_mixer = this

View File

@@ -57,7 +57,7 @@ impl<S: Source> RodioExt for S {
/// replay is being read
///
/// # Errors
/// If duration is smaller then 100ms
/// If duration is smaller than 100ms
fn replayable(
self,
duration: Duration,
@@ -151,7 +151,7 @@ impl<S: Source> Source for TakeSamples<S> {
struct ReplayQueue {
inner: ArrayQueue<Vec<Sample>>,
normal_chunk_len: usize,
/// The last chunk in the queue may be smaller then
/// The last chunk in the queue may be smaller than
/// the normal chunk size. This is always equal to the
/// size of the last element in the queue.
/// (so normally chunk_size)
@@ -535,7 +535,7 @@ mod tests {
let (mut replay, mut source) = input
.replayable(Duration::from_secs(3))
.expect("longer then 100ms");
.expect("longer than 100ms");
source.by_ref().take(3).count();
let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
@@ -552,7 +552,7 @@ mod tests {
let (mut replay, mut source) = input
.replayable(Duration::from_secs(2))
.expect("longer then 100ms");
.expect("longer than 100ms");
source.by_ref().take(5).count(); // get all items but do not end the source
let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
@@ -567,7 +567,7 @@ mod tests {
let (replay, mut source) = input
.replayable(Duration::from_secs(2))
.expect("longer then 100ms");
.expect("longer than 100ms");
// exhaust but do not yet end source
source.by_ref().take(40_000).count();
@@ -586,7 +586,7 @@ mod tests {
let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
let (mut replay, source) = input
.replayable(Duration::from_secs(2))
.expect("longer then 100ms");
.expect("longer than 100ms");
assert_eq!(replay.by_ref().samples_ready(), 0);
source.take(8000).count(); // half a second

View File

@@ -32,3 +32,6 @@ workspace-hack.workspace = true
[target.'cfg(not(target_os = "windows"))'.dependencies]
which.workspace = true
[dev-dependencies]
gpui = { workspace = true, "features" = ["test-support"] }

View File

@@ -34,7 +34,7 @@ actions!(
/// Checks for available updates.
Check,
/// Dismisses the update error message.
DismissErrorMessage,
DismissMessage,
/// Opens the release notes for the current version in a browser.
ViewReleaseNotes,
]
@@ -55,14 +55,14 @@ pub enum VersionCheckType {
Semantic(SemanticVersion),
}
#[derive(Clone, PartialEq, Eq)]
#[derive(Clone)]
pub enum AutoUpdateStatus {
Idle,
Checking,
Downloading { version: VersionCheckType },
Installing { version: VersionCheckType },
Updated { version: VersionCheckType },
Errored,
Errored { error: Arc<anyhow::Error> },
}
impl AutoUpdateStatus {
@@ -383,7 +383,9 @@ impl AutoUpdater {
}
UpdateCheckType::Manual => {
log::error!("auto-update failed: error:{:?}", error);
AutoUpdateStatus::Errored
AutoUpdateStatus::Errored {
error: Arc::new(error),
}
}
};
@@ -402,8 +404,8 @@ impl AutoUpdater {
self.status.clone()
}
pub fn dismiss_error(&mut self, cx: &mut Context<Self>) -> bool {
if self.status == AutoUpdateStatus::Idle {
pub fn dismiss(&mut self, cx: &mut Context<Self>) -> bool {
if let AutoUpdateStatus::Idle = self.status {
return false;
}
self.status = AutoUpdateStatus::Idle;
@@ -992,8 +994,27 @@ pub fn finalize_auto_update_on_quit() {
#[cfg(test)]
mod tests {
use gpui::TestAppContext;
use settings::default_settings;
use super::*;
#[gpui::test]
fn test_auto_update_defaults_to_true(cx: &mut TestAppContext) {
cx.update(|cx| {
let mut store = SettingsStore::new(cx);
store
.set_default_settings(&default_settings(), cx)
.expect("Unable to set default settings");
store
.set_user_settings("{}", cx)
.expect("Unable to set user settings");
cx.set_global(store);
AutoUpdateSetting::register(cx);
assert!(AutoUpdateSetting::get_global(cx).0);
});
}
#[test]
fn test_stable_does_not_update_when_fetched_version_is_not_higher() {
let release_channel = ReleaseChannel::Stable;

View File

@@ -22,7 +22,7 @@ use futures::{
channel::oneshot, future::BoxFuture,
};
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
use http_client::{HttpClient, HttpClientWithUrl, http};
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
use parking_lot::RwLock;
use postage::watch;
use proxy::connect_proxy_stream;
@@ -132,6 +132,20 @@ pub struct ProxySettings {
pub proxy: Option<String>,
}
impl ProxySettings {
pub fn proxy_url(&self) -> Option<Url> {
self.proxy
.as_ref()
.and_then(|input| {
input
.parse::<Url>()
.inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
.ok()
})
.or_else(read_proxy_from_env)
}
}
impl Settings for ProxySettings {
type FileContent = ProxySettingsContent;

View File

@@ -754,6 +754,10 @@ impl UserStore {
}
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
if self.plan().is_some_and(|plan| plan.is_v2()) {
return None;
}
self.model_request_usage
}

View File

@@ -39,7 +39,7 @@ pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
/// The name of the header used to indicate the the minimum required Zed version.
/// The name of the header used to indicate the minimum required Zed version.
///
/// This can be used to force a Zed upgrade in order to continue communicating
/// with the LLM service.
@@ -321,8 +321,8 @@ pub struct LanguageModel {
#[derive(Debug, Serialize, Deserialize)]
pub struct ListModelsResponse {
pub models: Vec<LanguageModel>,
pub default_model: LanguageModelId,
pub default_fast_model: LanguageModelId,
pub default_model: Option<LanguageModelId>,
pub default_fast_model: Option<LanguageModelId>,
pub recommended_models: Vec<LanguageModelId>,
}

View File

@@ -226,12 +226,6 @@ spec:
secretKeyRef:
name: supermaven
key: api_key
- name: USER_BACKFILLER_GITHUB_ACCESS_TOKEN
valueFrom:
secretKeyRef:
name: user-backfiller
key: github_access_token
optional: true
- name: INVITE_LINK_PREFIX
value: ${INVITE_LINK_PREFIX}
- name: RUST_BACKTRACE

View File

@@ -7,7 +7,6 @@ pub mod llm;
pub mod migrations;
pub mod rpc;
pub mod seed;
pub mod user_backfiller;
#[cfg(test)]
mod tests;
@@ -157,7 +156,6 @@ pub struct Config {
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
impl Config {
@@ -211,7 +209,6 @@ impl Config {
migrations_path: None,
seed_path: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,
kinesis_access_key: None,
kinesis_secret_key: None,

View File

@@ -11,7 +11,6 @@ use collab::ServiceMode;
use collab::api::CloudflareIpCountryHeader;
use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller;
use collab::{
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
executor::Executor, rpc::ResultExt,
@@ -114,7 +113,6 @@ async fn main() -> Result<()> {
if mode.is_api() {
fetch_extensions_from_blob_store_periodically(state.clone());
spawn_user_backfiller(state.clone());
app = app
.merge(collab::api::events::router())

View File

@@ -604,7 +604,6 @@ impl TestServer {
migrations_path: None,
seed_path: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,
kinesis_stream: None,
kinesis_access_key: None,

View File

@@ -1,165 +0,0 @@
use std::sync::Arc;
use anyhow::{Context as _, Result};
use chrono::{DateTime, Utc};
use util::ResultExt;
use crate::db::Database;
use crate::executor::Executor;
use crate::{AppState, Config};
pub fn spawn_user_backfiller(app_state: Arc<AppState>) {
let Some(user_backfiller_github_access_token) =
app_state.config.user_backfiller_github_access_token.clone()
else {
log::info!("no USER_BACKFILLER_GITHUB_ACCESS_TOKEN set; not spawning user backfiller");
return;
};
let executor = app_state.executor.clone();
executor.spawn_detached({
let executor = executor.clone();
async move {
let user_backfiller = UserBackfiller::new(
app_state.config.clone(),
user_backfiller_github_access_token,
app_state.db.clone(),
executor,
);
log::info!("backfilling users");
user_backfiller
.backfill_github_user_created_at()
.await
.log_err();
}
});
}
const GITHUB_REQUESTS_PER_HOUR_LIMIT: usize = 5_000;
const SLEEP_DURATION_BETWEEN_USERS: std::time::Duration = std::time::Duration::from_millis(
(GITHUB_REQUESTS_PER_HOUR_LIMIT as f64 / 60. / 60. * 1000.) as u64,
);
struct UserBackfiller {
config: Config,
github_access_token: Arc<str>,
db: Arc<Database>,
http_client: reqwest::Client,
executor: Executor,
}
impl UserBackfiller {
fn new(
config: Config,
github_access_token: Arc<str>,
db: Arc<Database>,
executor: Executor,
) -> Self {
Self {
config,
github_access_token,
db,
http_client: reqwest::Client::new(),
executor,
}
}
async fn backfill_github_user_created_at(&self) -> Result<()> {
let initial_channel_id = self.config.auto_join_channel_id;
let users_missing_github_user_created_at =
self.db.get_users_missing_github_user_created_at().await?;
for user in users_missing_github_user_created_at {
match self
.fetch_github_user(&format!(
"https://api.github.com/user/{}",
user.github_user_id
))
.await
{
Ok(github_user) => {
self.db
.update_or_create_user_by_github_account(
&user.github_login,
github_user.id,
user.email_address.as_deref(),
user.name.as_deref(),
github_user.created_at,
initial_channel_id,
)
.await?;
log::info!("backfilled user: {}", user.github_login);
}
Err(err) => {
log::error!("failed to fetch GitHub user {}: {err}", user.github_login);
}
}
self.executor.sleep(SLEEP_DURATION_BETWEEN_USERS).await;
}
Ok(())
}
async fn fetch_github_user(&self, url: &str) -> Result<GithubUser> {
let response = self
.http_client
.get(url)
.header(
"authorization",
format!("Bearer {}", self.github_access_token),
)
.header("user-agent", "zed")
.send()
.await
.with_context(|| format!("failed to fetch '{url}'"))?;
let rate_limit_remaining = response
.headers()
.get("x-ratelimit-remaining")
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<i32>().ok());
let rate_limit_reset = response
.headers()
.get("x-ratelimit-reset")
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<i64>().ok())
.and_then(|value| DateTime::from_timestamp(value, 0));
if rate_limit_remaining == Some(0)
&& let Some(reset_at) = rate_limit_reset
{
let now = Utc::now();
if reset_at > now {
let sleep_duration = reset_at - now;
log::info!(
"rate limit reached. Sleeping for {} seconds",
sleep_duration.num_seconds()
);
self.executor.sleep(sleep_duration.to_std().unwrap()).await;
}
}
response
.error_for_status()
.context("fetching GitHub user")?
.json()
.await
.with_context(|| format!("failed to deserialize GitHub user from '{url}'"))
}
}
#[derive(serde::Deserialize)]
struct GithubUser {
id: i32,
created_at: DateTime<Utc>,
#[expect(
unused,
reason = "This field was found to be unused with serde library bump; it's left as is due to insufficient context on PO's side, but it *may* be fine to remove"
)]
name: Option<String>,
}

21
crates/denoise/Cargo.toml Normal file
View File

@@ -0,0 +1,21 @@
[package]
name = "denoise"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[dependencies]
candle-core = { version = "0.9.1", git ="https://github.com/zed-industries/candle", branch = "9.1-patched" }
candle-onnx = { version = "0.9.1", git ="https://github.com/zed-industries/candle", branch = "9.1-patched" }
log.workspace = true
rodio = { workspace = true, features = ["wav_output"] }
rustfft = { version = "6.2.0", features = ["avx"] }
realfft = "3.4.0"
thiserror.workspace = true
workspace-hack.workspace = true

1
crates/denoise/LICENSE-GPL Symbolic link
View File

@@ -0,0 +1 @@
LICENSE-GPL

20
crates/denoise/README.md Normal file
View File

@@ -0,0 +1,20 @@
Real time streaming audio denoising using a [Dual-Signal Transformation LSTM Network for Real-Time Noise Suppression](https://arxiv.org/abs/2005.07551).
Trivial to build as it uses the native rust Candle crate for inference. Easy to integrate into any Rodio pipeline.
```rust
# use rodio::{nz, source::UniformSourceIterator, wav_to_file};
let file = std::fs::File::open("clips_airconditioning.wav")?;
let decoder = rodio::Decoder::try_from(file)?;
let resampled = UniformSourceIterator::new(decoder, nz!(1), nz!(16_000));
let mut denoised = denoise::Denoiser::try_new(resampled)?;
wav_to_file(&mut denoised, "denoised.wav")?;
Result::Ok<(), Box<dyn std::error::Error>>
```
## Acknowledgements & License
The trained models in this repo are optimized versions of the models in the [breizhn/DTLN](https://github.com/breizhn/DTLN?tab=readme-ov-file#model-conversion-and-real-time-processing-with-onnx). These are licensed under MIT.
The FFT code was adapted from Datadog's [dtln-rs Repo](https://github.com/DataDog/dtln-rs/tree/main) also licensed under MIT.

View File

@@ -0,0 +1,11 @@
use rodio::{nz, source::UniformSourceIterator, wav_to_file};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let file = std::fs::File::open("airconditioning.wav")?;
let decoder = rodio::Decoder::try_from(file)?;
let resampled = UniformSourceIterator::new(decoder, nz!(1), nz!(16_000));
let mut denoised = denoise::Denoiser::try_new(resampled)?;
wav_to_file(&mut denoised, "denoised.wav")?;
Ok(())
}

View File

@@ -0,0 +1,23 @@
use std::time::Duration;
use rodio::Source;
use rodio::wav_to_file;
use rodio::{nz, source::UniformSourceIterator};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let file = std::fs::File::open("clips_airconditioning.wav")?;
let decoder = rodio::Decoder::try_from(file)?;
let resampled = UniformSourceIterator::new(decoder, nz!(1), nz!(16_000));
let mut enabled = true;
let denoised = denoise::Denoiser::try_new(resampled)?.periodic_access(
Duration::from_secs(2),
|denoised| {
enabled = !enabled;
denoised.set_enabled(enabled);
},
);
wav_to_file(denoised, "processed.wav")?;
Ok(())
}

View File

@@ -0,0 +1,204 @@
/// use something like https://netron.app/ to inspect the models and understand
/// the flow
use std::collections::HashMap;
use candle_core::{Device, IndexOp, Tensor};
use candle_onnx::onnx::ModelProto;
use candle_onnx::prost::Message;
use realfft::RealFftPlanner;
use rustfft::num_complex::Complex;
pub struct Engine {
spectral_model: ModelProto,
signal_model: ModelProto,
fft_planner: RealFftPlanner<f32>,
fft_scratch: Vec<Complex<f32>>,
spectrum: [Complex<f32>; FFT_OUT_SIZE],
signal: [f32; BLOCK_LEN],
in_magnitude: [f32; FFT_OUT_SIZE],
in_phase: [f32; FFT_OUT_SIZE],
spectral_memory: Tensor,
signal_memory: Tensor,
in_buffer: [f32; BLOCK_LEN],
out_buffer: [f32; BLOCK_LEN],
}
// 32 ms @ 16khz per DTLN docs: https://github.com/breizhn/DTLN
pub const BLOCK_LEN: usize = 512;
// 8 ms @ 16khz per DTLN docs.
pub const BLOCK_SHIFT: usize = 128;
pub const FFT_OUT_SIZE: usize = BLOCK_LEN / 2 + 1;
impl Engine {
pub fn new() -> Self {
let mut fft_planner = RealFftPlanner::new();
let fft_planned = fft_planner.plan_fft_forward(BLOCK_LEN);
let scratch_len = fft_planned.get_scratch_len();
Self {
// Models are 1.5MB and 2.5MB respectively. Its worth the binary
// size increase not to have to distribute the models separately.
spectral_model: ModelProto::decode(
include_bytes!("../models/model_1_converted_simplified.onnx").as_slice(),
)
.expect("The model should decode"),
signal_model: ModelProto::decode(
include_bytes!("../models/model_2_converted_simplified.onnx").as_slice(),
)
.expect("The model should decode"),
fft_planner,
fft_scratch: vec![Complex::ZERO; scratch_len],
spectrum: [Complex::ZERO; FFT_OUT_SIZE],
signal: [0f32; BLOCK_LEN],
in_magnitude: [0f32; FFT_OUT_SIZE],
in_phase: [0f32; FFT_OUT_SIZE],
spectral_memory: Tensor::from_slice::<_, f32>(
&[0f32; 512],
(1, 2, BLOCK_SHIFT, 2),
&Device::Cpu,
)
.expect("Tensor has the correct dimensions"),
signal_memory: Tensor::from_slice::<_, f32>(
&[0f32; 512],
(1, 2, BLOCK_SHIFT, 2),
&Device::Cpu,
)
.expect("Tensor has the correct dimensions"),
out_buffer: [0f32; BLOCK_LEN],
in_buffer: [0f32; BLOCK_LEN],
}
}
/// Add a clunk of samples and get the denoised chunk 4 feeds later
pub fn feed(&mut self, samples: &[f32]) -> [f32; BLOCK_SHIFT] {
/// The name of the output node of the onnx network
/// [Dual-Signal Transformation LSTM Network for Real-Time Noise Suppression](https://arxiv.org/abs/2005.07551).
const MEMORY_OUTPUT: &'static str = "Identity_1";
debug_assert_eq!(samples.len(), BLOCK_SHIFT);
// place new samples at the end of the `in_buffer`
self.in_buffer.copy_within(BLOCK_SHIFT.., 0);
self.in_buffer[(BLOCK_LEN - BLOCK_SHIFT)..].copy_from_slice(&samples);
// run inference
let inputs = self.spectral_inputs();
let mut spectral_outputs = candle_onnx::simple_eval(&self.spectral_model, inputs)
.expect("The embedded file must be valid");
self.spectral_memory = spectral_outputs
.remove(MEMORY_OUTPUT)
.expect("The model has an output named Identity_1");
let inputs = self.signal_inputs(spectral_outputs);
let mut signal_outputs = candle_onnx::simple_eval(&self.signal_model, inputs)
.expect("The embedded file must be valid");
self.signal_memory = signal_outputs
.remove(MEMORY_OUTPUT)
.expect("The model has an output named Identity_1");
let model_output = model_outputs(signal_outputs);
// place processed samples at the start of the `out_buffer`
// shift the rest left, fill the end with zeros. Zeros are needed as
// the out buffer is part of the input of the network
self.out_buffer.copy_within(BLOCK_SHIFT.., 0);
self.out_buffer[BLOCK_LEN - BLOCK_SHIFT..].fill(0f32);
for (a, b) in self.out_buffer.iter_mut().zip(model_output) {
*a += b;
}
// samples at the front of the `out_buffer` are now denoised
self.out_buffer[..BLOCK_SHIFT]
.try_into()
.expect("len is correct")
}
fn spectral_inputs(&mut self) -> HashMap<String, Tensor> {
// Prepare FFT input
let fft = self.fft_planner.plan_fft_forward(BLOCK_LEN);
// Perform real-to-complex FFT
let mut fft_in = self.in_buffer;
fft.process_with_scratch(&mut fft_in, &mut self.spectrum, &mut self.fft_scratch)
.expect("The fft should run, there is enough scratch space");
// Generate magnitude and phase
for ((magnitude, phase), complex) in self
.in_magnitude
.iter_mut()
.zip(self.in_phase.iter_mut())
.zip(self.spectrum)
{
*magnitude = complex.norm();
*phase = complex.arg();
}
const SPECTRUM_INPUT: &str = "input_2";
const MEMORY_INPUT: &str = "input_3";
let memory_input =
Tensor::from_slice::<_, f32>(&self.in_magnitude, (1, 1, FFT_OUT_SIZE), &Device::Cpu)
.expect("the in magnitude has enough elements to fill the Tensor");
let inputs = HashMap::from([
(MEMORY_INPUT.to_string(), memory_input),
(SPECTRUM_INPUT.to_string(), self.spectral_memory.clone()),
]);
inputs
}
fn signal_inputs(&mut self, outputs: HashMap<String, Tensor>) -> HashMap<String, Tensor> {
let magnitude_weight = model_outputs(outputs);
// Apply mask and reconstruct complex spectrum
let mut spectrum = [Complex::I; FFT_OUT_SIZE];
for i in 0..FFT_OUT_SIZE {
let magnitude = self.in_magnitude[i] * magnitude_weight[i];
let phase = self.in_phase[i];
let real = magnitude * phase.cos();
let imag = magnitude * phase.sin();
spectrum[i] = Complex::new(real, imag);
}
// Handle DC component (i = 0)
let magnitude = self.in_magnitude[0] * magnitude_weight[0];
spectrum[0] = Complex::new(magnitude, 0.0);
// Handle Nyquist component (i = N/2)
let magnitude = self.in_magnitude[FFT_OUT_SIZE - 1] * magnitude_weight[FFT_OUT_SIZE - 1];
spectrum[FFT_OUT_SIZE - 1] = Complex::new(magnitude, 0.0);
// Perform complex-to-real IFFT
let ifft = self.fft_planner.plan_fft_inverse(BLOCK_LEN);
ifft.process_with_scratch(&mut spectrum, &mut self.signal, &mut self.fft_scratch)
.expect("The fft should run, there is enough scratch space");
// Normalize the IFFT output
for real in &mut self.signal {
*real /= BLOCK_LEN as f32;
}
const SIGNAL_INPUT: &str = "input_4";
const SIGNAL_MEMORY: &str = "input_5";
let signal_input =
Tensor::from_slice::<_, f32>(&self.signal, (1, 1, BLOCK_LEN), &Device::Cpu).unwrap();
HashMap::from([
(SIGNAL_INPUT.to_string(), signal_input),
(SIGNAL_MEMORY.to_string(), self.signal_memory.clone()),
])
}
}
// Both models put their outputs in the same location
fn model_outputs(mut outputs: HashMap<String, Tensor>) -> Vec<f32> {
const NON_MEMORY_OUTPUT: &str = "Identity";
outputs
.remove(NON_MEMORY_OUTPUT)
.expect("The model has this output")
.i((0, 0))
.and_then(|tensor| tensor.to_vec1())
.expect("The tensor has the correct dimensions")
}

270
crates/denoise/src/lib.rs Normal file
View File

@@ -0,0 +1,270 @@
mod engine;
use core::fmt;
use std::{collections::VecDeque, sync::mpsc, thread};
pub use engine::Engine;
use rodio::{ChannelCount, Sample, SampleRate, Source, nz};
use crate::engine::BLOCK_SHIFT;
const SUPPORTED_SAMPLE_RATE: SampleRate = nz!(16_000);
const SUPPORTED_CHANNEL_COUNT: ChannelCount = nz!(1);
pub struct Denoiser<S: Source> {
inner: S,
input_tx: mpsc::Sender<[Sample; BLOCK_SHIFT]>,
denoised_rx: mpsc::Receiver<[Sample; BLOCK_SHIFT]>,
ready: [Sample; BLOCK_SHIFT],
next: usize,
state: IterState,
// When disabled instead of reading denoised sub-blocks from the engine through
// `denoised_rx` we read unprocessed from this queue. This maintains the same
// latency so we can 'trivially' re-enable
queued: Queue,
}
impl<S: Source> fmt::Debug for Denoiser<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Denoiser")
.field("state", &self.state)
.finish_non_exhaustive()
}
}
struct Queue(VecDeque<[Sample; BLOCK_SHIFT]>);
impl Queue {
fn new() -> Self {
Self(VecDeque::new())
}
fn push(&mut self, block: [Sample; BLOCK_SHIFT]) {
self.0.push_back(block);
self.0.resize(4, [0f32; BLOCK_SHIFT]);
}
fn pop(&mut self) -> [Sample; BLOCK_SHIFT] {
debug_assert!(self.0.len() == 4);
self.0.pop_front().expect(
"There is no State where the queue is popped while there are less then 4 entries",
)
}
}
#[derive(Debug, Clone, Copy)]
pub enum IterState {
Enabled,
StartingMidAudio { fed_to_denoiser: usize },
Disabled,
Startup { enabled: bool },
}
#[derive(Debug, thiserror::Error)]
pub enum DenoiserError {
#[error("This denoiser only works on sources with samplerate 16000")]
UnsupportedSampleRate,
#[error("This denoiser only works on mono sources (1 channel)")]
UnsupportedChannelCount,
}
// todo dvdsk needs constant source upstream in rodio
impl<S: Source> Denoiser<S> {
pub fn try_new(source: S) -> Result<Self, DenoiserError> {
if source.sample_rate() != SUPPORTED_SAMPLE_RATE {
return Err(DenoiserError::UnsupportedSampleRate);
}
if source.channels() != SUPPORTED_CHANNEL_COUNT {
return Err(DenoiserError::UnsupportedChannelCount);
}
let (input_tx, input_rx) = mpsc::channel();
let (denoised_tx, denoised_rx) = mpsc::channel();
thread::spawn(move || {
run_neural_denoiser(denoised_tx, input_rx);
});
Ok(Self {
inner: source,
input_tx,
denoised_rx,
ready: [0.0; BLOCK_SHIFT],
state: IterState::Startup { enabled: true },
next: BLOCK_SHIFT,
queued: Queue::new(),
})
}
pub fn set_enabled(&mut self, enabled: bool) {
self.state = match (enabled, self.state) {
(false, IterState::StartingMidAudio { .. }) | (false, IterState::Enabled) => {
IterState::Disabled
}
(false, IterState::Startup { enabled: true }) => IterState::Startup { enabled: false },
(true, IterState::Disabled) => IterState::StartingMidAudio { fed_to_denoiser: 0 },
(_, state) => state,
};
}
fn feed(&self, sub_block: [f32; BLOCK_SHIFT]) {
self.input_tx.send(sub_block).unwrap();
}
}
fn run_neural_denoiser(
denoised_tx: mpsc::Sender<[f32; BLOCK_SHIFT]>,
input_rx: mpsc::Receiver<[f32; BLOCK_SHIFT]>,
) {
let mut engine = Engine::new();
loop {
let Ok(sub_block) = input_rx.recv() else {
// tx must have dropped, stop thread
break;
};
let denoised_sub_block = engine.feed(&sub_block);
if denoised_tx.send(denoised_sub_block).is_err() {
break;
}
}
}
impl<S: Source> Source for Denoiser<S> {
fn current_span_len(&self) -> Option<usize> {
self.inner.current_span_len()
}
fn channels(&self) -> rodio::ChannelCount {
self.inner.channels()
}
fn sample_rate(&self) -> rodio::SampleRate {
self.inner.sample_rate()
}
fn total_duration(&self) -> Option<std::time::Duration> {
self.inner.total_duration()
}
}
impl<S: Source> Iterator for Denoiser<S> {
type Item = Sample;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.next += 1;
if self.next < self.ready.len() {
let sample = self.ready[self.next];
return Some(sample);
}
// This is a separate function to prevent it from being inlined
// as this code only runs once every 128 samples
self.prepare_next_ready()
.inspect_err(|_| {
log::error!("Denoise engine crashed");
})
.ok()
.flatten()
}
}
#[derive(Debug, thiserror::Error)]
#[error("Could not send or receive from denoise thread. It must have crashed")]
struct DenoiseEngineCrashed;
impl<S: Source> Denoiser<S> {
#[cold]
fn prepare_next_ready(&mut self) -> Result<Option<f32>, DenoiseEngineCrashed> {
self.state = match self.state {
IterState::Startup { enabled } => {
// guaranteed to be coming from silence
for _ in 0..3 {
let Some(sub_block) = read_sub_block(&mut self.inner) else {
return Ok(None);
};
self.queued.push(sub_block);
self.input_tx
.send(sub_block)
.map_err(|_| DenoiseEngineCrashed)?;
}
let Some(sub_block) = read_sub_block(&mut self.inner) else {
return Ok(None);
};
self.queued.push(sub_block);
self.input_tx
.send(sub_block)
.map_err(|_| DenoiseEngineCrashed)?;
// throw out old blocks that are denoised silence
let _ = self.denoised_rx.iter().take(3).count();
self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
let Some(sub_block) = read_sub_block(&mut self.inner) else {
return Ok(None);
};
self.queued.push(sub_block);
self.feed(sub_block);
if enabled {
IterState::Enabled
} else {
IterState::Disabled
}
}
IterState::Enabled => {
self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
let Some(sub_block) = read_sub_block(&mut self.inner) else {
return Ok(None);
};
self.queued.push(sub_block);
self.input_tx
.send(sub_block)
.map_err(|_| DenoiseEngineCrashed)?;
IterState::Enabled
}
IterState::Disabled => {
// Need to maintain the same 512 samples delay such that
// we can re-enable at any point.
self.ready = self.queued.pop();
let Some(sub_block) = read_sub_block(&mut self.inner) else {
return Ok(None);
};
self.queued.push(sub_block);
IterState::Disabled
}
IterState::StartingMidAudio {
fed_to_denoiser: mut sub_blocks_fed,
} => {
self.ready = self.queued.pop();
let Some(sub_block) = read_sub_block(&mut self.inner) else {
return Ok(None);
};
self.queued.push(sub_block);
self.input_tx
.send(sub_block)
.map_err(|_| DenoiseEngineCrashed)?;
sub_blocks_fed += 1;
if sub_blocks_fed > 4 {
// throw out partially denoised blocks,
// next will be correctly denoised
let _ = self.denoised_rx.iter().take(3).count();
IterState::Enabled
} else {
IterState::StartingMidAudio {
fed_to_denoiser: sub_blocks_fed,
}
}
}
};
self.next = 0;
Ok(Some(self.ready[0]))
}
}
fn read_sub_block(s: &mut impl Source) -> Option<[f32; BLOCK_SHIFT]> {
let mut res = [0f32; BLOCK_SHIFT];
for sample in &mut res {
*sample = s.next()?;
}
Some(res)
}

View File

@@ -0,0 +1,46 @@
[package]
name = "edit_prediction_context"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/edit_prediction_context.rs"
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
itertools.workspace = true
language.workspace = true
log.workspace = true
ordered-float.workspace = true
project.workspace = true
regex.workspace = true
serde.workspace = true
slotmap.workspace = true
strum.workspace = true
text.workspace = true
tree-sitter.workspace = true
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
clap.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = {workspace= true, features = ["test-support"]}
serde_json.workspace = true
settings = {workspace= true, features = ["test-support"]}
text = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
zlog.workspace = true

View File

@@ -0,0 +1 @@
../../LICENSE-GPL

View File

@@ -0,0 +1,193 @@
use language::LanguageId;
use project::ProjectEntryId;
use std::borrow::Cow;
use std::ops::Range;
use std::sync::Arc;
use text::{Bias, BufferId, Rope};
use crate::outline::OutlineDeclaration;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Identifier {
pub name: Arc<str>,
pub language_id: LanguageId,
}
slotmap::new_key_type! {
pub struct DeclarationId;
}
#[derive(Debug, Clone)]
pub enum Declaration {
File {
project_entry_id: ProjectEntryId,
declaration: FileDeclaration,
},
Buffer {
project_entry_id: ProjectEntryId,
buffer_id: BufferId,
rope: Rope,
declaration: BufferDeclaration,
},
}
const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024;
impl Declaration {
pub fn identifier(&self) -> &Identifier {
match self {
Declaration::File { declaration, .. } => &declaration.identifier,
Declaration::Buffer { declaration, .. } => &declaration.identifier,
}
}
pub fn project_entry_id(&self) -> Option<ProjectEntryId> {
match self {
Declaration::File {
project_entry_id, ..
} => Some(*project_entry_id),
Declaration::Buffer {
project_entry_id, ..
} => Some(*project_entry_id),
}
}
pub fn item_text(&self) -> (Cow<'_, str>, bool) {
match self {
Declaration::File { declaration, .. } => (
declaration.text.as_ref().into(),
declaration.text_is_truncated,
),
Declaration::Buffer {
rope, declaration, ..
} => (
rope.chunks_in_range(declaration.item_range.clone())
.collect::<Cow<str>>(),
declaration.item_range_is_truncated,
),
}
}
pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
match self {
Declaration::File { declaration, .. } => (
declaration.text[declaration.signature_range_in_text.clone()].into(),
declaration.signature_is_truncated,
),
Declaration::Buffer {
rope, declaration, ..
} => (
rope.chunks_in_range(declaration.signature_range.clone())
.collect::<Cow<str>>(),
declaration.signature_range_is_truncated,
),
}
}
}
fn expand_range_to_line_boundaries_and_truncate(
range: &Range<usize>,
limit: usize,
rope: &Rope,
) -> (Range<usize>, bool) {
let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
point_range.start.column = 0;
point_range.end.row += 1;
point_range.end.column = 0;
let mut item_range =
rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end);
let is_truncated = item_range.len() > limit;
if is_truncated {
item_range.end = item_range.start + limit;
}
item_range.end = rope.clip_offset(item_range.end, Bias::Left);
(item_range, is_truncated)
}
#[derive(Debug, Clone)]
pub struct FileDeclaration {
pub parent: Option<DeclarationId>,
pub identifier: Identifier,
/// offset range of the declaration in the file, expanded to line boundaries and truncated
pub item_range_in_file: Range<usize>,
/// text of `item_range_in_file`
pub text: Arc<str>,
/// whether `text` was truncated
pub text_is_truncated: bool,
/// offset range of the signature within `text`
pub signature_range_in_text: Range<usize>,
/// whether `signature` was truncated
pub signature_is_truncated: bool,
}
impl FileDeclaration {
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate(
&declaration.item_range,
ITEM_TEXT_TRUNCATION_LENGTH,
rope,
);
// TODO: consider logging if unexpected
let signature_start = declaration
.signature_range
.start
.saturating_sub(item_range_in_file.start);
let mut signature_end = declaration
.signature_range
.end
.saturating_sub(item_range_in_file.start);
let signature_is_truncated = signature_end > item_range_in_file.len();
if signature_is_truncated {
signature_end = item_range_in_file.len();
}
FileDeclaration {
parent: None,
identifier: declaration.identifier,
signature_range_in_text: signature_start..signature_end,
signature_is_truncated,
text: rope
.chunks_in_range(item_range_in_file.clone())
.collect::<String>()
.into(),
text_is_truncated,
item_range_in_file,
}
}
}
#[derive(Debug, Clone)]
pub struct BufferDeclaration {
pub parent: Option<DeclarationId>,
pub identifier: Identifier,
pub item_range: Range<usize>,
pub item_range_is_truncated: bool,
pub signature_range: Range<usize>,
pub signature_range_is_truncated: bool,
}
impl BufferDeclaration {
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate(
&declaration.item_range,
ITEM_TEXT_TRUNCATION_LENGTH,
rope,
);
let (signature_range, signature_range_is_truncated) =
expand_range_to_line_boundaries_and_truncate(
&declaration.signature_range,
ITEM_TEXT_TRUNCATION_LENGTH,
rope,
);
Self {
parent: None,
identifier: declaration.identifier,
item_range,
item_range_is_truncated,
signature_range,
signature_range_is_truncated,
}
}
}

View File

@@ -0,0 +1,324 @@
use itertools::Itertools as _;
use language::BufferSnapshot;
use ordered_float::OrderedFloat;
use serde::Serialize;
use std::{collections::HashMap, ops::Range};
use strum::EnumIter;
use text::{OffsetRangeExt, Point, ToPoint};
use crate::{
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
reference::{Reference, ReferenceRegion},
syntax_index::SyntaxIndexState,
text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
};
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
// TODO:
//
// * Consider adding declaration_file_count
#[derive(Clone, Debug)]
pub struct ScoredSnippet {
pub identifier: Identifier,
pub declaration: Declaration,
pub score_components: ScoreInputs,
pub scores: Scores,
}
// TODO: Consider having "Concise" style corresponding to `concise_text`
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum SnippetStyle {
Signature,
Declaration,
}
impl ScoredSnippet {
/// Returns the score for this snippet with the specified style.
pub fn score(&self, style: SnippetStyle) -> f32 {
match style {
SnippetStyle::Signature => self.scores.signature,
SnippetStyle::Declaration => self.scores.declaration,
}
}
pub fn size(&self, style: SnippetStyle) -> usize {
// TODO: how to handle truncation?
match &self.declaration {
Declaration::File { declaration, .. } => match style {
SnippetStyle::Signature => declaration.signature_range_in_text.len(),
SnippetStyle::Declaration => declaration.text.len(),
},
Declaration::Buffer { declaration, .. } => match style {
SnippetStyle::Signature => declaration.signature_range.len(),
SnippetStyle::Declaration => declaration.item_range.len(),
},
}
}
pub fn score_density(&self, style: SnippetStyle) -> f32 {
self.score(style) / (self.size(style)) as f32
}
}
pub fn scored_snippets(
index: &SyntaxIndexState,
excerpt: &EditPredictionExcerpt,
excerpt_text: &EditPredictionExcerptText,
identifier_to_references: HashMap<Identifier, Vec<Reference>>,
cursor_offset: usize,
current_buffer: &BufferSnapshot,
) -> Vec<ScoredSnippet> {
let containing_range_identifier_occurrences =
IdentifierOccurrences::within_string(&excerpt_text.body);
let cursor_point = cursor_offset.to_point(&current_buffer);
let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
let end_point = Point::new(cursor_point.row + 1, 0);
let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
&current_buffer
.text_for_range(start_point..end_point)
.collect::<String>(),
);
let mut snippets = identifier_to_references
.into_iter()
.flat_map(|(identifier, references)| {
let declarations =
index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
let declaration_count = declarations.len();
declarations
.iter()
.filter_map(|declaration| match declaration {
Declaration::Buffer {
buffer_id,
declaration: buffer_declaration,
..
} => {
let is_same_file = buffer_id == &current_buffer.remote_id();
if is_same_file {
range_intersection(
&buffer_declaration.item_range.to_offset(&current_buffer),
&excerpt.range,
)
.is_none()
.then(|| {
let declaration_line = buffer_declaration
.item_range
.start
.to_point(current_buffer)
.row;
(
true,
(cursor_point.row as i32 - declaration_line as i32).abs()
as u32,
declaration,
)
})
} else {
Some((false, 0, declaration))
}
}
Declaration::File { .. } => {
// We can assume that a file declaration is in a different file,
// because the current one must be open
Some((false, 0, declaration))
}
})
.sorted_by_key(|&(_, distance, _)| distance)
.enumerate()
.map(
|(
declaration_line_distance_rank,
(is_same_file, declaration_line_distance, declaration),
)| {
let same_file_declaration_count = index.file_declaration_count(declaration);
score_snippet(
&identifier,
&references,
declaration.clone(),
is_same_file,
declaration_line_distance,
declaration_line_distance_rank,
same_file_declaration_count,
declaration_count,
&containing_range_identifier_occurrences,
&adjacent_identifier_occurrences,
cursor_point,
current_buffer,
)
},
)
.collect::<Vec<_>>()
})
.flatten()
.collect::<Vec<_>>();
snippets.sort_unstable_by_key(|snippet| {
OrderedFloat(
snippet
.score_density(SnippetStyle::Declaration)
.max(snippet.score_density(SnippetStyle::Signature)),
)
});
snippets
}
fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
let start = a.start.clone().max(b.start.clone());
let end = a.end.clone().min(b.end.clone());
if start < end {
Some(Range { start, end })
} else {
None
}
}
fn score_snippet(
identifier: &Identifier,
references: &[Reference],
declaration: Declaration,
is_same_file: bool,
declaration_line_distance: u32,
declaration_line_distance_rank: usize,
same_file_declaration_count: usize,
declaration_count: usize,
containing_range_identifier_occurrences: &IdentifierOccurrences,
adjacent_identifier_occurrences: &IdentifierOccurrences,
cursor: Point,
current_buffer: &BufferSnapshot,
) -> Option<ScoredSnippet> {
let is_referenced_nearby = references
.iter()
.any(|r| r.region == ReferenceRegion::Nearby);
let is_referenced_in_breadcrumb = references
.iter()
.any(|r| r.region == ReferenceRegion::Breadcrumb);
let reference_count = references.len();
let reference_line_distance = references
.iter()
.map(|r| {
let reference_line = r.range.start.to_point(current_buffer).row as i32;
(cursor.row as i32 - reference_line).abs() as u32
})
.min()
.unwrap();
let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
let item_signature_occurrences =
IdentifierOccurrences::within_string(&declaration.signature_text().0);
let containing_range_vs_item_jaccard = jaccard_similarity(
containing_range_identifier_occurrences,
&item_source_occurrences,
);
let containing_range_vs_signature_jaccard = jaccard_similarity(
containing_range_identifier_occurrences,
&item_signature_occurrences,
);
let adjacent_vs_item_jaccard =
jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
let adjacent_vs_signature_jaccard =
jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
containing_range_identifier_occurrences,
&item_source_occurrences,
);
let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
containing_range_identifier_occurrences,
&item_signature_occurrences,
);
let adjacent_vs_item_weighted_overlap =
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
let adjacent_vs_signature_weighted_overlap =
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
let score_components = ScoreInputs {
is_same_file,
is_referenced_nearby,
is_referenced_in_breadcrumb,
reference_line_distance,
declaration_line_distance,
declaration_line_distance_rank,
reference_count,
same_file_declaration_count,
declaration_count,
containing_range_vs_item_jaccard,
containing_range_vs_signature_jaccard,
adjacent_vs_item_jaccard,
adjacent_vs_signature_jaccard,
containing_range_vs_item_weighted_overlap,
containing_range_vs_signature_weighted_overlap,
adjacent_vs_item_weighted_overlap,
adjacent_vs_signature_weighted_overlap,
};
Some(ScoredSnippet {
identifier: identifier.clone(),
declaration: declaration,
scores: score_components.score(),
score_components,
})
}
#[derive(Clone, Debug, Serialize)]
pub struct ScoreInputs {
pub is_same_file: bool,
pub is_referenced_nearby: bool,
pub is_referenced_in_breadcrumb: bool,
pub reference_count: usize,
pub same_file_declaration_count: usize,
pub declaration_count: usize,
pub reference_line_distance: u32,
pub declaration_line_distance: u32,
pub declaration_line_distance_rank: usize,
pub containing_range_vs_item_jaccard: f32,
pub containing_range_vs_signature_jaccard: f32,
pub adjacent_vs_item_jaccard: f32,
pub adjacent_vs_signature_jaccard: f32,
pub containing_range_vs_item_weighted_overlap: f32,
pub containing_range_vs_signature_weighted_overlap: f32,
pub adjacent_vs_item_weighted_overlap: f32,
pub adjacent_vs_signature_weighted_overlap: f32,
}
#[derive(Clone, Debug, Serialize)]
pub struct Scores {
pub signature: f32,
pub declaration: f32,
}
impl ScoreInputs {
fn score(&self) -> Scores {
// Score related to how likely this is the correct declaration, range 0 to 1
let accuracy_score = if self.is_same_file {
// TODO: use declaration_line_distance_rank
1.0 / self.same_file_declaration_count as f32
} else {
1.0 / self.declaration_count as f32
};
// Score related to the distance between the reference and cursor, range 0 to 1
let distance_score = if self.is_referenced_nearby {
1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
} else {
// same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
0.5
};
// For now instead of linear combination, the scores are just multiplied together.
let combined_score = 10.0 * accuracy_score * distance_score;
Scores {
signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
// declaration score gets boosted both by being multipled by 2 and by there being more
// weighted overlap.
declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
}
}
}

View File

@@ -0,0 +1,220 @@
mod declaration;
mod declaration_scoring;
mod excerpt;
mod outline;
mod reference;
mod syntax_index;
mod text_similarity;
pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
use gpui::{App, AppContext as _, Entity, Task};
use language::BufferSnapshot;
pub use reference::references_in_excerpt;
pub use syntax_index::SyntaxIndex;
use text::{Point, ToOffset as _};
use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
pub struct EditPredictionContext {
pub excerpt: EditPredictionExcerpt,
pub excerpt_text: EditPredictionExcerptText,
pub snippets: Vec<ScoredSnippet>,
}
impl EditPredictionContext {
pub fn gather(
cursor_point: Point,
buffer: BufferSnapshot,
excerpt_options: EditPredictionExcerptOptions,
syntax_index: Entity<SyntaxIndex>,
cx: &mut App,
) -> Task<Self> {
let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
cx.background_spawn(async move {
let index_state = index_state.lock().await;
let excerpt =
EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)
.unwrap();
let excerpt_text = excerpt.text(&buffer);
let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
let cursor_offset = cursor_point.to_offset(&buffer);
let snippets = scored_snippets(
&index_state,
&excerpt,
&excerpt_text,
references,
cursor_offset,
&buffer,
);
Self {
excerpt,
excerpt_text,
snippets,
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use gpui::{Entity, TestAppContext};
use indoc::indoc;
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
use crate::{EditPredictionExcerptOptions, SyntaxIndex};
#[gpui::test]
async fn test_call_site(cx: &mut TestAppContext) {
let (project, index, _rust_lang_id) = init_test(cx).await;
let buffer = project
.update(cx, |project, cx| {
let project_path = project.find_project_path("c.rs", cx).unwrap();
project.open_buffer(project_path, cx)
})
.await
.unwrap();
cx.run_until_parked();
// first process_data call site
let cursor_point = language::Point::new(8, 21);
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let context = cx
.update(|cx| {
EditPredictionContext::gather(
cursor_point,
buffer_snapshot,
EditPredictionExcerptOptions {
max_bytes: 40,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
},
index,
cx,
)
})
.await;
assert_eq!(context.snippets.len(), 1);
assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
drop(buffer);
}
async fn init_test(
cx: &mut TestAppContext,
) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"a.rs": indoc! {r#"
fn main() {
let x = 1;
let y = 2;
let z = add(x, y);
println!("Result: {}", z);
}
fn add(a: i32, b: i32) -> i32 {
a + b
}
"#},
"b.rs": indoc! {"
pub struct Config {
pub name: String,
pub value: i32,
}
impl Config {
pub fn new(name: String, value: i32) -> Self {
Config { name, value }
}
}
"},
"c.rs": indoc! {r#"
use std::collections::HashMap;
fn main() {
let args: Vec<String> = std::env::args().collect();
let data: Vec<i32> = args[1..]
.iter()
.filter_map(|s| s.parse().ok())
.collect();
let result = process_data(data);
println!("{:?}", result);
}
fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
let mut counts = HashMap::new();
for value in data {
*counts.entry(value).or_insert(0) += 1;
}
counts
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_process_data() {
let data = vec![1, 2, 2, 3];
let result = process_data(data);
assert_eq!(result.get(&2), Some(&2));
}
}
"#}
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
let lang = rust_lang();
let lang_id = lang.id();
language_registry.add(Arc::new(lang));
let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
cx.run_until_parked();
(project, index, lang_id)
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
.unwrap()
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
.unwrap()
}
}

View File

@@ -0,0 +1,616 @@
use language::BufferSnapshot;
use std::ops::Range;
use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _};
use tree_sitter::{Node, TreeCursor};
use util::RangeExt;
// TODO:
//
// - Test parent signatures
//
// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
// planning.
//
// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
// paragraph).
//
// - Truncation of long lines.
//
// - Filter outer syntax layers that don't support edit prediction.
#[derive(Debug, Clone)]
pub struct EditPredictionExcerptOptions {
/// Limit for the number of bytes in the window around the cursor.
pub max_bytes: usize,
/// Minimum number of bytes in the window around the cursor. When syntax tree selection results
/// in an excerpt smaller than this, it will fall back on line-based selection.
pub min_bytes: usize,
/// Target ratio of bytes before the cursor divided by total bytes in the window.
pub target_before_cursor_over_total_bytes: f32,
/// Whether to include parent signatures
pub include_parent_signatures: bool,
}
#[derive(Debug, Clone)]
pub struct EditPredictionExcerpt {
pub range: Range<usize>,
pub parent_signature_ranges: Vec<Range<usize>>,
pub size: usize,
}
#[derive(Clone)]
pub struct EditPredictionExcerptText {
pub body: String,
pub parent_signatures: Vec<String>,
}
impl EditPredictionExcerpt {
pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText {
let body = buffer
.text_for_range(self.range.clone())
.collect::<String>();
let parent_signatures = self
.parent_signature_ranges
.iter()
.map(|range| buffer.text_for_range(range.clone()).collect::<String>())
.collect();
EditPredictionExcerptText {
body,
parent_signatures,
}
}
/// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
/// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
/// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures
/// of parent outline items.
///
/// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
/// expansion.
///
/// Returns `None` if the line around the cursor doesn't fit.
pub fn select_from_buffer(
query_point: Point,
buffer: &BufferSnapshot,
options: &EditPredictionExcerptOptions,
) -> Option<Self> {
if buffer.len() <= options.max_bytes {
log::debug!(
"using entire file for excerpt since source length ({}) <= window max bytes ({})",
buffer.len(),
options.max_bytes
);
return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new()));
}
let query_offset = query_point.to_offset(buffer);
let query_range = Point::new(query_point.row, 0).to_offset(buffer)
..Point::new(query_point.row + 1, 0).to_offset(buffer);
if query_range.len() >= options.max_bytes {
return None;
}
// TODO: Don't compute text / annotation_range / skip converting to and from anchors.
let outline_items = if options.include_parent_signatures {
buffer
.outline_items_containing(query_range.clone(), false, None)
.into_iter()
.flat_map(|item| {
Some(ExcerptOutlineItem {
item_range: item.range.to_offset(&buffer),
signature_range: item.signature_range?.to_offset(&buffer),
})
})
.collect()
} else {
Vec::new()
};
let excerpt_selector = ExcerptSelector {
query_offset,
query_range,
outline_items: &outline_items,
buffer,
options,
};
if let Some(excerpt_ranges) = excerpt_selector.select_tree_sitter_nodes() {
if excerpt_ranges.size >= options.min_bytes {
return Some(excerpt_ranges);
}
log::debug!(
"tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
excerpt_ranges.size,
options.min_bytes
);
} else {
log::debug!(
"couldn't find excerpt via tree-sitter, falling back on line-based selection"
);
}
excerpt_selector.select_lines()
}
fn new(range: Range<usize>, parent_signature_ranges: Vec<Range<usize>>) -> Self {
let size = range.len()
+ parent_signature_ranges
.iter()
.map(|r| r.len())
.sum::<usize>();
Self {
range,
parent_signature_ranges,
size,
}
}
fn with_expanded_range(&self, new_range: Range<usize>) -> Self {
if !new_range.contains_inclusive(&self.range) {
// this is an issue because parent_signature_ranges may be incorrect
log::error!("bug: with_expanded_range called with disjoint range");
}
let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len());
let mut size = new_range.len();
for range in &self.parent_signature_ranges {
if range.contains_inclusive(&new_range) {
break;
}
parent_signature_ranges.push(range.clone());
size += range.len();
}
Self {
range: new_range,
parent_signature_ranges,
size,
}
}
fn parent_signatures_size(&self) -> usize {
self.size - self.range.len()
}
}
struct ExcerptSelector<'a> {
query_offset: usize,
query_range: Range<usize>,
outline_items: &'a [ExcerptOutlineItem],
buffer: &'a BufferSnapshot,
options: &'a EditPredictionExcerptOptions,
}
struct ExcerptOutlineItem {
item_range: Range<usize>,
signature_range: Range<usize>,
}
impl<'a> ExcerptSelector<'a> {
/// Finds the largest node that is smaller than the window size and contains `query_range`.
fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
let selected_layer_root = self.select_syntax_layer()?;
let mut cursor = selected_layer_root.walk();
loop {
let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer)
..node_line_end(cursor.node()).to_offset(&self.buffer);
if excerpt_range.contains_inclusive(&self.query_range) {
let excerpt = self.make_excerpt(excerpt_range);
if excerpt.size <= self.options.max_bytes {
return Some(self.expand_to_siblings(&mut cursor, excerpt));
}
} else {
// TODO: Should still be able to handle this case via AST nodes. For example, this
// can happen if the cursor is between two methods in a large class file.
return None;
}
if cursor
.goto_first_child_for_byte(self.query_range.start)
.is_none()
{
return None;
}
}
}
/// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
fn select_syntax_layer(&self) -> Option<Node<'_>> {
let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
let mut largest: Option<Node<'_>> = None;
for layer in self
.buffer
.syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
{
let layer_range = layer.node().byte_range();
if !layer_range.contains_inclusive(&self.query_range) {
continue;
}
if layer_range.len() > self.options.max_bytes {
match &smallest_exceeding_max_len {
None => smallest_exceeding_max_len = Some(layer.node()),
Some(existing) => {
if layer_range.len() < existing.byte_range().len() {
smallest_exceeding_max_len = Some(layer.node());
}
}
}
} else {
match &largest {
None => largest = Some(layer.node()),
Some(existing) if layer_range.len() > existing.byte_range().len() => {
largest = Some(layer.node())
}
_ => {}
}
}
}
smallest_exceeding_max_len.or(largest)
}
// motivation for this and `goto_previous_named_sibling` is to avoid including things like
// trailing unnamed "}" in body nodes
fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
while cursor.goto_next_sibling() {
if cursor.node().is_named() {
return true;
}
}
false
}
fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
while cursor.goto_previous_sibling() {
if cursor.node().is_named() {
return true;
}
}
false
}
fn expand_to_siblings(
&self,
cursor: &mut TreeCursor,
mut excerpt: EditPredictionExcerpt,
) -> EditPredictionExcerpt {
let mut forward_cursor = cursor.clone();
let backward_cursor = cursor;
let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
loop {
if backward_done && forward_done {
break;
}
let mut forward = None;
while !forward_done {
let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer);
if new_end > excerpt.range.end {
let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end);
if new_excerpt.size <= self.options.max_bytes {
forward = Some(new_excerpt);
break;
} else {
log::debug!("halting forward expansion, as it doesn't fit");
forward_done = true;
break;
}
}
forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
}
let mut backward = None;
while !backward_done {
let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer);
if new_start < excerpt.range.start {
let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end);
if new_excerpt.size <= self.options.max_bytes {
backward = Some(new_excerpt);
break;
} else {
log::debug!("halting backward expansion, as it doesn't fit");
backward_done = true;
break;
}
}
backward_done = !Self::goto_previous_named_sibling(backward_cursor);
}
let go_forward = match (forward, backward) {
(Some(forward), Some(backward)) => {
let go_forward = self.is_better_excerpt(&forward, &backward);
if go_forward {
excerpt = forward;
} else {
excerpt = backward;
}
go_forward
}
(Some(forward), None) => {
log::debug!("expanding forward, since backward expansion has halted");
excerpt = forward;
true
}
(None, Some(backward)) => {
log::debug!("expanding backward, since forward expansion has halted");
excerpt = backward;
false
}
(None, None) => break,
};
if go_forward {
forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
} else {
backward_done = !Self::goto_previous_named_sibling(backward_cursor);
}
}
excerpt
}
fn select_lines(&self) -> Option<EditPredictionExcerpt> {
// early return if line containing query_offset is already too large
let excerpt = self.make_excerpt(self.query_range.clone());
if excerpt.size > self.options.max_bytes {
log::debug!(
"excerpt for cursor line is {} bytes, which exceeds the window",
excerpt.size
);
return None;
}
let signatures_size = excerpt.parent_signatures_size();
let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
let before_bytes =
(self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
let start_point = {
let offset = self.query_offset.saturating_sub(before_bytes);
let point = offset.to_point(self.buffer);
Point::new(point.row + 1, 0)
};
let start_offset = start_point.to_offset(&self.buffer);
let end_point = {
let offset = start_offset + bytes_remaining;
let point = offset.to_point(self.buffer);
Point::new(point.row, 0)
};
let end_offset = end_point.to_offset(&self.buffer);
// this could be expanded further since recalculated `signature_size` may be smaller, but
// skipping that for now for simplicity
//
// TODO: could also consider checking if lines immediately before / after fit.
let excerpt = self.make_excerpt(start_offset..end_offset);
if excerpt.size > self.options.max_bytes {
log::error!(
"bug: line-based excerpt selection has size {}, \
which is {} bytes larger than the max size",
excerpt.size,
excerpt.size - self.options.max_bytes
);
}
return Some(excerpt);
}
fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
let parent_signature_ranges = self
.outline_items
.iter()
.filter(|item| item.item_range.contains_inclusive(&range))
.map(|item| item.signature_range.clone())
.collect();
EditPredictionExcerpt::new(range, parent_signature_ranges)
}
/// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
fn is_better_excerpt(
&self,
forward: &EditPredictionExcerpt,
backward: &EditPredictionExcerpt,
) -> bool {
let forward_ratio = self.excerpt_range_ratio(forward);
let backward_ratio = self.excerpt_range_ratio(backward);
let forward_delta =
(forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
let backward_delta =
(backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
let forward_is_better = forward_delta <= backward_delta;
if forward_is_better {
log::debug!(
"expanding forward since {} is closer than {} to {}",
forward_ratio,
backward_ratio,
self.options.target_before_cursor_over_total_bytes
);
} else {
log::debug!(
"expanding backward since {} is closer than {} to {}",
backward_ratio,
forward_ratio,
self.options.target_before_cursor_over_total_bytes
);
}
forward_is_better
}
/// Returns the ratio of bytes before the cursor over bytes within the range.
fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
log::error!("bug: edit prediction cursor offset is not outside the excerpt");
return 0.0;
};
bytes_before_cursor as f32 / excerpt.range.len() as f32
}
}
fn node_line_start(node: Node) -> Point {
Point::new(node.start_position().row as u32, 0)
}
fn node_line_end(node: Node) -> Point {
Point::new(node.end_position().row as u32 + 1, 0)
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::{AppContext, TestAppContext};
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
use util::test::{generate_marked_text, marked_text_offsets_by};
fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
buffer.read_with(cx, |buffer, _| buffer.snapshot())
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
.unwrap()
}
fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
(text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
}
fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
let buffer = create_buffer(&text, cx);
let cursor_point = cursor.to_point(&buffer);
let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
.expect("Should select an excerpt");
pretty_assertions::assert_eq!(
generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
generate_marked_text(&text, &[expected_excerpt], false)
);
assert!(excerpt.size <= options.max_bytes);
assert!(excerpt.range.contains(&cursor));
}
#[gpui::test]
fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
zlog::init_test();
let text = r#"
fn main() {
let x = 1;
« let ˇy = 2;
» let z = 3;
}"#;
let options = EditPredictionExcerptOptions {
max_bytes: 20,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
};
check_example(options, text, cx);
}
#[gpui::test]
fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
zlog::init_test();
let text = r#"
fn foo() {}
«fn main() {
let x = 1;
let ˇy = 2;
let z = 3;
}
»
fn bar() {}"#;
let options = EditPredictionExcerptOptions {
max_bytes: 65,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
};
check_example(options, text, cx);
}
#[gpui::test]
fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
zlog::init_test();
let text = r#"
fn main() {
« let x = 1;
let ˇy = 2;
let z = 3;
»}"#;
let options = EditPredictionExcerptOptions {
max_bytes: 50,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
};
check_example(options, text, cx);
}
#[gpui::test]
fn test_line_based_selection(cx: &mut TestAppContext) {
zlog::init_test();
let text = r#"
fn main() {
let x = 1;
« if true {
let ˇy = 2;
}
let z = 3;
»}"#;
let options = EditPredictionExcerptOptions {
max_bytes: 60,
min_bytes: 45,
target_before_cursor_over_total_bytes: 0.5,
include_parent_signatures: false,
};
check_example(options, text, cx);
}
#[gpui::test]
fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
zlog::init_test();
let text = r#"
fn main() {
« let a = 1;
let b = 2;
let c = 3;
let ˇd = 4;
let e = 5;
let f = 6;
»
let g = 7;
}"#;
let options = EditPredictionExcerptOptions {
max_bytes: 120,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.6,
include_parent_signatures: false,
};
check_example(options, text, cx);
}
}

View File

@@ -0,0 +1,126 @@
use language::{BufferSnapshot, SyntaxMapMatches};
use std::{cmp::Reverse, ops::Range};
use crate::declaration::Identifier;
// TODO:
//
// * how to handle multiple name captures? for now last one wins
//
// * annotation ranges
//
// * new "signature" capture for outline queries
//
// * Check parent behavior of "int x, y = 0" declarations in a test
pub struct OutlineDeclaration {
pub parent_index: Option<usize>,
pub identifier: Identifier,
pub item_range: Range<usize>,
pub signature_range: Range<usize>,
}
pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec<OutlineDeclaration> {
declarations_overlapping_range(0..buffer.len(), buffer)
}
pub fn declarations_overlapping_range(
range: Range<usize>,
buffer: &BufferSnapshot,
) -> Vec<OutlineDeclaration> {
let mut declarations = OutlineIterator::new(range, buffer).collect::<Vec<_>>();
declarations.sort_unstable_by_key(|item| (item.item_range.start, Reverse(item.item_range.end)));
let mut parent_stack: Vec<(usize, Range<usize>)> = Vec::new();
for (index, declaration) in declarations.iter_mut().enumerate() {
while let Some((top_parent_index, top_parent_range)) = parent_stack.last() {
if declaration.item_range.start >= top_parent_range.end {
parent_stack.pop();
} else {
declaration.parent_index = Some(*top_parent_index);
break;
}
}
parent_stack.push((index, declaration.item_range.clone()));
}
declarations
}
/// Iterates outline items without being ordered w.r.t. nested items and without populating
/// `parent`.
pub struct OutlineIterator<'a> {
buffer: &'a BufferSnapshot,
matches: SyntaxMapMatches<'a>,
}
impl<'a> OutlineIterator<'a> {
pub fn new(range: Range<usize>, buffer: &'a BufferSnapshot) -> Self {
let matches = buffer.syntax.matches(range, &buffer.text, |grammar| {
grammar.outline_config.as_ref().map(|c| &c.query)
});
Self { buffer, matches }
}
}
impl<'a> Iterator for OutlineIterator<'a> {
type Item = OutlineDeclaration;
fn next(&mut self) -> Option<Self::Item> {
while let Some(mat) = self.matches.peek() {
let config = self.matches.grammars()[mat.grammar_index]
.outline_config
.as_ref()
.unwrap();
let mut name_range = None;
let mut item_range = None;
let mut signature_start = None;
let mut signature_end = None;
let mut add_to_signature = |range: Range<usize>| {
if signature_start.is_none() {
signature_start = Some(range.start);
}
signature_end = Some(range.end);
};
for capture in mat.captures {
let range = capture.node.byte_range();
if capture.index == config.name_capture_ix {
name_range = Some(range.clone());
add_to_signature(range);
} else if Some(capture.index) == config.context_capture_ix
|| Some(capture.index) == config.extra_context_capture_ix
{
add_to_signature(range);
} else if capture.index == config.item_capture_ix {
item_range = Some(range.clone());
}
}
let language_id = mat.language.id();
self.matches.advance();
if let Some(name_range) = name_range
&& let Some(item_range) = item_range
&& let Some(signature_start) = signature_start
&& let Some(signature_end) = signature_end
{
let name = self
.buffer
.text_for_range(name_range)
.collect::<String>()
.into();
return Some(OutlineDeclaration {
identifier: Identifier { name, language_id },
item_range: item_range,
signature_range: signature_start..signature_end,
parent_index: None,
});
}
}
None
}
}

View File

@@ -0,0 +1,109 @@
use language::BufferSnapshot;
use std::collections::HashMap;
use std::ops::Range;
use crate::{
declaration::Identifier,
excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
};
#[derive(Debug)]
pub struct Reference {
pub identifier: Identifier,
pub range: Range<usize>,
pub region: ReferenceRegion,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ReferenceRegion {
Breadcrumb,
Nearby,
}
pub fn references_in_excerpt(
excerpt: &EditPredictionExcerpt,
excerpt_text: &EditPredictionExcerptText,
snapshot: &BufferSnapshot,
) -> HashMap<Identifier, Vec<Reference>> {
let mut references = identifiers_in_range(
excerpt.range.clone(),
excerpt_text.body.as_str(),
ReferenceRegion::Nearby,
snapshot,
);
for (range, text) in excerpt
.parent_signature_ranges
.iter()
.zip(excerpt_text.parent_signatures.iter())
{
references.extend(identifiers_in_range(
range.clone(),
text.as_str(),
ReferenceRegion::Breadcrumb,
snapshot,
));
}
let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::new();
for reference in references {
identifier_to_references
.entry(reference.identifier.clone())
.or_insert_with(Vec::new)
.push(reference);
}
identifier_to_references
}
/// Finds all nodes which have a "variable" match from the highlights query within the offset range.
pub fn identifiers_in_range(
range: Range<usize>,
range_text: &str,
reference_region: ReferenceRegion,
buffer: &BufferSnapshot,
) -> Vec<Reference> {
let mut matches = buffer
.syntax
.matches(range.clone(), &buffer.text, |grammar| {
grammar
.highlights_config
.as_ref()
.map(|config| &config.query)
});
let mut references = Vec::new();
let mut last_added_range = None;
while let Some(mat) = matches.peek() {
let config = matches.grammars()[mat.grammar_index]
.highlights_config
.as_ref();
for capture in mat.captures {
if let Some(config) = config {
if config.identifier_capture_indices.contains(&capture.index) {
let node_range = capture.node.byte_range();
// sometimes multiple highlight queries match - this deduplicates them
if Some(node_range.clone()) == last_added_range {
continue;
}
let identifier_text =
&range_text[node_range.start - range.start..node_range.end - range.start];
references.push(Reference {
identifier: Identifier {
name: identifier_text.into(),
language_id: mat.language.id(),
},
range: node_range.clone(),
region: reference_region,
});
last_added_range = Some(node_range);
}
}
}
matches.advance();
}
references
}

View File

@@ -0,0 +1,853 @@
use std::sync::Arc;
use collections::{HashMap, HashSet};
use futures::lock::Mutex;
use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
use language::{Buffer, BufferEvent};
use project::buffer_store::{BufferStore, BufferStoreEvent};
use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
use project::{PathChange, Project, ProjectEntryId, ProjectPath};
use slotmap::SlotMap;
use text::BufferId;
use util::{ResultExt as _, debug_panic, some_or_debug_panic};
use crate::declaration::{
BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
};
use crate::outline::declarations_in_buffer;
// TODO:
//
// * Skip for remote projects
//
// * Consider making SyntaxIndex not an Entity.
// Potential future improvements:
//
// * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which
// references are present and their scores.
// Potential future optimizations:
//
// * Cache of buffers for files
//
// * Parse files directly instead of loading into a Rope. Make SyntaxMap generic to handle embedded
// languages? Will also need to find line boundaries, but that can be done by scanning characters in
// the flat representation.
//
// * Use something similar to slotmap without key versions.
//
// * Concurrent slotmap
//
// * Use queue for parsing
//
pub struct SyntaxIndex {
state: Arc<Mutex<SyntaxIndexState>>,
project: WeakEntity<Project>,
}
#[derive(Default)]
pub struct SyntaxIndexState {
declarations: SlotMap<DeclarationId, Declaration>,
identifiers: HashMap<Identifier, HashSet<DeclarationId>>,
files: HashMap<ProjectEntryId, FileState>,
buffers: HashMap<BufferId, BufferState>,
}
#[derive(Debug, Default)]
struct FileState {
declarations: Vec<DeclarationId>,
task: Option<Task<()>>,
}
#[derive(Default)]
struct BufferState {
declarations: Vec<DeclarationId>,
task: Option<Task<()>>,
}
impl SyntaxIndex {
pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
let mut this = Self {
project: project.downgrade(),
state: Arc::new(Mutex::new(SyntaxIndexState::default())),
};
let worktree_store = project.read(cx).worktree_store();
cx.subscribe(&worktree_store, Self::handle_worktree_store_event)
.detach();
for worktree in worktree_store
.read(cx)
.worktrees()
.map(|w| w.read(cx).snapshot())
.collect::<Vec<_>>()
{
for entry in worktree.files(false, 0) {
this.update_file(
entry.id,
ProjectPath {
worktree_id: worktree.id(),
path: entry.path.clone(),
},
cx,
);
}
}
let buffer_store = project.read(cx).buffer_store().clone();
for buffer in buffer_store.read(cx).buffers().collect::<Vec<_>>() {
this.register_buffer(&buffer, cx);
}
cx.subscribe(&buffer_store, Self::handle_buffer_store_event)
.detach();
this
}
fn handle_worktree_store_event(
&mut self,
_worktree_store: Entity<WorktreeStore>,
event: &WorktreeStoreEvent,
cx: &mut Context<Self>,
) {
use WorktreeStoreEvent::*;
match event {
WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
let state = Arc::downgrade(&self.state);
let worktree_id = *worktree_id;
let updated_entries_set = updated_entries_set.clone();
cx.spawn(async move |this, cx| {
let Some(state) = state.upgrade() else { return };
for (path, entry_id, path_change) in updated_entries_set.iter() {
if let PathChange::Removed = path_change {
state.lock().await.files.remove(entry_id);
} else {
let project_path = ProjectPath {
worktree_id,
path: path.clone(),
};
this.update(cx, |this, cx| {
this.update_file(*entry_id, project_path, cx);
})
.ok();
}
}
})
.detach();
}
WorktreeDeletedEntry(_worktree_id, project_entry_id) => {
let project_entry_id = *project_entry_id;
self.with_state(cx, move |state| {
state.files.remove(&project_entry_id);
})
}
_ => {}
}
}
fn handle_buffer_store_event(
&mut self,
_buffer_store: Entity<BufferStore>,
event: &BufferStoreEvent,
cx: &mut Context<Self>,
) {
use BufferStoreEvent::*;
match event {
BufferAdded(buffer) => self.register_buffer(buffer, cx),
BufferOpened { .. }
| BufferChangedFilePath { .. }
| BufferDropped { .. }
| SharedBufferClosed { .. } => {}
}
}
pub fn state(&self) -> &Arc<Mutex<SyntaxIndexState>> {
&self.state
}
fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) {
if let Some(mut state) = self.state.try_lock() {
f(&mut state);
return;
}
let state = Arc::downgrade(&self.state);
cx.background_spawn(async move {
let Some(state) = state.upgrade() else {
return None;
};
let mut state = state.lock().await;
Some(f(&mut state))
})
.detach();
}
fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
let buffer_id = buffer.read(cx).remote_id();
cx.observe_release(buffer, move |this, _buffer, cx| {
this.with_state(cx, move |state| {
if let Some(buffer_state) = state.buffers.remove(&buffer_id) {
SyntaxIndexState::remove_buffer_declarations(
&buffer_state.declarations,
&mut state.declarations,
&mut state.identifiers,
);
}
})
})
.detach();
cx.subscribe(buffer, Self::handle_buffer_event).detach();
self.update_buffer(buffer.clone(), cx);
}
fn handle_buffer_event(
&mut self,
buffer: Entity<Buffer>,
event: &BufferEvent,
cx: &mut Context<Self>,
) {
match event {
BufferEvent::Edited => self.update_buffer(buffer, cx),
_ => {}
}
}
fn update_buffer(&mut self, buffer_entity: Entity<Buffer>, cx: &mut Context<Self>) {
let buffer = buffer_entity.read(cx);
let Some(project_entry_id) =
project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
else {
return;
};
let buffer_id = buffer.remote_id();
let mut parse_status = buffer.parse_status();
let snapshot_task = cx.spawn({
let weak_buffer = buffer_entity.downgrade();
async move |_, cx| {
while *parse_status.borrow() != language::ParseStatus::Idle {
parse_status.changed().await?;
}
weak_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())
}
});
let parse_task = cx.background_spawn(async move {
let snapshot = snapshot_task.await?;
let rope = snapshot.text.as_rope().clone();
anyhow::Ok((
declarations_in_buffer(&snapshot)
.into_iter()
.map(|item| {
(
item.parent_index,
BufferDeclaration::from_outline(item, &rope),
)
})
.collect::<Vec<_>>(),
rope,
))
});
let task = cx.spawn({
async move |this, cx| {
let Ok((declarations, rope)) = parse_task.await else {
return;
};
this.update(cx, move |this, cx| {
this.with_state(cx, move |state| {
let buffer_state = state
.buffers
.entry(buffer_id)
.or_insert_with(Default::default);
SyntaxIndexState::remove_buffer_declarations(
&buffer_state.declarations,
&mut state.declarations,
&mut state.identifiers,
);
let mut new_ids = Vec::with_capacity(declarations.len());
state.declarations.reserve(declarations.len());
for (parent_index, mut declaration) in declarations {
declaration.parent = parent_index
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
let identifier = declaration.identifier.clone();
let declaration_id = state.declarations.insert(Declaration::Buffer {
rope: rope.clone(),
buffer_id,
declaration,
project_entry_id,
});
new_ids.push(declaration_id);
state
.identifiers
.entry(identifier)
.or_default()
.insert(declaration_id);
}
buffer_state.declarations = new_ids;
});
})
.ok();
}
});
self.with_state(cx, move |state| {
state
.buffers
.entry(buffer_id)
.or_insert_with(Default::default)
.task = Some(task)
});
}
fn update_file(
&mut self,
entry_id: ProjectEntryId,
project_path: ProjectPath,
cx: &mut Context<Self>,
) {
let Some(project) = self.project.upgrade() else {
return;
};
let project = project.read(cx);
let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) else {
return;
};
let language_registry = project.languages().clone();
let snapshot_task = worktree.update(cx, |worktree, cx| {
let load_task = worktree.load_file(&project_path.path, cx);
cx.spawn(async move |_this, cx| {
let loaded_file = load_task.await?;
let language = language_registry
.language_for_file_path(&project_path.path)
.await
.log_err();
let buffer = cx.new(|cx| {
let mut buffer = Buffer::local(loaded_file.text, cx);
buffer.set_language(language, cx);
buffer
})?;
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
while *parse_status.borrow() != language::ParseStatus::Idle {
parse_status.changed().await?;
}
buffer.read_with(cx, |buffer, _cx| buffer.snapshot())
})
});
let parse_task = cx.background_spawn(async move {
let snapshot = snapshot_task.await?;
let rope = snapshot.as_rope();
let declarations = declarations_in_buffer(&snapshot)
.into_iter()
.map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope)))
.collect::<Vec<_>>();
anyhow::Ok(declarations)
});
let task = cx.spawn({
async move |this, cx| {
// TODO: how to handle errors?
let Ok(declarations) = parse_task.await else {
return;
};
this.update(cx, |this, cx| {
this.with_state(cx, move |state| {
let file_state =
state.files.entry(entry_id).or_insert_with(Default::default);
for old_declaration_id in &file_state.declarations {
let Some(declaration) = state.declarations.remove(*old_declaration_id)
else {
debug_panic!("declaration not found");
continue;
};
if let Some(identifier_declarations) =
state.identifiers.get_mut(declaration.identifier())
{
identifier_declarations.remove(old_declaration_id);
}
}
let mut new_ids = Vec::with_capacity(declarations.len());
state.declarations.reserve(declarations.len());
for (parent_index, mut declaration) in declarations {
declaration.parent = parent_index
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
let identifier = declaration.identifier.clone();
let declaration_id = state.declarations.insert(Declaration::File {
project_entry_id: entry_id,
declaration,
});
new_ids.push(declaration_id);
state
.identifiers
.entry(identifier)
.or_default()
.insert(declaration_id);
}
file_state.declarations = new_ids;
});
})
.ok();
}
});
self.with_state(cx, move |state| {
state
.files
.entry(entry_id)
.or_insert_with(Default::default)
.task = Some(task);
});
}
}
impl SyntaxIndexState {
pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
self.declarations.get(id)
}
/// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector.
///
/// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded.
pub fn declarations_for_identifier<const N: usize>(
&self,
identifier: &Identifier,
) -> Vec<Declaration> {
// make sure to not have a large stack allocation
assert!(N < 32);
let Some(declaration_ids) = self.identifiers.get(&identifier) else {
return vec![];
};
let mut result = Vec::with_capacity(N);
let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
let mut file_declarations = Vec::new();
for declaration_id in declaration_ids {
let declaration = self.declarations.get(*declaration_id);
let Some(declaration) = some_or_debug_panic(declaration) else {
continue;
};
match declaration {
Declaration::Buffer {
project_entry_id, ..
} => {
included_buffer_entry_ids.push(*project_entry_id);
result.push(declaration.clone());
if result.len() == N {
return Vec::new();
}
}
Declaration::File {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
file_declarations.push(declaration.clone());
}
}
}
}
for declaration in file_declarations {
match declaration {
Declaration::File {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
result.push(declaration);
if result.len() == N {
return Vec::new();
}
}
}
Declaration::Buffer { .. } => {}
}
}
result
}
pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
match declaration {
Declaration::File {
project_entry_id, ..
} => self
.files
.get(project_entry_id)
.map(|file_state| file_state.declarations.len())
.unwrap_or_default(),
Declaration::Buffer { buffer_id, .. } => self
.buffers
.get(buffer_id)
.map(|buffer_state| buffer_state.declarations.len())
.unwrap_or_default(),
}
}
fn remove_buffer_declarations(
old_declaration_ids: &[DeclarationId],
declarations: &mut SlotMap<DeclarationId, Declaration>,
identifiers: &mut HashMap<Identifier, HashSet<DeclarationId>>,
) {
for old_declaration_id in old_declaration_ids {
let Some(declaration) = declarations.remove(*old_declaration_id) else {
debug_panic!("declaration not found");
continue;
};
if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) {
identifier_declarations.remove(old_declaration_id);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{path::Path, sync::Arc};
use gpui::TestAppContext;
use indoc::indoc;
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use text::OffsetRangeExt as _;
use util::path;
use crate::syntax_index::SyntaxIndex;
#[gpui::test]
async fn test_unopen_indexed_files(cx: &mut TestAppContext) {
let (project, index, rust_lang_id) = init_test(cx).await;
let main = Identifier {
name: "main".into(),
language_id: rust_lang_id,
};
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
let index_state = index_state.lock().await;
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, main.clone());
assert_eq!(decl.item_range_in_file, 32..280);
let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range_in_file, 0..98);
});
}
#[gpui::test]
async fn test_parents_in_file(cx: &mut TestAppContext) {
let (project, index, rust_lang_id) = init_test(cx).await;
let test_process_data = Identifier {
name: "test_process_data".into(),
language_id: rust_lang_id,
};
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
let index_state = index_state.lock().await;
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
let parent = index_state.declaration(parent_id).unwrap();
let parent_decl = expect_file_decl("c.rs", &parent, &project, cx);
assert_eq!(
parent_decl.identifier,
Identifier {
name: "tests".into(),
language_id: rust_lang_id
}
);
assert_eq!(parent_decl.parent, None);
});
}
#[gpui::test]
async fn test_parents_in_buffer(cx: &mut TestAppContext) {
let (project, index, rust_lang_id) = init_test(cx).await;
let test_process_data = Identifier {
name: "test_process_data".into(),
language_id: rust_lang_id,
};
let buffer = project
.update(cx, |project, cx| {
let project_path = project.find_project_path("c.rs", cx).unwrap();
project.open_buffer(project_path, cx)
})
.await
.unwrap();
cx.run_until_parked();
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
let index_state = index_state.lock().await;
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
let parent = index_state.declaration(parent_id).unwrap();
let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx);
assert_eq!(
parent_decl.identifier,
Identifier {
name: "tests".into(),
language_id: rust_lang_id
}
);
assert_eq!(parent_decl.parent, None);
});
drop(buffer);
}
#[gpui::test]
async fn test_declarations_limt(cx: &mut TestAppContext) {
let (_, index, rust_lang_id) = init_test(cx).await;
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
let index_state = index_state.lock().await;
let decls = index_state.declarations_for_identifier::<1>(&Identifier {
name: "main".into(),
language_id: rust_lang_id,
});
assert_eq!(decls.len(), 0);
}
#[gpui::test]
async fn test_buffer_shadow(cx: &mut TestAppContext) {
let (project, index, rust_lang_id) = init_test(cx).await;
let main = Identifier {
name: "main".into(),
language_id: rust_lang_id,
};
let buffer = project
.update(cx, |project, cx| {
let project_path = project.find_project_path("c.rs", cx).unwrap();
project.open_buffer(project_path, cx)
})
.await
.unwrap();
cx.run_until_parked();
let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone());
{
let index_state = index_state_arc.lock().await;
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279);
expect_file_decl("a.rs", &decls[1], &project, cx);
});
}
// Drop the buffer and wait for release
cx.update(|_| {
drop(buffer);
});
cx.run_until_parked();
let index_state = index_state_arc.lock().await;
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
expect_file_decl("c.rs", &decls[0], &project, cx);
expect_file_decl("a.rs", &decls[1], &project, cx);
});
}
fn expect_buffer_decl<'a>(
path: &str,
declaration: &'a Declaration,
project: &Entity<Project>,
cx: &App,
) -> &'a BufferDeclaration {
if let Declaration::Buffer {
declaration,
project_entry_id,
..
} = declaration
{
let project_path = project
.read(cx)
.path_for_entry(*project_entry_id, cx)
.unwrap();
assert_eq!(project_path.path.as_ref(), Path::new(path),);
declaration
} else {
panic!("Expected a buffer declaration, found {:?}", declaration);
}
}
fn expect_file_decl<'a>(
path: &str,
declaration: &'a Declaration,
project: &Entity<Project>,
cx: &App,
) -> &'a FileDeclaration {
if let Declaration::File {
declaration,
project_entry_id: file,
} = declaration
{
assert_eq!(
project
.read(cx)
.path_for_entry(*file, cx)
.unwrap()
.path
.as_ref(),
Path::new(path),
);
declaration
} else {
panic!("Expected a file declaration, found {:?}", declaration);
}
}
async fn init_test(
cx: &mut TestAppContext,
) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"a.rs": indoc! {r#"
fn main() {
let x = 1;
let y = 2;
let z = add(x, y);
println!("Result: {}", z);
}
fn add(a: i32, b: i32) -> i32 {
a + b
}
"#},
"b.rs": indoc! {"
pub struct Config {
pub name: String,
pub value: i32,
}
impl Config {
pub fn new(name: String, value: i32) -> Self {
Config { name, value }
}
}
"},
"c.rs": indoc! {r#"
use std::collections::HashMap;
fn main() {
let args: Vec<String> = std::env::args().collect();
let data: Vec<i32> = args[1..]
.iter()
.filter_map(|s| s.parse().ok())
.collect();
let result = process_data(data);
println!("{:?}", result);
}
fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
let mut counts = HashMap::new();
for value in data {
*counts.entry(value).or_insert(0) += 1;
}
counts
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_process_data() {
let data = vec![1, 2, 2, 3];
let result = process_data(data);
assert_eq!(result.get(&2), Some(&2));
}
}
"#}
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
let lang = rust_lang();
let lang_id = lang.id();
language_registry.add(Arc::new(lang));
let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
cx.run_until_parked();
(project, index, lang_id)
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
.unwrap()
}
}

View File

@@ -0,0 +1,241 @@
use regex::Regex;
use std::{collections::HashMap, sync::LazyLock};
use crate::reference::Reference;
// TODO: Consider implementing sliding window similarity matching like
// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
//
// That implementation could actually be more efficient - no need to track words in the window that
// are not in the query.
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
#[derive(Debug)]
pub struct IdentifierOccurrences {
identifier_to_count: HashMap<String, usize>,
total_count: usize,
}
impl IdentifierOccurrences {
pub fn within_string(code: &str) -> Self {
Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
}
#[allow(dead_code)]
pub fn within_references(references: &[Reference]) -> Self {
Self::from_iterator(
references
.iter()
.map(|reference| reference.identifier.name.as_ref()),
)
}
pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
let mut identifier_to_count = HashMap::new();
let mut total_count = 0;
for identifier in identifier_iterator {
// TODO: Score matches that match case higher?
//
// TODO: Also include unsplit identifier?
for identifier_part in split_identifier(identifier) {
identifier_to_count
.entry(identifier_part.to_lowercase())
.and_modify(|count| *count += 1)
.or_insert(1);
total_count += 1;
}
}
IdentifierOccurrences {
identifier_to_count,
total_count,
}
}
}
// Splits camelcase / snakecase / kebabcase / pascalcase
//
// TODO: Make this more efficient / elegant.
fn split_identifier<'a>(identifier: &'a str) -> Vec<&'a str> {
let mut parts = Vec::new();
let mut start = 0;
let chars: Vec<char> = identifier.chars().collect();
if chars.is_empty() {
return parts;
}
let mut i = 0;
while i < chars.len() {
let ch = chars[i];
// Handle explicit delimiters (underscore and hyphen)
if ch == '_' || ch == '-' {
if i > start {
parts.push(&identifier[start..i]);
}
start = i + 1;
i += 1;
continue;
}
// Handle camelCase and PascalCase transitions
if i > 0 && i < chars.len() {
let prev_char = chars[i - 1];
// Transition from lowercase/digit to uppercase
if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
parts.push(&identifier[start..i]);
start = i;
}
// Handle sequences like "XMLParser" -> ["XML", "Parser"]
else if i + 1 < chars.len()
&& ch.is_uppercase()
&& chars[i + 1].is_lowercase()
&& prev_char.is_uppercase()
{
parts.push(&identifier[start..i]);
start = i;
}
}
i += 1;
}
// Add the last part if there's any remaining
if start < identifier.len() {
parts.push(&identifier[start..]);
}
// Filter out empty strings
parts.into_iter().filter(|s| !s.is_empty()).collect()
}
pub fn jaccard_similarity<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let intersection = set_a
.identifier_to_count
.keys()
.filter(|key| set_b.identifier_to_count.contains_key(*key))
.count();
let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
intersection as f32 / union as f32
}
// TODO
#[allow(dead_code)]
pub fn overlap_coefficient<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let intersection = set_a
.identifier_to_count
.keys()
.filter(|key| set_b.identifier_to_count.contains_key(*key))
.count();
intersection as f32 / set_a.identifier_to_count.len() as f32
}
// TODO
#[allow(dead_code)]
pub fn weighted_jaccard_similarity<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let mut numerator = 0;
let mut denominator_a = 0;
let mut used_count_b = 0;
for (symbol, count_a) in set_a.identifier_to_count.iter() {
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
numerator += count_a.min(count_b);
denominator_a += count_a.max(count_b);
used_count_b += count_b;
}
let denominator = denominator_a + (set_b.total_count - used_count_b);
if denominator == 0 {
0.0
} else {
numerator as f32 / denominator as f32
}
}
pub fn weighted_overlap_coefficient<'a>(
mut set_a: &'a IdentifierOccurrences,
mut set_b: &'a IdentifierOccurrences,
) -> f32 {
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let mut numerator = 0;
for (symbol, count_a) in set_a.identifier_to_count.iter() {
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
numerator += count_a.min(count_b);
}
let denominator = set_a.total_count.min(set_b.total_count);
if denominator == 0 {
0.0
} else {
numerator as f32 / denominator as f32
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_split_identifier() {
assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
}
#[test]
fn test_similarity_functions() {
// 10 identifier parts, 8 unique
// Repeats: 2 "outline", 2 "items"
let set_a = IdentifierOccurrences::within_string(
"let mut outline_items = query_outline_items(&language, &tree, &source);",
);
// 14 identifier parts, 11 unique
// Repeats: 2 "outline", 2 "language", 2 "tree"
let set_b = IdentifierOccurrences::within_string(
"pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
);
// 6 overlaps: "outline", "items", "query", "language", "tree", "source"
// 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
// Numerator is one more than before due to both having 2 "outline".
// Denominator is the same except for 3 more due to the non-overlapping duplicates
assert_eq!(
weighted_jaccard_similarity(&set_a, &set_b),
7.0 / (7.0 + 7.0 + 3.0)
);
// Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
// Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
// the smaller set, 10.
assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
}
}

View File

@@ -0,0 +1,35 @@
// To discuss: What to send to the new endpoint? Thinking it'd make sense to put `prompt.rs` from
// `zeta_context.rs` in cloud.
//
// * Run excerpt selection at several different sizes, send the largest size with offsets within for
// the smaller sizes.
//
// * Longer event history.
//
// * Many more snippets than could fit in model context - allows ranking experimentation.
pub struct Zeta2Request {
pub event_history: Vec<Event>,
pub excerpt: String,
pub excerpt_subsets: Vec<Zeta2ExcerptSubset>,
/// Within `excerpt`
pub cursor_position: usize,
pub signatures: Vec<String>,
pub retrieved_declarations: Vec<ReferencedDeclaration>,
}
pub struct Zeta2ExcerptSubset {
/// Within `excerpt` text.
pub excerpt_range: Range<usize>,
/// Within `signatures`.
pub parent_signatures: Vec<usize>,
}
pub struct ReferencedDeclaration {
pub text: Arc<str>,
/// Range within `text`
pub signature_range: Range<usize>,
/// Indices within `signatures`.
pub parent_signatures: Vec<usize>,
// A bunch of score metrics
}

View File

@@ -20549,7 +20549,9 @@ impl Editor {
)
.detach();
}
self.update_lsp_data(false, Some(buffer_id), window, cx);
if self.active_diagnostics != ActiveDiagnostic::All {
self.update_lsp_data(false, Some(buffer_id), window, cx);
}
cx.emit(EditorEvent::ExcerptsAdded {
buffer: buffer.clone(),
predecessor: *predecessor,

View File

@@ -19265,7 +19265,7 @@ async fn test_expand_diff_hunk_at_excerpt_boundary(cx: &mut TestAppContext) {
cx.executor().run_until_parked();
// When the start of a hunk coincides with the start of its excerpt,
// the hunk is expanded. When the start of a a hunk is earlier than
// the hunk is expanded. When the start of a hunk is earlier than
// the start of its excerpt, the hunk is not expanded.
cx.assert_state_with_diff(
"

View File

@@ -9694,7 +9694,7 @@ impl EditorScrollbars {
editor_bounds.bottom_left(),
size(
// The horizontal viewport size differs from the space available for the
// horizontal scrollbar, so we have to manually stich it together here.
// horizontal scrollbar, so we have to manually stitch it together here.
editor_bounds.size.width - right_margin,
scrollbar_width,
),

View File

@@ -521,6 +521,14 @@ impl PickerDelegate for BranchListDelegate {
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
.tooltip({
let branch_name = entry.branch.name().to_string();
if entry.is_new {
Tooltip::text(format!("Create branch \"{}\"", branch_name))
} else {
Tooltip::text(branch_name)
}
})
.child(
v_flex()
.w_full()

View File

@@ -3748,7 +3748,10 @@ impl GitPanel {
.custom_scrollbars(
Scrollbars::for_settings::<GitPanelSettings>()
.tracked_scroll_handle(self.scroll_handle.clone())
.with_track_along(ScrollAxes::Horizontal),
.with_track_along(
ScrollAxes::Horizontal,
cx.theme().colors().panel_background,
),
window,
cx,
),

View File

@@ -115,7 +115,7 @@ seahash = "4.1"
semantic_version.workspace = true
serde.workspace = true
serde_json.workspace = true
slotmap = "1.0.6"
slotmap.workspace = true
smallvec.workspace = true
smol.workspace = true
stacksafe.workspace = true

View File

@@ -151,9 +151,9 @@ impl From<Hsla> for Rgba {
};
Rgba {
r,
g,
b,
r: r.clamp(0., 1.),
g: g.clamp(0., 1.),
b: b.clamp(0., 1.),
a: color.a,
}
}

View File

@@ -82,6 +82,10 @@ unsafe fn build_classes() {
APP_DELEGATE_CLASS = unsafe {
let mut decl = ClassDecl::new("GPUIApplicationDelegate", class!(NSResponder)).unwrap();
decl.add_ivar::<*mut c_void>(MAC_PLATFORM_IVAR);
decl.add_method(
sel!(applicationWillFinishLaunching:),
will_finish_launching as extern "C" fn(&mut Object, Sel, id),
);
decl.add_method(
sel!(applicationDidFinishLaunching:),
did_finish_launching as extern "C" fn(&mut Object, Sel, id),
@@ -1356,6 +1360,23 @@ unsafe fn get_mac_platform(object: &mut Object) -> &MacPlatform {
}
}
extern "C" fn will_finish_launching(_this: &mut Object, _: Sel, _: id) {
unsafe {
let user_defaults: id = msg_send![class!(NSUserDefaults), standardUserDefaults];
// The autofill heuristic controller causes slowdown and high CPU usage.
// We don't know exactly why. This disables the full heuristic controller.
//
// Adapted from: https://github.com/ghostty-org/ghostty/pull/8625
let name = ns_string("NSAutoFillHeuristicControllerEnabled");
let existing_value: id = msg_send![user_defaults, objectForKey: name];
if existing_value == nil {
let false_value: id = msg_send![class!(NSNumber), numberWithBool:false];
let _: () = msg_send![user_defaults, setObject: false_value forKey: name];
}
}
}
extern "C" fn did_finish_launching(this: &mut Object, _: Sel, _: id) {
unsafe {
let app: id = msg_send![APP_CLASS, sharedApplication];

View File

@@ -1016,7 +1016,7 @@ fn handle_gpu_device_lost(
all_windows: &std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
text_system: &std::sync::Weak<DirectWriteTextSystem>,
) {
// Here we wait a bit to ensure the the system has time to recover from the device lost state.
// Here we wait a bit to ensure the system has time to recover from the device lost state.
// If we don't wait, the final drawing result will be blank.
std::thread::sleep(std::time::Duration::from_millis(350));

View File

@@ -684,8 +684,16 @@ impl PlatformWindow for WindowsWindow {
.executor
.spawn(async move {
this.set_window_placement().log_err();
unsafe { SetActiveWindow(hwnd).log_err() };
unsafe { SetFocus(Some(hwnd)).log_err() };
unsafe {
// If the window is minimized, restore it.
if IsIconic(hwnd).as_bool() {
ShowWindowAsync(hwnd, SW_RESTORE).ok().log_err();
}
SetActiveWindow(hwnd).log_err();
SetFocus(Some(hwnd)).log_err();
}
// premium ragebait by windows, this is needed because the window
// must have received an input event to be able to set itself to foreground

View File

@@ -318,6 +318,12 @@ pub fn read_proxy_from_env() -> Option<Url> {
.and_then(|env| env.parse().ok())
}
pub fn read_no_proxy_from_env() -> Option<String> {
const ENV_VARS: &[&str] = &["NO_PROXY", "no_proxy"];
ENV_VARS.iter().find_map(|var| std::env::var(var).ok())
}
pub struct BlockedHttpClient;
impl BlockedHttpClient {

View File

@@ -68,7 +68,7 @@ With both approaches, would need to record the buffer version and use that when
* Mode to navigate to source code on every element change while picking.
* Tracking of more source locations - currently the source location is often in a ui compoenent. Ideally this would have a way for the components to indicate that they are probably not the source location the user is looking for.
* Tracking of more source locations - currently the source location is often in a ui component. Ideally this would have a way for the components to indicate that they are probably not the source location the user is looking for.
- Could have `InspectorElementId` be `Vec<(ElementId, Option<Location>)>`, but if there are multiple code paths that construct the same element this would cause them to be considered different.

View File

@@ -145,7 +145,7 @@ struct BufferBranchState {
/// state of a buffer.
pub struct BufferSnapshot {
pub text: text::BufferSnapshot,
pub(crate) syntax: SyntaxSnapshot,
pub syntax: SyntaxSnapshot,
file: Option<Arc<dyn File>>,
diagnostics: SmallVec<[(LanguageServerId, DiagnosticSet); 2]>,
remote_selections: TreeMap<ReplicaId, SelectionSet>,
@@ -660,7 +660,10 @@ impl HighlightedTextBuilder {
syntax_snapshot: &'a SyntaxSnapshot,
) -> BufferChunks<'a> {
let captures = syntax_snapshot.captures(range.clone(), snapshot, |grammar| {
grammar.highlights_query.as_ref()
grammar
.highlights_config
.as_ref()
.map(|config| &config.query)
});
let highlight_maps = captures
@@ -3246,7 +3249,10 @@ impl BufferSnapshot {
fn get_highlights(&self, range: Range<usize>) -> (SyntaxMapCaptures<'_>, Vec<HighlightMap>) {
let captures = self.syntax.captures(range, &self.text, |grammar| {
grammar.highlights_query.as_ref()
grammar
.highlights_config
.as_ref()
.map(|config| &config.query)
});
let highlight_maps = captures
.grammars()
@@ -3310,18 +3316,25 @@ impl BufferSnapshot {
/// Iterates over every [`SyntaxLayer`] in the buffer.
pub fn syntax_layers(&self) -> impl Iterator<Item = SyntaxLayer<'_>> + '_ {
self.syntax
.layers_for_range(0..self.len(), &self.text, true)
self.syntax_layers_for_range(0..self.len(), true)
}
pub fn syntax_layer_at<D: ToOffset>(&self, position: D) -> Option<SyntaxLayer<'_>> {
let offset = position.to_offset(self);
self.syntax
.layers_for_range(offset..offset, &self.text, false)
self.syntax_layers_for_range(offset..offset, false)
.filter(|l| l.node().end_byte() > offset)
.last()
}
pub fn syntax_layers_for_range<D: ToOffset>(
&self,
range: Range<D>,
include_hidden: bool,
) -> impl Iterator<Item = SyntaxLayer<'_>> + '_ {
self.syntax
.layers_for_range(range, &self.text, include_hidden)
}
pub fn smallest_syntax_layer_containing<D: ToOffset>(
&self,
range: Range<D>,
@@ -3859,9 +3872,12 @@ impl BufferSnapshot {
text: item.text,
highlight_ranges: item.highlight_ranges,
name_ranges: item.name_ranges,
body_range: item.body_range.map(|body_range| {
self.anchor_after(body_range.start)..self.anchor_before(body_range.end)
}),
signature_range: item
.signature_range
.map(|r| self.anchor_after(r.start)..self.anchor_before(r.end)),
body_range: item
.body_range
.map(|r| self.anchor_after(r.start)..self.anchor_before(r.end)),
annotation_range: annotation_row_range.map(|annotation_range| {
self.anchor_after(Point::new(annotation_range.start, 0))
..self.anchor_before(Point::new(
@@ -3901,38 +3917,51 @@ impl BufferSnapshot {
let mut open_point = None;
let mut close_point = None;
let mut buffer_ranges = Vec::new();
for capture in mat.captures {
let node_is_name;
if capture.index == config.name_capture_ix {
node_is_name = true;
} else if Some(capture.index) == config.context_capture_ix
|| (Some(capture.index) == config.extra_context_capture_ix && include_extra_context)
{
node_is_name = false;
} else {
if Some(capture.index) == config.open_capture_ix {
open_point = Some(Point::from_ts_point(capture.node.end_position()));
} else if Some(capture.index) == config.close_capture_ix {
close_point = Some(Point::from_ts_point(capture.node.start_position()));
}
continue;
let mut signature_start = None;
let mut signature_end = None;
let mut extend_signature_range = |node: tree_sitter::Node| {
if signature_start.is_none() {
signature_start = Some(Point::from_ts_point(node.start_position()));
}
signature_end = Some(Point::from_ts_point(node.end_position()));
};
let mut range = capture.node.start_byte()..capture.node.end_byte();
let start = capture.node.start_position();
if capture.node.end_position().row > start.row {
let mut buffer_ranges = Vec::new();
let mut add_to_buffer_ranges = |node: tree_sitter::Node, node_is_name| {
let mut range = node.start_byte()..node.end_byte();
let start = node.start_position();
if node.end_position().row > start.row {
range.end = range.start + self.line_len(start.row as u32) as usize - start.column;
}
if !range.is_empty() {
buffer_ranges.push((range, node_is_name));
}
};
for capture in mat.captures {
if capture.index == config.name_capture_ix {
add_to_buffer_ranges(capture.node, true);
extend_signature_range(capture.node);
} else if Some(capture.index) == config.context_capture_ix
|| (Some(capture.index) == config.extra_context_capture_ix && include_extra_context)
{
add_to_buffer_ranges(capture.node, false);
extend_signature_range(capture.node);
} else {
if Some(capture.index) == config.open_capture_ix {
open_point = Some(Point::from_ts_point(capture.node.end_position()));
} else if Some(capture.index) == config.close_capture_ix {
close_point = Some(Point::from_ts_point(capture.node.start_position()));
}
}
}
if buffer_ranges.is_empty() {
return None;
}
let mut text = String::new();
let mut highlight_ranges = Vec::new();
let mut name_ranges = Vec::new();
@@ -3941,7 +3970,6 @@ impl BufferSnapshot {
true,
);
let mut last_buffer_range_end = 0;
for (buffer_range, is_name) in buffer_ranges {
let space_added = !text.is_empty() && buffer_range.start > last_buffer_range_end;
if space_added {
@@ -3983,12 +4011,17 @@ impl BufferSnapshot {
last_buffer_range_end = buffer_range.end;
}
let signature_range = signature_start
.zip(signature_end)
.map(|(start, end)| start..end);
Some(OutlineItem {
depth: 0, // We'll calculate the depth later
range: item_point_range,
text,
highlight_ranges,
name_ranges,
signature_range,
body_range: open_point.zip(close_point).map(|(start, end)| start..end),
annotation_range: None,
})

View File

@@ -81,7 +81,9 @@ pub use language_registry::{
};
pub use lsp::{LanguageServerId, LanguageServerName};
pub use outline::*;
pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer, ToTreeSitterPoint, TreeSitterOptions};
pub use syntax_map::{
OwnedSyntaxLayer, SyntaxLayer, SyntaxMapMatches, ToTreeSitterPoint, TreeSitterOptions,
};
pub use text::{AnchorRangeExt, LineEnding};
pub use tree_sitter::{Node, Parser, Tree, TreeCursor};
@@ -1154,7 +1156,7 @@ pub struct Grammar {
id: GrammarId,
pub ts_language: tree_sitter::Language,
pub(crate) error_query: Option<Query>,
pub(crate) highlights_query: Option<Query>,
pub highlights_config: Option<HighlightsConfig>,
pub(crate) brackets_config: Option<BracketsConfig>,
pub(crate) redactions_config: Option<RedactionConfig>,
pub(crate) runnable_config: Option<RunnableConfig>,
@@ -1168,6 +1170,11 @@ pub struct Grammar {
pub(crate) highlight_map: Mutex<HighlightMap>,
}
pub struct HighlightsConfig {
pub query: Query,
pub identifier_capture_indices: Vec<u32>,
}
struct IndentConfig {
query: Query,
indent_capture_ix: u32,
@@ -1332,7 +1339,7 @@ impl Language {
grammar: ts_language.map(|ts_language| {
Arc::new(Grammar {
id: GrammarId::new(),
highlights_query: None,
highlights_config: None,
brackets_config: None,
outline_config: None,
text_object_config: None,
@@ -1430,7 +1437,29 @@ impl Language {
pub fn with_highlights_query(mut self, source: &str) -> Result<Self> {
let grammar = self.grammar_mut()?;
grammar.highlights_query = Some(Query::new(&grammar.ts_language, source)?);
let query = Query::new(&grammar.ts_language, source)?;
let mut identifier_capture_indices = Vec::new();
for name in [
"variable",
"constant",
"constructor",
"function",
"function.method",
"function.method.call",
"function.special",
"property",
"type",
"type.interface",
] {
identifier_capture_indices.extend(query.capture_index_for_name(name));
}
grammar.highlights_config = Some(HighlightsConfig {
query,
identifier_capture_indices,
});
Ok(self)
}
@@ -1856,7 +1885,10 @@ impl Language {
let tree = grammar.parse_text(text, None);
let captures =
SyntaxSnapshot::single_tree_captures(range.clone(), text, &tree, self, |grammar| {
grammar.highlights_query.as_ref()
grammar
.highlights_config
.as_ref()
.map(|config| &config.query)
});
let highlight_maps = vec![grammar.highlight_map()];
let mut offset = 0;
@@ -1885,10 +1917,10 @@ impl Language {
pub fn set_theme(&self, theme: &SyntaxTheme) {
if let Some(grammar) = self.grammar.as_ref()
&& let Some(highlights_query) = &grammar.highlights_query
&& let Some(highlights_config) = &grammar.highlights_config
{
*grammar.highlight_map.lock() =
HighlightMap::new(highlights_query.capture_names(), theme);
HighlightMap::new(highlights_config.query.capture_names(), theme);
}
}
@@ -2103,8 +2135,9 @@ impl Grammar {
pub fn highlight_id_for_name(&self, name: &str) -> Option<HighlightId> {
let capture_id = self
.highlights_query
.highlights_config
.as_ref()?
.query
.capture_index_for_name(name)?;
Some(self.highlight_map.lock().get(capture_id))
}

View File

@@ -552,6 +552,7 @@ pub struct LanguageSettingsContent {
///
/// Default: ["..."]
#[serde(default)]
#[settings_ui(skip)]
pub language_servers: Option<Vec<String>>,
/// Controls where the `editor::Rewrap` action is allowed for this language.
///

View File

@@ -19,6 +19,7 @@ pub struct OutlineItem<T> {
pub text: String,
pub highlight_ranges: Vec<(Range<usize>, HighlightStyle)>,
pub name_ranges: Vec<Range<usize>>,
pub signature_range: Option<Range<T>>,
pub body_range: Option<Range<T>>,
pub annotation_range: Option<Range<T>>,
}
@@ -35,6 +36,10 @@ impl<T: ToPoint> OutlineItem<T> {
text: self.text.clone(),
highlight_ranges: self.highlight_ranges.clone(),
name_ranges: self.name_ranges.clone(),
signature_range: self
.signature_range
.as_ref()
.map(|r| r.start.to_point(buffer)..r.end.to_point(buffer)),
body_range: self
.body_range
.as_ref()
@@ -208,6 +213,7 @@ mod tests {
text: "class Foo".to_string(),
highlight_ranges: vec![],
name_ranges: vec![6..9],
signature_range: None,
body_range: None,
annotation_range: None,
},
@@ -217,6 +223,7 @@ mod tests {
text: "private".to_string(),
highlight_ranges: vec![],
name_ranges: vec![],
signature_range: None,
body_range: None,
annotation_range: None,
},
@@ -241,6 +248,7 @@ mod tests {
text: "fn process".to_string(),
highlight_ranges: vec![],
name_ranges: vec![3..10],
signature_range: None,
body_range: None,
annotation_range: None,
},
@@ -250,6 +258,7 @@ mod tests {
text: "struct DataProcessor".to_string(),
highlight_ranges: vec![],
name_ranges: vec![7..20],
signature_range: None,
body_range: None,
annotation_range: None,
},

View File

@@ -1409,12 +1409,15 @@ fn assert_capture_ranges(
) {
let mut actual_ranges = Vec::<Range<usize>>::new();
let captures = syntax_map.captures(0..buffer.len(), buffer, |grammar| {
grammar.highlights_query.as_ref()
grammar
.highlights_config
.as_ref()
.map(|config| &config.query)
});
let queries = captures
.grammars()
.iter()
.map(|grammar| grammar.highlights_query.as_ref().unwrap())
.map(|grammar| &grammar.highlights_config.as_ref().unwrap().query)
.collect::<Vec<_>>();
for capture in captures {
let name = &queries[capture.grammar_index].capture_names()[capture.index as usize];

View File

@@ -29,6 +29,7 @@ copilot.workspace = true
credentials_provider.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
editor.workspace = true
fs.workspace = true
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
@@ -61,6 +62,7 @@ util.workspace = true
vercel = { workspace = true, features = ["schemars"] }
workspace-hack.workspace = true
x_ai = { workspace = true, features = ["schemars"] }
zed_env_vars.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }

View File

@@ -0,0 +1,295 @@
use anyhow::{Result, anyhow};
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, future};
use gpui::{AsyncApp, Context, SharedString, Task};
use language_model::AuthenticateError;
use std::{
fmt::{Display, Formatter},
sync::Arc,
};
use util::ResultExt as _;
use zed_env_vars::EnvVar;
/// Manages a single API key for a language model provider. API keys either come from environment
/// variables or the system keychain.
///
/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
/// only used with that URL.
pub struct ApiKeyState {
url: SharedString,
load_status: LoadStatus,
load_task: Option<future::Shared<Task<()>>>,
}
#[derive(Debug, Clone)]
pub enum LoadStatus {
NotPresent,
Error(String),
Loaded(ApiKey),
}
#[derive(Debug, Clone)]
pub struct ApiKey {
source: ApiKeySource,
key: Arc<str>,
}
impl ApiKeyState {
pub fn new(url: SharedString) -> Self {
Self {
url,
load_status: LoadStatus::NotPresent,
load_task: None,
}
}
pub fn has_key(&self) -> bool {
matches!(self.load_status, LoadStatus::Loaded { .. })
}
pub fn is_from_env_var(&self) -> bool {
match &self.load_status {
LoadStatus::Loaded(ApiKey {
source: ApiKeySource::EnvVar { .. },
..
}) => true,
_ => false,
}
}
/// Get the stored API key, verifying that it is associated with the URL. Returns `None` if
/// there is no key or for URL mismatches, and the mismatch case is logged.
///
/// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been
/// called with this URL.
pub fn key(&self, url: &str) -> Option<Arc<str>> {
let api_key = match &self.load_status {
LoadStatus::Loaded(api_key) => api_key,
_ => return None,
};
if url == self.url.as_str() {
Some(api_key.key.clone())
} else if let ApiKeySource::EnvVar(var_name) = &api_key.source {
log::warn!(
"{} is now being used with URL {}, when initially it was used with URL {}",
var_name,
url,
self.url
);
Some(api_key.key.clone())
} else {
// bug case because load_if_needed should be called whenever the url may have changed
log::error!(
"bug: Attempted to use API key associated with URL {} instead with URL {}",
self.url,
url
);
None
}
}
/// Set or delete the API key in the system keychain.
pub fn store<Ent: 'static>(
&mut self,
url: SharedString,
key: Option<String>,
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
cx: &Context<Ent>,
) -> Task<Result<()>> {
if self.is_from_env_var() {
return Task::ready(Err(anyhow!(
"bug: attempted to store API key in system keychain when API key is from env var",
)));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |ent, cx| {
if let Some(key) = &key {
credentials_provider
.write_credentials(&url, "Bearer", key.as_bytes(), cx)
.await
.log_err();
} else {
credentials_provider
.delete_credentials(&url, cx)
.await
.log_err();
}
ent.update(cx, |ent, cx| {
let this = get_this(ent);
this.url = url;
this.load_status = match &key {
Some(key) => LoadStatus::Loaded(ApiKey {
source: ApiKeySource::SystemKeychain,
key: key.as_str().into(),
}),
None => LoadStatus::NotPresent,
};
cx.notify();
})
})
}
/// Reloads the API key if the current API key is associated with a different URL.
///
/// Note that it is not efficient to use this or `load_if_needed` with multiple URLs
/// interchangeably - URL change should correspond to some user initiated change.
pub fn handle_url_change<Ent: 'static>(
&mut self,
url: SharedString,
env_var: &EnvVar,
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
cx: &mut Context<Ent>,
) {
if url != self.url {
if !self.is_from_env_var() {
// loading will continue even though this result task is dropped
let _task = self.load_if_needed(url, env_var, get_this, cx);
}
}
}
/// If needed, loads the API key associated with the given URL from the system keychain. When a
/// non-empty environment variable is provided, it will be used instead. If called when an API
/// key was already loaded for a different URL, that key will be cleared before loading.
///
/// Dropping the returned Task does not cancel key loading.
pub fn load_if_needed<Ent: 'static>(
&mut self,
url: SharedString,
env_var: &EnvVar,
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
cx: &mut Context<Ent>,
) -> Task<Result<(), AuthenticateError>> {
if let LoadStatus::Loaded { .. } = &self.load_status
&& self.url == url
{
return Task::ready(Ok(()));
}
if let Some(key) = &env_var.value
&& !key.is_empty()
{
let api_key = ApiKey::from_env(env_var.name.clone(), key);
self.url = url;
self.load_status = LoadStatus::Loaded(api_key);
self.load_task = None;
cx.notify();
return Task::ready(Ok(()));
}
let task = if let Some(load_task) = &self.load_task {
load_task.clone()
} else {
let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
self.url = url;
self.load_status = LoadStatus::NotPresent;
self.load_task = Some(load_task.clone());
cx.notify();
load_task
};
cx.spawn(async move |ent, cx| {
task.await;
ent.update(cx, |ent, _cx| {
get_this(ent).load_status.clone().into_authenticate_result()
})
.ok();
Ok(())
})
}
fn load<Ent: 'static>(
url: SharedString,
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
cx: &Context<Ent>,
) -> Task<()> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn({
async move |ent, cx| {
let load_status =
ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
.await;
ent.update(cx, |ent, cx| {
let this = get_this(ent);
this.url = url;
this.load_status = load_status;
this.load_task = None;
cx.notify();
})
.ok();
}
})
}
}
impl ApiKey {
pub fn key(&self) -> &str {
&self.key
}
pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
Self {
source: ApiKeySource::EnvVar(env_var_name),
key: key.into(),
}
}
pub async fn load_from_system_keychain(
url: &str,
credentials_provider: &dyn CredentialsProvider,
cx: &AsyncApp,
) -> Result<Self, AuthenticateError> {
Self::load_from_system_keychain_impl(url, credentials_provider, cx)
.await
.into_authenticate_result()
}
async fn load_from_system_keychain_impl(
url: &str,
credentials_provider: &dyn CredentialsProvider,
cx: &AsyncApp,
) -> LoadStatus {
if url.is_empty() {
return LoadStatus::NotPresent;
}
let read_result = credentials_provider.read_credentials(&url, cx).await;
let api_key = match read_result {
Ok(Some((_, api_key))) => api_key,
Ok(None) => return LoadStatus::NotPresent,
Err(err) => return LoadStatus::Error(err.to_string()),
};
let key = match str::from_utf8(&api_key) {
Ok(key) => key,
Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")),
};
LoadStatus::Loaded(Self {
source: ApiKeySource::SystemKeychain,
key: key.into(),
})
}
}
impl LoadStatus {
fn into_authenticate_result(self) -> Result<ApiKey, AuthenticateError> {
match self {
LoadStatus::Loaded(api_key) => Ok(api_key),
LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound),
LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))),
}
}
}
#[derive(Debug, Clone)]
enum ApiKeySource {
EnvVar(SharedString),
SystemKeychain,
}
impl Display for ApiKeySource {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var),
ApiKeySource::SystemKeychain => write!(f, "system keychain"),
}
}
}

View File

@@ -7,6 +7,7 @@ use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
mod api_key;
pub mod provider;
mod settings;
pub mod ui;

View File

@@ -1,18 +1,14 @@
use crate::AllLanguageModelSettings;
use crate::api_key::ApiKeyState;
use crate::ui::InstructionListItem;
use anthropic::{
AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent,
ToolResultPart, Usage,
ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent,
ToolResultContent, ToolResultPart, Usage,
};
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
};
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, FontStyle, Task, TextStyle, WhiteSpace};
use http_client::HttpClient;
use language_model::{
AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
@@ -27,11 +23,12 @@ use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
use util::{ResultExt, truncate_and_trailoff};
use zed_env_vars::{EnvVar, env_var};
const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
@@ -97,91 +94,52 @@ pub struct AnthropicLanguageModelProvider {
state: gpui::Entity<State>,
}
const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY";
const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
api_key_state: ApiKeyState,
}
impl State {
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, cx)
.await
.ok();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
.await
.ok();
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
self.api_key_state.has_key()
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let key = AnthropicLanguageModelProvider::api_key(cx);
cx.spawn(async move |this, cx| {
let key = key.await?;
this.update(cx, |this, cx| {
this.api_key = Some(key.key);
this.api_key_from_env = key.from_env;
cx.notify();
})?;
Ok(())
})
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let api_url = AnthropicLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
}
pub struct ApiKey {
pub key: String,
pub from_env: bool,
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = AnthropicLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
}
}
impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
cx.notify();
}),
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
}
});
Self { http_client, state }
@@ -197,30 +155,16 @@ impl AnthropicLanguageModelProvider {
})
}
pub fn api_key(cx: &mut App) -> Task<Result<ApiKey, AuthenticateError>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.clone();
fn settings(cx: &App) -> &AnthropicSettings {
&crate::AllLanguageModelSettings::get_global(cx).anthropic
}
if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
Task::ready(Ok(ApiKey {
key,
from_env: true,
}))
fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
ANTHROPIC_API_URL.into()
} else {
cx.spawn(async move |cx| {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
Ok(ApiKey {
key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
from_env: false,
})
})
SharedString::new(api_url.as_str())
}
}
}
@@ -275,11 +219,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
}
// Override with available models from settings
for model in AllLanguageModelSettings::get_global(cx)
.anthropic
.available_models
.iter()
{
for model in &AnthropicLanguageModelProvider::settings(cx).available_models {
models.insert(
model.name.clone(),
anthropic::Model::Custom {
@@ -327,7 +267,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -417,11 +358,11 @@ impl AnthropicModel {
> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(state.api_key.clone(), settings.api_url.clone())
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
let api_url = AnthropicLanguageModelProvider::api_url(cx);
(state.api_key_state.key(&api_url), api_url)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
return future::ready(Err(anyhow!("App state dropped").into())).boxed();
};
let beta_headers = self.model.beta_headers();
@@ -483,7 +424,10 @@ impl LanguageModel for AnthropicModel {
}
fn api_key(&self, cx: &App) -> Option<String> {
self.state.read(cx).api_key.clone()
self.state.read_with(cx, |state, cx| {
let api_url = AnthropicLanguageModelProvider::api_url(cx);
state.api_key_state.key(&api_url).map(|key| key.to_string())
})
}
fn max_token_count(&self) -> u64 {
@@ -984,15 +928,17 @@ impl ConfigurationView {
return;
}
// url changes can cause the editor to be displayed again
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -1001,11 +947,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
state
.update(cx, |state, cx| state.set_api_key(None, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
@@ -1040,7 +986,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
@@ -1079,7 +1025,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."),
format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small)
.color(Color::Muted),
@@ -1099,9 +1045,14 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
} else {
"API key configured.".to_string()
let api_url = AnthropicLanguageModelProvider::api_url(cx);
if api_url == ANTHROPIC_API_URL {
"API key configured".to_string()
} else {
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
}
})),
)
.child(
@@ -1112,7 +1063,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable.")))
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)

View File

@@ -215,11 +215,21 @@ impl State {
self.default_model = models
.iter()
.find(|model| model.id == response.default_model)
.find(|model| {
response
.default_model
.as_ref()
.is_some_and(|default_model_id| &model.id == default_model_id)
})
.cloned();
self.default_fast_model = models
.iter()
.find(|model| model.id == response.default_fast_model)
.find(|model| {
response
.default_fast_model
.as_ref()
.is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
})
.cloned();
self.recommended_models = response
.recommended_models
@@ -541,29 +551,36 @@ where
impl From<ApiError> for LanguageModelCompletionError {
fn from(error: ApiError) -> Self {
if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body)
&& cloud_error.code.starts_with("upstream_http_")
{
let status = if let Some(status) = cloud_error.upstream_status {
status
} else if cloud_error.code.ends_with("_error") {
error.status
} else {
// If there's a status code in the code string (e.g. "upstream_http_429")
// then use that; otherwise, see if the JSON contains a status code.
cloud_error
.code
.strip_prefix("upstream_http_")
.and_then(|code_str| code_str.parse::<u16>().ok())
.and_then(|code| StatusCode::from_u16(code).ok())
.unwrap_or(error.status)
};
if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
if cloud_error.code.starts_with("upstream_http_") {
let status = if let Some(status) = cloud_error.upstream_status {
status
} else if cloud_error.code.ends_with("_error") {
error.status
} else {
// If there's a status code in the code string (e.g. "upstream_http_429")
// then use that; otherwise, see if the JSON contains a status code.
cloud_error
.code
.strip_prefix("upstream_http_")
.and_then(|code_str| code_str.parse::<u16>().ok())
.and_then(|code| StatusCode::from_u16(code).ok())
.unwrap_or(error.status)
};
return LanguageModelCompletionError::UpstreamProviderError {
message: cloud_error.message,
status,
retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
};
return LanguageModelCompletionError::UpstreamProviderError {
message: cloud_error.message,
status,
retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
};
}
return LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
error.status,
cloud_error.message,
None,
);
}
let retry_after = None;

View File

@@ -1,12 +1,12 @@
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use deepseek::DEEPSEEK_API_URL;
use editor::{Editor, EditorElement, EditorStyle};
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{FutureExt, StreamExt, future, future::BoxFuture, stream::BoxStream};
use gpui::{
AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
WhiteSpace,
AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
Window,
};
use http_client::HttpClient;
use language_model::{
@@ -21,16 +21,19 @@ use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use theme::ThemeSettings;
use ui::{Icon, IconName, List, prelude::*};
use util::ResultExt;
use util::{ResultExt, truncate_and_trailoff};
use zed_env_vars::{EnvVar, env_var};
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
const API_KEY_ENV_VAR_NAME: &str = "DEEPSEEK_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default)]
struct RawToolCall {
@@ -59,95 +62,48 @@ pub struct DeepSeekLanguageModelProvider {
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
api_key_state: ApiKeyState,
}
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
self.api_key_state.has_key()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.deepseek
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.deepseek
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
.await?;
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.deepseek
.api_url
.clone();
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
}
}
impl DeepSeekLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
cx.notify();
}),
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
}
});
Self { http_client, state }
@@ -160,7 +116,20 @@ impl DeepSeekLanguageModelProvider {
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
}
fn settings(cx: &App) -> &DeepSeekSettings {
&crate::AllLanguageModelSettings::get_global(cx).deepseek
}
fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
DEEPSEEK_API_URL.into()
} else {
SharedString::new(api_url.as_str())
}
}
}
@@ -199,11 +168,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
models.insert("deepseek-chat", deepseek::Model::Chat);
models.insert("deepseek-reasoner", deepseek::Model::Reasoner);
for available_model in AllLanguageModelSettings::get_global(cx)
.deepseek
.available_models
.iter()
{
for available_model in &Self::settings(cx).available_models {
models.insert(
&available_model.name,
deepseek::Model::Custom {
@@ -240,7 +205,8 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -259,15 +225,20 @@ impl DeepSeekLanguageModel {
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<deepseek::StreamResponse>>>> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).deepseek;
(state.api_key.clone(), settings.api_url.clone())
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
(state.api_key_state.key(&api_url), api_url)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
return future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing DeepSeek API Key")?;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request =
deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
@@ -610,7 +581,7 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
let api_key = self.api_key_editor.read(cx).text(cx);
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
@@ -618,12 +589,10 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn(async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -631,10 +600,12 @@ impl ConfigurationView {
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn(async move |_, cx| state.update(cx, |state, cx| state.reset_api_key(cx))?.await)
.detach_and_log_err(cx);
cx.notify();
cx.spawn(async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(None, cx))?
.await
})
.detach_and_log_err(cx);
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
@@ -672,7 +643,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
@@ -706,8 +677,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(format!(
"Or set the {} environment variable.",
DEEPSEEK_API_KEY_VAR
"Or set the {API_KEY_ENV_VAR_NAME} environment variable."
))
.size(LabelSize::Small)
.color(Color::Muted),
@@ -727,9 +697,17 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {}", DEEPSEEK_API_KEY_VAR)
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
} else {
"API key configured".to_string()
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
if api_url == DEEPSEEK_API_URL {
"API key configured".to_string()
} else {
format!(
"API key configured for {}",
truncate_and_trailoff(&api_url, 32)
)
}
})),
)
.child(

View File

@@ -2,13 +2,14 @@ use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
ThinkingConfig, UsageMetadata,
};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
Window,
};
use http_client::HttpClient;
use language_model::{
@@ -26,19 +27,19 @@ use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::sync::{
Arc,
Arc, LazyLock,
atomic::{self, AtomicU64},
};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
use util::{ResultExt, truncate_and_trailoff};
use zed_env_vars::EnvVar;
use crate::AllLanguageModelSettings;
use crate::api_key::ApiKey;
use crate::api_key::ApiKeyState;
use crate::ui::InstructionListItem;
use super::anthropic::ApiKey;
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
@@ -91,101 +92,56 @@ pub struct GoogleLanguageModelProvider {
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
api_key_state: ApiKeyState,
}
const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY";
const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY";
const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
// Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
});
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
self.api_key_state.has_key()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.google
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let api_url = GoogleLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.google
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
.await?;
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.google
.api_url
.clone();
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) {
(api_key, true)
} else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = GoogleLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
}
}
impl GoogleLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
cx.notify();
}),
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
}
});
Self { http_client, state }
@@ -201,30 +157,32 @@ impl GoogleLanguageModelProvider {
})
}
pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
if let Some(key) = API_KEY_ENV_VAR.value.clone() {
return Task::ready(Ok(key));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.google
.api_url
.clone();
if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) {
Task::ready(Ok(ApiKey {
key,
from_env: true,
}))
} else {
cx.spawn(async move |cx| {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, cx)
let api_url = Self::api_url(cx).to_string();
cx.spawn(async move |cx| {
Ok(
ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
.key()
.to_string(),
)
})
}
Ok(ApiKey {
key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
from_env: false,
})
})
fn settings(cx: &App) -> &GoogleSettings {
&crate::AllLanguageModelSettings::get_global(cx).google
}
fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
google_ai::API_URL.into()
} else {
SharedString::new(api_url.as_str())
}
}
}
@@ -269,10 +227,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
}
// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
.google
.available_models
{
for model in &GoogleLanguageModelProvider::settings(cx).available_models {
models.insert(
model.name.clone(),
google_ai::Model::Custom {
@@ -317,7 +272,8 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -340,11 +296,11 @@ impl GoogleLanguageModel {
> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google;
(state.api_key.clone(), settings.api_url.clone())
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
let api_url = GoogleLanguageModelProvider::api_url(cx);
(state.api_key_state.key(&api_url), api_url)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
return future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
@@ -418,13 +374,16 @@ impl LanguageModel for GoogleLanguageModel {
let model_id = self.model.request_id().to_string();
let request = into_google(request, model_id, self.model.mode());
let http_client = self.http_client.clone();
let api_key = self.state.read(cx).api_key.clone();
let settings = &AllLanguageModelSettings::get_global(cx).google;
let api_url = settings.api_url.clone();
let api_url = GoogleLanguageModelProvider::api_url(cx);
let api_key = self.state.read(cx).api_key_state.key(&api_url);
async move {
let api_key = api_key.context("Missing Google API key")?;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
}
.into());
};
let response = google_ai::count_tokens(
http_client.as_ref(),
&api_url,
@@ -852,20 +811,22 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self.api_key_editor.read(cx).text(cx);
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
// url changes can cause the editor to be displayed again
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -874,11 +835,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
state
.update(cx, |state, cx| state.set_api_key(None, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
@@ -913,7 +874,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
@@ -950,7 +911,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."),
format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
)
@@ -969,9 +930,14 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {GEMINI_API_KEY_VAR} environment variable.")
format!("API key set in {} environment variable", API_KEY_ENV_VAR.name)
} else {
"API key configured.".to_string()
let api_url = GoogleLanguageModelProvider::api_url(cx);
if api_url == google_ai::API_URL {
"API key configured".to_string()
} else {
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
}
})),
)
.child(
@@ -982,7 +948,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset.")))
this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)

View File

@@ -1,10 +1,10 @@
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
Window,
};
use http_client::HttpClient;
use language_model::{
@@ -14,24 +14,28 @@ use language_model::{
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, TokenUsage,
};
use mistral::StreamResponse;
use mistral::{MISTRAL_API_URL, StreamResponse};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::collections::HashMap;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
use util::{ResultExt, truncate_and_trailoff};
use zed_env_vars::{EnvVar, env_var};
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default, Clone, Debug, PartialEq)]
pub struct MistralSettings {
pub api_url: String,
@@ -56,96 +60,48 @@ pub struct MistralLanguageModelProvider {
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
api_key_state: ApiKeyState,
}
const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY";
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
self.api_key_state.has_key()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.mistral
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let api_url = MistralLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.mistral
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
.await?;
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.mistral
.api_url
.clone();
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = MistralLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
}
}
impl MistralLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
cx.notify();
}),
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
}
});
Self { http_client, state }
@@ -160,6 +116,19 @@ impl MistralLanguageModelProvider {
request_limiter: RateLimiter::new(4),
})
}
fn settings(cx: &App) -> &MistralSettings {
&crate::AllLanguageModelSettings::get_global(cx).mistral
}
fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
mistral::MISTRAL_API_URL.into()
} else {
SharedString::new(api_url.as_str())
}
}
}
impl LanguageModelProviderState for MistralLanguageModelProvider {
@@ -202,10 +171,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
}
// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
.mistral
.available_models
{
for model in &Self::settings(cx).available_models {
models.insert(
model.name.clone(),
mistral::Model::Custom {
@@ -254,7 +220,8 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -276,15 +243,20 @@ impl MistralLanguageModel {
Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).mistral;
(state.api_key.clone(), settings.api_url.clone())
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
let api_url = MistralLanguageModelProvider::api_url(cx);
(state.api_key_state.key(&api_url), api_url)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
return future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing Mistral API Key")?;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request =
mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
@@ -780,20 +752,22 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self.api_key_editor.read(cx).text(cx);
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
// url changes can cause the editor to be displayed again
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -802,11 +776,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
state
.update(cx, |state, cx| state.set_api_key(None, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
@@ -841,7 +815,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
@@ -878,7 +852,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."),
format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
)
@@ -897,9 +871,14 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.")
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
} else {
"API key configured.".to_string()
let api_url = MistralLanguageModelProvider::api_url(cx);
if api_url == MISTRAL_API_URL {
"API key configured".to_string()
} else {
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
}
})),
)
.child(
@@ -910,7 +889,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable.")))
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)

View File

@@ -1,7 +1,8 @@
use anyhow::{Result, anyhow};
use fs::Fs;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{Stream, TryFutureExt, stream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use gpui::{AnyView, App, AsyncApp, Context, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -10,20 +11,25 @@ use language_model::{
LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
};
use menu;
use ollama::{
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionCall,
OllamaFunctionTool, OllamaToolCall, get_models, show_model, stream_chat_completion,
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OLLAMA_API_URL,
OllamaFunctionCall, OllamaFunctionTool, OllamaToolCall, get_models, show_model,
stream_chat_completion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use settings::{Settings, SettingsStore, update_settings_file};
use std::pin::Pin;
use std::sync::LazyLock;
use std::sync::atomic::{AtomicU64, Ordering};
use std::{collections::HashMap, sync::Arc};
use ui::{ButtonLike, Indicator, List, prelude::*};
use util::ResultExt;
use ui::{ButtonLike, ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use zed_env_vars::{EnvVar, env_var};
use crate::AllLanguageModelSettings;
use crate::api_key::ApiKeyState;
use crate::ui::InstructionListItem;
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
@@ -33,6 +39,9 @@ const OLLAMA_SITE: &str = "https://ollama.com/";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
const API_KEY_ENV_VAR_NAME: &str = "OLLAMA_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
pub api_url: String,
@@ -63,25 +72,61 @@ pub struct OllamaLanguageModelProvider {
}
pub struct State {
api_key_state: ApiKeyState,
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
fetched_models: Vec<ollama::Model>,
fetch_model_task: Option<Task<Result<()>>>,
_subscription: Subscription,
}
impl State {
fn is_authenticated(&self) -> bool {
!self.available_models.is_empty()
!self.fetched_models.is_empty()
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let api_url = OllamaLanguageModelProvider::api_url(cx);
let task = self
.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx);
self.fetched_models.clear();
cx.spawn(async move |this, cx| {
let result = task.await;
this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
.ok();
result
})
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = OllamaLanguageModelProvider::api_url(cx);
let task = self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
// Always try to fetch models - if no API key is needed (local Ollama), it will work
// If API key is needed and provided, it will work
// If API key is needed and not provided, it will fail gracefully
cx.spawn(async move |this, cx| {
let result = task.await;
this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
.ok();
result
})
}
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = Arc::clone(&self.http_client);
let api_url = settings.api_url.clone();
let api_url = OllamaLanguageModelProvider::api_url(cx);
let api_key = self.api_key_state.key(&api_url);
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(async move |this, cx| {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let models =
get_models(http_client.as_ref(), &api_url, api_key.as_deref(), None).await?;
let tasks = models
.into_iter()
@@ -92,9 +137,12 @@ impl State {
.map(|model| {
let http_client = Arc::clone(&http_client);
let api_url = api_url.clone();
let api_key = api_key.clone();
async move {
let name = model.name.as_str();
let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
let capabilities =
show_model(http_client.as_ref(), &api_url, api_key.as_deref(), name)
.await?;
let ollama_model = ollama::Model::new(
name,
None,
@@ -119,7 +167,7 @@ impl State {
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(cx, |this, cx| {
this.available_models = ollama_models;
this.fetched_models = ollama_models;
cx.notify();
})
})
@@ -129,15 +177,6 @@ impl State {
let task = self.fetch_models(cx);
self.fetch_model_task.replace(task);
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let fetch_models_task = self.fetch_models(cx);
cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
}
}
impl OllamaLanguageModelProvider {
@@ -145,30 +184,47 @@ impl OllamaLanguageModelProvider {
let this = Self {
http_client: http_client.clone(),
state: cx.new(|cx| {
let subscription = cx.observe_global::<SettingsStore>({
let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
cx.observe_global::<SettingsStore>({
let mut last_settings = OllamaLanguageModelProvider::settings(cx).clone();
move |this: &mut State, cx| {
let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
if &settings != new_settings {
settings = new_settings.clone();
this.restart_fetch_models_task(cx);
let current_settings = OllamaLanguageModelProvider::settings(cx);
let settings_changed = current_settings != &last_settings;
if settings_changed {
let url_changed = last_settings.api_url != current_settings.api_url;
last_settings = current_settings.clone();
if url_changed {
this.fetched_models.clear();
this.authenticate(cx).detach();
}
cx.notify();
}
}
});
})
.detach();
State {
http_client,
available_models: Default::default(),
fetched_models: Default::default(),
fetch_model_task: None,
_subscription: subscription,
api_key_state: ApiKeyState::new(Self::api_url(cx)),
}
}),
};
this.state
.update(cx, |state, cx| state.restart_fetch_models_task(cx));
this
}
fn settings(cx: &App) -> &OllamaSettings {
&AllLanguageModelSettings::get_global(cx).ollama
}
fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
OLLAMA_API_URL.into()
} else {
SharedString::new(api_url.as_str())
}
}
}
impl LanguageModelProviderState for OllamaLanguageModelProvider {
@@ -208,16 +264,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
let mut models: HashMap<String, ollama::Model> = HashMap::new();
// Add models from the Ollama API
for model in self.state.read(cx).available_models.iter() {
for model in self.state.read(cx).fetched_models.iter() {
models.insert(model.name.clone(), model.clone());
}
// Override with available models from settings
for model in AllLanguageModelSettings::get_global(cx)
.ollama
.available_models
.iter()
{
for model in &OllamaLanguageModelProvider::settings(cx).available_models {
models.insert(
model.name.clone(),
ollama::Model {
@@ -240,6 +292,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
model,
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
state: self.state.clone(),
}) as Arc<dyn LanguageModel>
})
.collect::<Vec<_>>();
@@ -267,7 +320,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.fetch_models(cx))
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -276,6 +330,7 @@ pub struct OllamaLanguageModel {
model: ollama::Model,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
state: gpui::Entity<State>,
}
impl OllamaLanguageModel {
@@ -454,15 +509,17 @@ impl LanguageModel for OllamaLanguageModel {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
settings.api_url.clone()
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
let api_url = OllamaLanguageModelProvider::api_url(cx);
(state.api_key_state.key(&api_url), api_url)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
};
let future = self.request_limiter.stream(async move {
let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
let stream =
stream_chat_completion(http_client.as_ref(), &api_url, api_key.as_deref(), request)
.await?;
let stream = map_to_language_model_completion_events(stream);
Ok(stream)
});
@@ -574,39 +631,221 @@ fn map_to_language_model_completion_events(
}
struct ConfigurationView {
api_key_editor: gpui::Entity<SingleLineInput>,
api_url_editor: gpui::Entity<SingleLineInput>,
state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>,
}
impl ConfigurationView {
pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let loading_models_task = Some(cx.spawn_in(window, {
let state = state.clone();
async move |this, cx| {
if let Some(task) = state
.update(cx, |state, cx| state.authenticate(cx))
.log_err()
{
task.await.log_err();
}
this.update(cx, |this, cx| {
this.loading_models_task = None;
cx.notify();
})
.log_err();
}
}));
let api_key_editor =
cx.new(|cx| SingleLineInput::new(window, cx, "63e02e...").label("API key"));
let api_url_editor = cx.new(|cx| {
let input = SingleLineInput::new(window, cx, OLLAMA_API_URL).label("API URL");
input.set_text(OllamaLanguageModelProvider::api_url(cx), window, cx);
input
});
cx.observe(&state, |_, _, cx| {
cx.notify();
})
.detach();
Self {
api_key_editor,
api_url_editor,
state,
loading_models_task,
}
}
fn retry_connection(&self, cx: &mut App) {
self.state
.update(cx, |state, cx| state.fetch_models(cx))
.detach_and_log_err(cx);
.update(cx, |state, cx| state.restart_fetch_models_task(cx));
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
// url changes can cause the editor to be displayed again
self.api_key_editor
.update(cx, |input, cx| input.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor
.update(cx, |input, cx| input.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(None, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn save_api_url(&mut self, cx: &mut Context<Self>) {
let api_url = self.api_url_editor.read(cx).text(cx).trim().to_string();
let current_url = OllamaLanguageModelProvider::api_url(cx);
if !api_url.is_empty() && &api_url != &current_url {
let fs = <dyn Fs>::global(cx);
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |settings, _| {
if let Some(settings) = settings.ollama.as_mut() {
settings.api_url = Some(api_url);
} else {
settings.ollama = Some(crate::settings::OllamaSettingsContent {
api_url: Some(api_url),
available_models: None,
});
}
});
}
}
fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_url_editor
.update(cx, |input, cx| input.set_text("", window, cx));
let fs = <dyn Fs>::global(cx);
update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
if let Some(settings) = settings.ollama.as_mut() {
settings.api_url = Some(OLLAMA_API_URL.into());
}
});
cx.notify();
}
fn render_instructions() -> Div {
v_flex()
.gap_2()
.child(Label::new(
"Run LLMs locally on your machine with Ollama, or connect to an Ollama server. \
Can provide access to Llama, Mistral, Gemma, and hundreds of other models.",
))
.child(Label::new("To use local Ollama:"))
.child(
List::new()
.child(InstructionListItem::new(
"Download and install Ollama from",
Some("ollama.com"),
Some("https://ollama.com/download"),
))
.child(InstructionListItem::text_only(
"Start Ollama and download a model: `ollama run gpt-oss:20b`",
))
.child(InstructionListItem::text_only(
"Click 'Connect' below to start using Ollama in Zed",
)),
)
.child(Label::new(
"Alternatively, you can connect to an Ollama server by specifying its \
URL and API key (may not be required):",
))
}
fn render_api_key_editor(&self, cx: &Context<Self>) -> Div {
let state = self.state.read(cx);
let env_var_set = state.api_key_state.is_from_env_var();
if !state.api_key_state.has_key() {
v_flex()
.on_action(cx.listener(Self::save_api_key))
.child(self.api_key_editor.clone())
.child(
Label::new(
format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed.")
)
.size(LabelSize::Small)
.color(Color::Muted),
)
} else {
h_flex()
.p_3()
.justify_between()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().elevated_surface_background)
.child(
h_flex()
.gap_2()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(
Label::new(
if env_var_set {
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
} else {
"API key configured".to_string()
}
)
)
)
.child(
Button::new("reset-api-key", "Reset API Key")
.label_size(LabelSize::Small)
.icon(IconName::Undo)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
}
}
fn render_api_url_editor(&self, cx: &Context<Self>) -> Div {
let api_url = OllamaLanguageModelProvider::api_url(cx);
let custom_api_url_set = api_url != OLLAMA_API_URL;
if custom_api_url_set {
h_flex()
.p_3()
.justify_between()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().elevated_surface_background)
.child(
h_flex()
.gap_2()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(v_flex().gap_1().child(Label::new(api_url))),
)
.child(
Button::new("reset-api-url", "Reset API URL")
.label_size(LabelSize::Small)
.icon(IconName::Undo)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.on_click(
cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
),
)
} else {
v_flex()
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
this.save_api_url(cx);
cx.notify();
}))
.gap_2()
.child(self.api_url_editor.clone())
}
}
}
@@ -614,98 +853,83 @@ impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_authenticated = self.state.read(cx).is_authenticated();
let ollama_intro =
"Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama.";
if self.loading_models_task.is_some() {
div().child(Label::new("Loading models...")).into_any()
} else {
v_flex()
.gap_2()
.child(
v_flex().gap_1().child(Label::new(ollama_intro)).child(
List::new()
.child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant."))
.child(InstructionListItem::text_only(
"Once installed, try `ollama run llama3.2`",
)),
),
)
.child(
h_flex()
.w_full()
.justify_between()
.gap_2()
.child(
h_flex()
.w_full()
.gap_2()
.map(|this| {
if is_authenticated {
this.child(
Button::new("ollama-site", "Ollama")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
.into_any_element(),
)
} else {
this.child(
Button::new(
"download_ollama_button",
"Download Ollama",
)
v_flex()
.gap_2()
.child(Self::render_instructions())
.child(self.render_api_url_editor(cx))
.child(self.render_api_key_editor(cx))
.child(
h_flex()
.w_full()
.justify_between()
.gap_2()
.child(
h_flex()
.w_full()
.gap_2()
.map(|this| {
if is_authenticated {
this.child(
Button::new("ollama-site", "Ollama")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Small)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
.into_any_element(),
)
} else {
this.child(
Button::new("download_ollama_button", "Download Ollama")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, _, cx| {
cx.open_url(OLLAMA_DOWNLOAD_URL)
})
.into_any_element(),
)
}
})
.child(
Button::new("view-models", "View All Models")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
),
)
.map(|this| {
if is_authenticated {
this.child(
ButtonLike::new("connected")
.disabled(true)
.cursor_style(gpui::CursorStyle::Arrow)
.child(
h_flex()
.gap_2()
.child(Indicator::dot().color(Color::Success))
.child(Label::new("Connected"))
.into_any_element(),
),
)
} else {
this.child(
Button::new("retry_ollama_models", "Connect")
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon(IconName::PlayFilled)
.on_click(cx.listener(move |this, _, _, cx| {
)
}
})
.child(
Button::new("view-models", "View All Models")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
),
)
.map(|this| {
if is_authenticated {
this.child(
ButtonLike::new("connected")
.disabled(true)
.cursor_style(gpui::CursorStyle::Arrow)
.child(
h_flex()
.gap_2()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new("Connected"))
.into_any_element(),
),
)
} else {
this.child(
Button::new("retry_ollama_models", "Connect")
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon(IconName::PlayOutlined)
.on_click(
cx.listener(move |this, _, _, cx| {
this.retry_connection(cx)
})),
)
}
})
)
.into_any()
}
}),
),
)
}
}),
)
}
}

View File

@@ -1,10 +1,8 @@
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -14,24 +12,29 @@ use language_model::{
RateLimiter, Role, StopReason, TokenUsage,
};
use menu;
use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
use open_ai::{
ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, stream_completion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr as _;
use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
use util::{ResultExt, truncate_and_trailoff};
use zed_env_vars::{EnvVar, env_var};
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
pub api_url: String,
@@ -54,132 +57,48 @@ pub struct OpenAiLanguageModelProvider {
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
last_api_url: String,
_subscription: Subscription,
api_key_state: ApiKeyState,
}
const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
self.api_key_state.has_key()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let api_url = OpenAiLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn get_api_key(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
self.get_api_key(cx)
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = OpenAiLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
}
}
impl OpenAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let initial_api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
last_api_url: initial_api_url.clone(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let current_api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
if this.last_api_url != current_api_url {
this.last_api_url = current_api_url;
if !this.api_key_from_env {
this.api_key = None;
let spawn_task = cx.spawn(async move |handle, cx| {
if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) {
if let Err(_) = task.await {
handle
.update(cx, |this, _| {
this.api_key = None;
this.api_key_from_env = false;
})
.ok();
}
}
});
spawn_task.detach();
}
}
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
cx.notify();
}),
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
}
});
Self { http_client, state }
@@ -194,6 +113,19 @@ impl OpenAiLanguageModelProvider {
request_limiter: RateLimiter::new(4),
})
}
fn settings(cx: &App) -> &OpenAiSettings {
&crate::AllLanguageModelSettings::get_global(cx).openai
}
fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
open_ai::OPEN_AI_API_URL.into()
} else {
SharedString::new(api_url.as_str())
}
}
}
impl LanguageModelProviderState for OpenAiLanguageModelProvider {
@@ -236,10 +168,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
}
// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
.openai
.available_models
{
for model in &OpenAiLanguageModelProvider::settings(cx).available_models {
models.insert(
model.name.clone(),
open_ai::Model::Custom {
@@ -278,7 +207,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -298,11 +228,12 @@ impl OpenAiLanguageModel {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
(state.api_key.clone(), settings.api_url.clone())
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
let api_url = OpenAiLanguageModelProvider::api_url(cx);
(state.api_key_state.key(&api_url), api_url)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
return future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
@@ -802,45 +733,35 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self
.api_key_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
// Don't proceed if no API key is provided and we're not authenticated
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
// url changes can cause the editor to be displayed again
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor.update(cx, |input, cx| {
input.editor.update(cx, |editor, cx| {
editor.set_text("", window, cx);
});
});
self.api_key_editor
.update(cx, |input, cx| input.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
state
.update(cx, |state, cx| state.set_api_key(None, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
@@ -850,7 +771,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
let api_key_section = if self.should_render_editor(cx) {
v_flex()
@@ -872,10 +793,11 @@ impl Render for ConfigurationView {
)
.child(self.api_key_editor.clone())
.child(
Label::new(
format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
Label::new(format!(
"You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
))
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(
Label::new(
@@ -898,9 +820,14 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
} else {
"API key configured.".to_string()
let api_url = OpenAiLanguageModelProvider::api_url(cx);
if api_url == OPEN_AI_API_URL {
"API key configured".to_string()
} else {
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
}
})),
)
.child(
@@ -911,7 +838,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)

View File

@@ -1,9 +1,7 @@
use anyhow::{Context as _, Result, anyhow};
use credentials_provider::CredentialsProvider;
use anyhow::{Result, anyhow};
use convert_case::{Case, Casing};
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -17,12 +15,12 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use ui::{ElevationIndex, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
use util::{ResultExt, truncate_and_trailoff};
use zed_env_vars::EnvVar;
use crate::AllLanguageModelSettings;
use crate::api_key::ApiKeyState;
use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
#[derive(Default, Clone, Debug, PartialEq)]
@@ -70,124 +68,67 @@ pub struct OpenAiCompatibleLanguageModelProvider {
pub struct State {
id: Arc<str>,
env_var_name: Arc<str>,
api_key: Option<String>,
api_key_from_env: bool,
api_key_env_var: EnvVar,
api_key_state: ApiKeyState,
settings: OpenAiCompatibleSettings,
_subscription: Subscription,
}
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
self.api_key_state.has_key()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = self.settings.api_url.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let api_url = SharedString::new(self.settings.api_url.as_str());
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = self.settings.api_url.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn get_api_key(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let env_var_name = self.env_var_name.clone();
let api_url = self.settings.api_url.clone();
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
self.get_api_key(cx)
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = SharedString::new(self.settings.api_url.clone());
self.api_key_state.load_if_needed(
api_url,
&self.api_key_env_var,
|this| &mut this.api_key_state,
cx,
)
}
}
impl OpenAiCompatibleLanguageModelProvider {
pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
AllLanguageModelSettings::get_global(cx)
crate::AllLanguageModelSettings::get_global(cx)
.openai_compatible
.get(id)
}
let state = cx.new(|cx| State {
id: id.clone(),
env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_key_env_var_name = format!("{}_API_KEY", id).to_case(Case::UpperSnake).into();
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let Some(settings) = resolve_settings(&this.id, cx).cloned() else {
return;
};
if &this.settings != &settings {
if settings.api_url != this.settings.api_url && !this.api_key_from_env {
let spawn_task = cx.spawn(async move |handle, cx| {
if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) {
if let Err(_) = task.await {
handle
.update(cx, |this, _| {
this.api_key = None;
this.api_key_from_env = false;
})
.ok();
}
}
});
spawn_task.detach();
}
let api_url = SharedString::new(settings.api_url.as_str());
this.api_key_state.handle_url_change(
api_url,
&this.api_key_env_var,
|this| &mut this.api_key_state,
cx,
);
this.settings = settings;
cx.notify();
}
}),
})
.detach();
let settings = resolve_settings(&id, cx).cloned().unwrap_or_default();
State {
id: id.clone(),
api_key_env_var: EnvVar::new(api_key_env_var_name),
api_key_state: ApiKeyState::new(SharedString::new(settings.api_url.as_str())),
settings,
}
});
Self {
@@ -274,7 +215,8 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -296,10 +238,15 @@ impl OpenAiCompatibleLanguageModel {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
(state.api_key.clone(), state.settings.api_url.clone())
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, _cx| {
let api_url = &state.settings.api_url;
(
state.api_key_state.key(api_url),
state.settings.api_url.clone(),
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
return future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let provider = self.provider_name.clone();
@@ -469,56 +416,47 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self
.api_key_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
// Don't proceed if no API key is provided and we're not authenticated
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
// url changes can cause the editor to be displayed again
self.api_key_editor
.update(cx, |input, cx| input.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor.update(cx, |input, cx| {
input.editor.update(cx, |editor, cx| {
editor.set_text("", window, cx);
});
});
self.api_key_editor
.update(cx, |input, cx| input.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
state
.update(cx, |state, cx| state.set_api_key(None, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
fn should_render_editor(&self, cx: &Context<Self>) -> bool {
!self.state.read(cx).is_authenticated()
}
}
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let env_var_name = self.state.read(cx).env_var_name.clone();
let state = self.state.read(cx);
let env_var_set = state.api_key_state.is_from_env_var();
let env_var_name = &state.api_key_env_var.name;
let api_key_section = if self.should_render_editor(cx) {
v_flex()
@@ -550,9 +488,9 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {env_var_name} environment variable.")
format!("API key set in {env_var_name} environment variable")
} else {
"API key configured.".to_string()
format!("API key configured for {}", truncate_and_trailoff(&state.settings.api_url, 32))
})),
)
.child(

View File

@@ -1,10 +1,9 @@
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Result, anyhow};
use collections::HashMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
};
use http_client::HttpClient;
use language_model::{
@@ -15,24 +14,28 @@ use language_model::{
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
};
use open_router::{
Model, ModelMode as OpenRouterModelMode, Provider, ResponseStreamEvent, list_models,
stream_completion,
Model, ModelMode as OpenRouterModelMode, OPEN_ROUTER_API_URL, Provider, ResponseStreamEvent,
list_models,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr as _;
use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
use util::{ResultExt, truncate_and_trailoff};
use zed_env_vars::{EnvVar, env_var};
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenRouterSettings {
pub api_url: String,
@@ -90,93 +93,37 @@ pub struct OpenRouterLanguageModelProvider {
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
api_key_state: ApiKeyState,
http_client: Arc<dyn HttpClient>,
available_models: Vec<open_router::Model>,
fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
settings: OpenRouterSettings,
_subscription: Subscription,
}
const OPENROUTER_API_KEY_VAR: &str = "OPENROUTER_API_KEY";
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
self.api_key_state.has_key()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.open_router
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.open_router
.api_url
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.restart_fetch_models_task(cx);
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.open_router
.api_url
.clone();
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
let task = self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
this.restart_fetch_models_task(cx);
cx.notify();
})?;
Ok(())
let result = task.await;
this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
.ok();
result
})
}
@@ -184,10 +131,9 @@ impl State {
&mut self,
cx: &mut Context<Self>,
) -> Task<Result<(), LanguageModelCompletionError>> {
let settings = &AllLanguageModelSettings::get_global(cx).open_router;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let Some(api_key) = self.api_key.clone() else {
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
let Some(api_key) = self.api_key_state.key(&api_url) else {
return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
}));
@@ -216,33 +162,52 @@ impl State {
if self.is_authenticated() {
let task = self.fetch_models(cx);
self.fetch_models_task.replace(task);
} else {
self.available_models = Vec::new();
}
}
}
impl OpenRouterLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
http_client: http_client.clone(),
available_models: Vec::new(),
fetch_models_task: None,
settings: OpenRouterSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let current_settings = &AllLanguageModelSettings::get_global(cx).open_router;
let settings_changed = current_settings != &this.settings;
if settings_changed {
this.settings = current_settings.clone();
this.restart_fetch_models_task(cx);
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>({
let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone();
move |this: &mut State, cx| {
let current_settings = OpenRouterLanguageModelProvider::settings(cx);
let settings_changed = current_settings != &last_settings;
if settings_changed {
last_settings = current_settings.clone();
this.authenticate(cx).detach();
cx.notify();
}
}
cx.notify();
}),
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
http_client: http_client.clone(),
available_models: Vec::new(),
fetch_models_task: None,
}
});
Self { http_client, state }
}
fn settings(cx: &App) -> &OpenRouterSettings {
&crate::AllLanguageModelSettings::get_global(cx).open_router
}
fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
OPEN_ROUTER_API_URL.into()
} else {
SharedString::new(api_url.as_str())
}
}
fn create_language_model(&self, model: open_router::Model) -> Arc<dyn LanguageModel> {
Arc::new(OpenRouterLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
@@ -287,10 +252,7 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
let mut models_from_api = self.state.read(cx).available_models.clone();
let mut settings_models = Vec::new();
for model in &AllLanguageModelSettings::get_global(cx)
.open_router
.available_models
{
for model in &Self::settings(cx).available_models {
settings_models.push(open_router::Model {
name: model.name.clone(),
display_name: model.display_name.clone(),
@@ -338,7 +300,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -366,14 +329,11 @@ impl OpenRouterLanguageModel {
>,
> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).open_router;
(state.api_key.clone(), settings.api_url.clone())
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
(state.api_key_state.key(&api_url), api_url)
}) else {
return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!(
"App state dropped"
))))
.boxed();
return future::ready(Err(anyhow!("App state dropped").into())).boxed();
};
async move {
@@ -382,7 +342,8 @@ impl OpenRouterLanguageModel {
provider: PROVIDER_NAME,
});
};
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let request =
open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
request.await.map_err(Into::into)
}
.boxed()
@@ -830,20 +791,22 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self.api_key_editor.read(cx).text(cx);
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
// url changes can cause the editor to be displayed again
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -852,11 +815,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
state
.update(cx, |state, cx| state.set_api_key(None, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
@@ -891,7 +854,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
@@ -928,7 +891,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
format!("You can also assign the {OPENROUTER_API_KEY_VAR} environment variable and restart Zed."),
format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
)
@@ -947,9 +910,14 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {OPENROUTER_API_KEY_VAR} environment variable.")
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
} else {
"API key configured.".to_string()
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
if api_url == OPEN_ROUTER_API_URL {
"API key configured".to_string()
} else {
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
}
})),
)
.child(
@@ -960,7 +928,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENROUTER_API_KEY_VAR} environment variable.")))
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)

View File

@@ -1,8 +1,7 @@
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -10,24 +9,26 @@ use language_model::{
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, RateLimiter, Role,
};
use menu;
use open_ai::ResponseStreamEvent;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use vercel::Model;
use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
use util::{ResultExt, truncate_and_trailoff};
use vercel::{Model, VERCEL_API_URL};
use zed_env_vars::{EnvVar, env_var};
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel");
const API_KEY_ENV_VAR_NAME: &str = "VERCEL_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default, Clone, Debug, PartialEq)]
pub struct VercelSettings {
pub api_url: String,
@@ -49,103 +50,48 @@ pub struct VercelLanguageModelProvider {
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
api_key_state: ApiKeyState,
}
const VERCEL_API_KEY_VAR: &str = "VERCEL_API_KEY";
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
self.api_key_state.has_key()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
let api_url = if settings.api_url.is_empty() {
vercel::VERCEL_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let api_url = VercelLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
let api_url = if settings.api_url.is_empty() {
vercel::VERCEL_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
let api_url = if settings.api_url.is_empty() {
vercel::VERCEL_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(VERCEL_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = VercelLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
}
}
impl VercelLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
cx.notify();
}),
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
}
});
Self { http_client, state }
@@ -160,6 +106,19 @@ impl VercelLanguageModelProvider {
request_limiter: RateLimiter::new(4),
})
}
fn settings(cx: &App) -> &VercelSettings {
&crate::AllLanguageModelSettings::get_global(cx).vercel
}
fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
VERCEL_API_URL.into()
} else {
SharedString::new(api_url.as_str())
}
}
}
impl LanguageModelProviderState for VercelLanguageModelProvider {
@@ -200,10 +159,7 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
}
}
for model in &AllLanguageModelSettings::get_global(cx)
.vercel
.available_models
{
for model in &Self::settings(cx).available_models {
models.insert(
model.name.clone(),
vercel::Model::Custom {
@@ -241,7 +197,8 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -261,16 +218,12 @@ impl VercelLanguageModel {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
let api_url = if settings.api_url.is_empty() {
vercel::VERCEL_API_URL.to_string()
} else {
settings.api_url.clone()
};
(state.api_key.clone(), api_url)
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
let api_url = VercelLanguageModelProvider::api_url(cx);
(state.api_key_state.key(&api_url), api_url)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
return future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
@@ -466,45 +419,35 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self
.api_key_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
// Don't proceed if no API key is provided and we're not authenticated
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
// url changes can cause the editor to be displayed again
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor.update(cx, |input, cx| {
input.editor.update(cx, |editor, cx| {
editor.set_text("", window, cx);
});
});
self.api_key_editor
.update(cx, |input, cx| input.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
state
.update(cx, |state, cx| state.set_api_key(None, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
@@ -514,7 +457,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
let api_key_section = if self.should_render_editor(cx) {
v_flex()
@@ -534,7 +477,7 @@ impl Render for ConfigurationView {
.child(self.api_key_editor.clone())
.child(
Label::new(format!(
"You can also assign the {VERCEL_API_KEY_VAR} environment variable and restart Zed."
"You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
))
.size(LabelSize::Small)
.color(Color::Muted),
@@ -559,9 +502,14 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {VERCEL_API_KEY_VAR} environment variable.")
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
} else {
"API key configured.".to_string()
let api_url = VercelLanguageModelProvider::api_url(cx);
if api_url == VERCEL_API_URL {
"API key configured".to_string()
} else {
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
}
})),
)
.child(
@@ -572,7 +520,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {VERCEL_API_KEY_VAR} environment variable.")))
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)

View File

@@ -1,8 +1,7 @@
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -10,23 +9,25 @@ use language_model::{
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role,
};
use menu;
use open_ai::ResponseStreamEvent;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use x_ai::Model;
use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
use util::{ResultExt, truncate_and_trailoff};
use x_ai::{Model, XAI_API_URL};
use zed_env_vars::{EnvVar, env_var};
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: &str = "x_ai";
const PROVIDER_NAME: &str = "xAI";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
const API_KEY_ENV_VAR_NAME: &str = "XAI_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default, Clone, Debug, PartialEq)]
pub struct XAiSettings {
@@ -49,103 +50,48 @@ pub struct XAiLanguageModelProvider {
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
api_key_state: ApiKeyState,
}
const XAI_API_KEY_VAR: &str = "XAI_API_KEY";
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
self.api_key_state.has_key()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
let api_url = if settings.api_url.is_empty() {
x_ai::XAI_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let api_url = XAiLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
let api_url = if settings.api_url.is_empty() {
x_ai::XAI_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
let api_url = if settings.api_url.is_empty() {
x_ai::XAI_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = XAiLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
}
}
impl XAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
cx.notify();
}),
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
}
});
Self { http_client, state }
@@ -160,6 +106,19 @@ impl XAiLanguageModelProvider {
request_limiter: RateLimiter::new(4),
})
}
fn settings(cx: &App) -> &XAiSettings {
&crate::AllLanguageModelSettings::get_global(cx).x_ai
}
fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
XAI_API_URL.into()
} else {
SharedString::new(api_url.as_str())
}
}
}
impl LanguageModelProviderState for XAiLanguageModelProvider {
@@ -172,11 +131,11 @@ impl LanguageModelProviderState for XAiLanguageModelProvider {
impl LanguageModelProvider for XAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -200,10 +159,7 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
}
}
for model in &AllLanguageModelSettings::get_global(cx)
.x_ai
.available_models
{
for model in &Self::settings(cx).available_models {
models.insert(
model.name.clone(),
x_ai::Model::Custom {
@@ -241,7 +197,8 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -261,20 +218,20 @@ impl XAiLanguageModel {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
let api_url = if settings.api_url.is_empty() {
x_ai::XAI_API_URL.to_string()
} else {
settings.api_url.clone()
};
(state.api_key.clone(), api_url)
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
let api_url = XAiLanguageModelProvider::api_url(cx);
(state.api_key_state.key(&api_url), api_url)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
return future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing xAI API Key")?;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request =
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
@@ -295,11 +252,11 @@ impl LanguageModel for XAiLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -456,45 +413,35 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self
.api_key_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
// Don't proceed if no API key is provided and we're not authenticated
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
// url changes can cause the editor to be displayed again
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor.update(cx, |input, cx| {
input.editor.update(cx, |editor, cx| {
editor.set_text("", window, cx);
});
});
self.api_key_editor
.update(cx, |input, cx| input.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
state
.update(cx, |state, cx| state.set_api_key(None, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
@@ -504,7 +451,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
let api_key_section = if self.should_render_editor(cx) {
v_flex()
@@ -524,7 +471,7 @@ impl Render for ConfigurationView {
.child(self.api_key_editor.clone())
.child(
Label::new(format!(
"You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed."
"You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
))
.size(LabelSize::Small)
.color(Color::Muted),
@@ -549,9 +496,14 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {XAI_API_KEY_VAR} environment variable.")
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
} else {
"API key configured.".to_string()
let api_url = XAiLanguageModelProvider::api_url(cx);
if api_url == XAI_API_URL {
"API key configured".to_string()
} else {
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
}
})),
)
.child(
@@ -562,7 +514,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable.")))
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)

View File

@@ -30,6 +30,10 @@ impl BasedPyrightBanner {
_subscriptions: [subscription],
}
}
fn onboarding_banner_enabled(&self) -> bool {
!self.dismissed && self.have_basedpyright
}
}
impl EventEmitter<ToolbarItemEvent> for BasedPyrightBanner {}
@@ -38,7 +42,7 @@ impl Render for BasedPyrightBanner {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
div()
.id("basedpyright-banner")
.when(!self.dismissed && self.have_basedpyright, |el| {
.when(self.onboarding_banner_enabled(), |el| {
el.child(
Banner::new()
.child(
@@ -81,6 +85,9 @@ impl ToolbarItemView for BasedPyrightBanner {
_window: &mut ui::Window,
cx: &mut Context<Self>,
) -> ToolbarItemLocation {
if !self.onboarding_banner_enabled() {
return ToolbarItemLocation::Hidden;
}
if let Some(item) = active_pane_item
&& let Some(editor) = item.act_as::<Editor>(cx)
&& let Some(path) = editor.update(cx, |editor, cx| editor.target_file_abs_path(cx))

View File

@@ -12,7 +12,8 @@ use theme::ActiveTheme;
use tree_sitter::{Node, TreeCursor};
use ui::{
ButtonCommon, ButtonLike, Clickable, Color, ContextMenu, FluentBuilder as _, IconButton,
IconName, Label, LabelCommon, LabelSize, PopoverMenu, StyledExt, Tooltip, h_flex, v_flex,
IconName, Label, LabelCommon, LabelSize, PopoverMenu, StyledExt, Tooltip, WithScrollbar,
h_flex, v_flex,
};
use workspace::{
Event as WorkspaceEvent, SplitDirection, ToolbarItemEvent, ToolbarItemLocation,
@@ -487,7 +488,7 @@ impl SyntaxTreeView {
}
impl Render for SyntaxTreeView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
div()
.flex_1()
.bg(cx.theme().colors().editor_background)
@@ -512,6 +513,8 @@ impl Render for SyntaxTreeView {
.text_bg(cx.theme().colors().background)
.into_any_element(),
)
.vertical_scrollbar_for(self.list_scroll_handle.clone(), window, cx)
.into_any_element()
} else {
let inner_content = v_flex()
.items_center()
@@ -540,6 +543,7 @@ impl Render for SyntaxTreeView {
.size_full()
.justify_center()
.child(inner_content)
.into_any_element()
}
})
}

View File

@@ -57,6 +57,7 @@ pet-core.workspace = true
pet-fs.workspace = true
pet-poetry.workspace = true
pet-reporter.workspace = true
pet-virtualenv.workspace = true
pet.workspace = true
project.workspace = true
regex.workspace = true

View File

@@ -10,4 +10,365 @@
(raw_string_literal)
(interpreted_string_literal)
] @injection.content
(#set! injection.language "regex")))
(#set! injection.language "regex")
))
; INJECT SQL
(
[
; var, const or short declaration of raw or interpreted string literal
((comment) @comment
.
(expression_list
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a literal element (to struct field eg.)
((comment) @comment
.
(literal_element
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a function parameter
((comment) @comment
.
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content)
]
(#match? @comment "^\\/\\*\\s*sql\\s*\\*\\/") ; /* sql */ or /*sql*/
(#set! injection.language "sql")
)
; INJECT JSON
(
[
; var, const or short declaration of raw or interpreted string literal
((comment) @comment
.
(expression_list
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a literal element (to struct field eg.)
((comment) @comment
.
(literal_element
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a function parameter
((comment) @comment
.
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content)
]
(#match? @comment "^\\/\\*\\s*json\\s*\\*\\/") ; /* json */ or /*json*/
(#set! injection.language "json")
)
; INJECT YAML
(
[
; var, const or short declaration of raw or interpreted string literal
((comment) @comment
.
(expression_list
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a literal element (to struct field eg.)
((comment) @comment
.
(literal_element
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a function parameter
((comment) @comment
.
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content)
]
(#match? @comment "^\\/\\*\\s*yaml\\s*\\*\\/") ; /* yaml */ or /*yaml*/
(#set! injection.language "yaml")
)
; INJECT XML
(
[
; var, const or short declaration of raw or interpreted string literal
((comment) @comment
.
(expression_list
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a literal element (to struct field eg.)
((comment) @comment
.
(literal_element
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a function parameter
((comment) @comment
.
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content)
]
(#match? @comment "^\\/\\*\\s*xml\\s*\\*\\/") ; /* xml */ or /*xml*/
(#set! injection.language "xml")
)
; INJECT HTML
(
[
; var, const or short declaration of raw or interpreted string literal
((comment) @comment
.
(expression_list
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a literal element (to struct field eg.)
((comment) @comment
.
(literal_element
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a function parameter
((comment) @comment
.
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content)
]
(#match? @comment "^\\/\\*\\s*html\\s*\\*\\/") ; /* html */ or /*html*/
(#set! injection.language "html")
)
; INJECT JS
(
[
; var, const or short declaration of raw or interpreted string literal
((comment) @comment
.
(expression_list
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a literal element (to struct field eg.)
((comment) @comment
.
(literal_element
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a function parameter
((comment) @comment
.
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content)
]
(#match? @comment "^\\/\\*\\s*js\\s*\\*\\/") ; /* js */ or /*js*/
(#set! injection.language "javascript")
)
; INJECT CSS
(
[
; var, const or short declaration of raw or interpreted string literal
((comment) @comment
.
(expression_list
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a literal element (to struct field eg.)
((comment) @comment
.
(literal_element
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a function parameter
((comment) @comment
.
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content)
]
(#match? @comment "^\\/\\*\\s*css\\s*\\*\\/") ; /* css */ or /*css*/
(#set! injection.language "css")
)
; INJECT LUA
(
[
; var, const or short declaration of raw or interpreted string literal
((comment) @comment
.
(expression_list
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a literal element (to struct field eg.)
((comment) @comment
.
(literal_element
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a function parameter
((comment) @comment
.
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content)
]
(#match? @comment "^\\/\\*\\s*lua\\s*\\*\\/") ; /* lua */ or /*lua*/
(#set! injection.language "lua")
)
; INJECT BASH
(
[
; var, const or short declaration of raw or interpreted string literal
((comment) @comment
.
(expression_list
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a literal element (to struct field eg.)
((comment) @comment
.
(literal_element
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a function parameter
((comment) @comment
.
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content)
]
(#match? @comment "^\\/\\*\\s*bash\\s*\\*\\/") ; /* bash */ or /*bash*/
(#set! injection.language "bash")
)
; INJECT CSV
(
[
; var, const or short declaration of raw or interpreted string literal
((comment) @comment
.
(expression_list
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a literal element (to struct field eg.)
((comment) @comment
.
(literal_element
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content
))
; when passing as a function parameter
((comment) @comment
.
[
(interpreted_string_literal)
(raw_string_literal)
] @injection.content)
]
(#match? @comment "^\\/\\*\\s*csv\\s*\\*\\/") ; /* csv */ or /*csv*/
(#set! injection.language "csv")
)

View File

@@ -286,6 +286,7 @@ pub fn init(languages: Arc<LanguageRegistry>, fs: Arc<dyn Fs>, node: NodeRuntime
"HEEX",
"HTML",
"JavaScript",
"TypeScript",
"PHP",
"Svelte",
"TSX",

View File

@@ -16,6 +16,7 @@ use node_runtime::{NodeRuntime, VersionStrategy};
use pet_core::Configuration;
use pet_core::os_environment::Environment;
use pet_core::python_environment::{PythonEnvironment, PythonEnvironmentKind};
use pet_virtualenv::is_virtualenv_dir;
use project::Fs;
use project::lsp_store::language_server_settings;
use serde_json::{Value, json};
@@ -900,6 +901,21 @@ fn python_module_name_from_relative_path(relative_path: &str) -> String {
.to_string()
}
fn is_python_env_global(k: &PythonEnvironmentKind) -> bool {
matches!(
k,
PythonEnvironmentKind::Homebrew
| PythonEnvironmentKind::Pyenv
| PythonEnvironmentKind::GlobalPaths
| PythonEnvironmentKind::MacPythonOrg
| PythonEnvironmentKind::MacCommandLineTools
| PythonEnvironmentKind::LinuxGlobal
| PythonEnvironmentKind::MacXCode
| PythonEnvironmentKind::WindowsStore
| PythonEnvironmentKind::WindowsRegistry
)
}
fn python_env_kind_display(k: &PythonEnvironmentKind) -> &'static str {
match k {
PythonEnvironmentKind::Conda => "Conda",
@@ -966,6 +982,26 @@ async fn get_worktree_venv_declaration(worktree_root: &Path) -> Option<String> {
Some(venv_name.trim().to_string())
}
fn get_venv_parent_dir(env: &PythonEnvironment) -> Option<PathBuf> {
// If global, we aren't a virtual environment
if let Some(kind) = env.kind
&& is_python_env_global(&kind)
{
return None;
}
// Check to be sure we are a virtual environment using pet's most generic
// virtual environment type, VirtualEnv
let venv = env
.executable
.as_ref()
.and_then(|p| p.parent())
.and_then(|p| p.parent())
.filter(|p| is_virtualenv_dir(p))?;
venv.parent().map(|parent| parent.to_path_buf())
}
#[async_trait]
impl ToolchainLister for PythonToolchainProvider {
async fn list(
@@ -1025,11 +1061,15 @@ impl ToolchainLister for PythonToolchainProvider {
});
// Compare project paths against worktree root
let proj_ordering = || match (&lhs.project, &rhs.project) {
(Some(l), Some(r)) => (r == &wr).cmp(&(l == &wr)),
(Some(l), None) if l == &wr => Ordering::Less,
(None, Some(r)) if r == &wr => Ordering::Greater,
_ => Ordering::Equal,
let proj_ordering = || {
let lhs_project = lhs.project.clone().or_else(|| get_venv_parent_dir(lhs));
let rhs_project = rhs.project.clone().or_else(|| get_venv_parent_dir(rhs));
match (&lhs_project, &rhs_project) {
(Some(l), Some(r)) => (r == &wr).cmp(&(l == &wr)),
(Some(l), None) if l == &wr => Ordering::Less,
(None, Some(r)) if r == &wr => Ordering::Greater,
_ => Ordering::Equal,
}
};
// Compare environment priorities
@@ -1131,7 +1171,7 @@ impl ToolchainLister for PythonToolchainProvider {
let activate_keyword = match shell {
ShellKind::Cmd => ".",
ShellKind::Nushell => "overlay use",
ShellKind::Powershell => ".",
ShellKind::PowerShell => ".",
ShellKind::Fish => "source",
ShellKind::Csh => "source",
ShellKind::Posix => "source",
@@ -1141,7 +1181,7 @@ impl ToolchainLister for PythonToolchainProvider {
ShellKind::Csh => "activate.csh",
ShellKind::Fish => "activate.fish",
ShellKind::Nushell => "activate.nu",
ShellKind::Powershell => "activate.ps1",
ShellKind::PowerShell => "activate.ps1",
ShellKind::Cmd => "activate.bat",
};
let path = prefix.join(BINARY_DIR).join(activate_script_name);
@@ -1165,7 +1205,7 @@ impl ToolchainLister for PythonToolchainProvider {
ShellKind::Fish => Some(format!("\"{pyenv}\" shell - fish {version}")),
ShellKind::Posix => Some(format!("\"{pyenv}\" shell - sh {version}")),
ShellKind::Nushell => Some(format!("\"{pyenv}\" shell - nu {version}")),
ShellKind::Powershell => None,
ShellKind::PowerShell => None,
ShellKind::Csh => None,
ShellKind::Cmd => None,
})

View File

@@ -146,6 +146,7 @@ impl LspAdapter for TailwindLspAdapter {
"html": "html",
"css": "css",
"javascript": "javascript",
"typescript": "typescript",
"typescriptreact": "typescriptreact",
},
})))
@@ -178,6 +179,7 @@ impl LspAdapter for TailwindLspAdapter {
(LanguageName::new("HTML"), "html".to_string()),
(LanguageName::new("CSS"), "css".to_string()),
(LanguageName::new("JavaScript"), "javascript".to_string()),
(LanguageName::new("TypeScript"), "typescript".to_string()),
(LanguageName::new("TSX"), "typescriptreact".to_string()),
(LanguageName::new("Svelte"), "svelte".to_string()),
(LanguageName::new("Elixir"), "phoenix-heex".to_string()),

View File

@@ -21,9 +21,11 @@ word_characters = ["#", "$"]
prettier_parser_name = "typescript"
tab_size = 2
debuggers = ["JavaScript"]
scope_opt_in_language_servers = ["tailwindcss-language-server"]
[overrides.string]
completion_query_characters = ["."]
completion_query_characters = ["-", "."]
opt_into_language_servers = ["tailwindcss-language-server"]
prefer_label_for_snippet = true
[overrides.function_name_before_type_arguments]

View File

@@ -1079,7 +1079,7 @@ impl Element for MarkdownElement {
{
builder.modify_current_div(|el| {
let content_range = parser::extract_code_block_content_range(
parsed_markdown.source()[range.clone()].trim(),
&parsed_markdown.source()[range.clone()],
);
let content_range = content_range.start + range.start
..content_range.end + range.start;
@@ -1110,7 +1110,7 @@ impl Element for MarkdownElement {
{
builder.modify_current_div(|el| {
let content_range = parser::extract_code_block_content_range(
parsed_markdown.source()[range.clone()].trim(),
&parsed_markdown.source()[range.clone()],
);
let content_range = content_range.start + range.start
..content_range.end + range.start;

Some files were not shown because too many files have changed in this diff Show More