Compare commits

..

93 Commits

Author SHA1 Message Date
Richard Feldman
9cc517e0dd Fix some extension auto install bugs 2025-12-11 00:52:08 -05:00
Richard Feldman
d1390a5b78 Merge remote-tracking branch 'origin/main' into migrate-provider-extensions 2025-12-11 00:26:09 -05:00
Richard Feldman
ee4faede38 Migrate on auto-load 2025-12-11 00:22:38 -05:00
Richard Feldman
8d96a699b3 Revise migration system some more 2025-12-11 00:13:11 -05:00
Richard Feldman
8cfb7471db Minimize how we're tracking migrations 2025-12-10 23:21:14 -05:00
Richard Feldman
def9c87837 Migrate credentials without touching settings 2025-12-10 22:29:48 -05:00
CharlesChen0823
7ed5d42696 git: Fix git hook hang with prek (#44212)
Fix git hook hang when using with `prek`. Can see
[comments](https://github.com/zed-industries/zed/issues/44057#issuecomment-3606837089),
this is easy test, should using release build, debug build sometimes not
hang.

The issue existing long time, see issue #37293 , and then in commit
#42239 this issue had fixed. but in commit #43285 broken again. So I
reference the implementation in #42239, then this code work.

I MUST CLAIM, I really don't known what happend, and why this code work.
But it worked.

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
2025-12-11 03:17:13 +00:00
Richard Feldman
0313ab6d41 Change open-router to openrouter in default.json 2025-12-10 22:10:29 -05:00
Richard Feldman
c5329fdff2 Rename extension from open-router to openrouter 2025-12-10 22:09:59 -05:00
Richard Feldman
a676a6895b Remove redundant set_builtin_provider_hiding_fn call 2025-12-10 22:05:03 -05:00
Richard Feldman
3b5d7d7d89 Minor cleanups 2025-12-10 22:04:35 -05:00
Richard Feldman
91f01131b1 Introduce DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS 2025-12-10 21:29:10 -05:00
Richard Feldman
5fa5226286 Remove llm_provider_authenticate() 2025-12-10 21:28:58 -05:00
Richard Feldman
ae94007227 Merge remote-tracking branch 'origin/main' into migrate-provider-extensions 2025-12-10 21:13:57 -05:00
Max Brunsfeld
25d74480aa Rework edit prediction CLI (#44562)
This PR restructures the commands of the Edit Prediction CLI (now called
`ep`), to support some flows that are important for the training
process:
* generating zeta2 prompt and expected output, without running
predictions
* scoring outputs that are generated by a system other than the
production code (to evaluate the model during training)

To achieve this, we've restructured the CLI commands so that they all
take as input, and produce as output, a consistent, uniform data format:
a set of one or more `Example` structs, expressible either as the
original markdown format, or as a JSON lines. The `Example` struct
starts with the basic fields that are in human-readable eval format, but
contain a number of optional fields that are filled in by different
steps in the processing pipeline (`context`, `predict`, `format-prompt`,
and `score`).

### To do

* [x] Adjust the teacher model output parsing to use the full buffer
contents
* [x] Move udiff to cli
* [x] Align `format-prompt` with Zeta2's production code
* [x] Change score output to assume same provider
* [x] Move pretty reporting to `eval` command
* [x] Store cursor point in addition to cursor offset
* [x] Rename `edit_prediction_cli2` -> `edit_prediction_cli` (nuke the
old one)

Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-12-10 17:36:51 -08:00
Cole Miller
37077a8ebb git: Avoid calling git help -a on every commit (#44586)
Updates #43993 

Release Notes:

- N/A
2025-12-11 01:03:35 +00:00
Finn Evers
7c4a85f5f1 ci: Explicitly set git committer information in protobuf check (#44582)
This should hopefully fix the flakes for good.

Release Notes:

- N/A
2025-12-10 23:35:02 +00:00
Cole Miller
d21628c349 Revert "Increase askpass timeout for git operations (#42946)" (#44578)
This reverts commit a74aac88c9.

cc @11happy, we need to do a bit more than just running `git hook
pre-push` before pushing, as described
[here](https://github.com/zed-industries/zed/pull/42946#issuecomment-3550570438).
Right now this is also running the pre-push hook twice.

Release Notes:

- N/A
2025-12-10 18:07:01 -05:00
Xipeng Jin
9e628505f3 git: Add tree view support to Git Panel (#44089)
Closes #35803

This PR adds tree view support to the git panel UI as an additional
setting and moves git entry checkboxes to the right. Tree view only
supports sorting by paths behavior since sorting by status can become
noisy, due to having to duplicate directories that have entries with
different statuses.

### Tree vs Flat View
<img width="358" height="250" alt="image"
src="https://github.com/user-attachments/assets/c6b95d57-12fc-4c5e-8537-ee129963e50c"
/>
<img width="362" height="152" alt="image"
src="https://github.com/user-attachments/assets/0a69e00f-3878-4807-ae45-65e2d54174fc"
/>


#### Architecture changes

Before this PR, `GitPanel::entries` represented all entries and all
visible entries because both sets were equal to one another. However,
this equality isn't true for tree view, because entries can be
collapsed. To fix this, `TreeState` was added as a logical indices field
that is used to filter out non-visible entries. A benefit of this field
is that it could be used in the future to implement searching in the
GitPanel.

Another significant thing this PR changed was adding a HashMap field
`entries_by_indices` on `GitPanel`. We did this because `entry_by_path`
used binary search, which becomes overly complicated to implement for
tree view. The performance of this function matters because it's a hot
code path, so a linear search wasn't ideal either. The solution was
using a hash map to improve time complexity from O(log n) to O(1), where
n is the count of entries.

#### Follow-ups
In the future, we could use `ui::ListItem` to render entries in the tree
view to improve UI consistency.
 
Release Notes:

- Added tree view for Git panel. Users are able to switch between Flat
and Tree view in Git panel.

---------

Co-authored-by: Anthony Eid <anthony@zed.dev>
Co-authored-by: Remco Smits <djsmits12@gmail.com>
2025-12-10 15:11:36 -05:00
KyleBarton
3a84ec38ac Introduce MVP Dev Containers support (#44442)
Partially addresses #11473 

MVP of dev containers with the following capabilities:

- If in a project with `.devcontainer/devcontainer.json`, a pop-up
notification will ask if you want to open the project in a dev
container. This can be dismissed:
<img width="1478" height="1191" alt="Screenshot 2025-12-08 at 3 15
23 PM"
src="https://github.com/user-attachments/assets/ec2e20d6-28ec-4495-8f23-4c1d48a9ce78"
/>
- Similarly, if a `devcontainer.json` file is in the project, you can
open a devcontainer (or go the devcontainer.json file for further
editing) via the `open remote` modal:


https://github.com/user-attachments/assets/61f2fdaa-2808-4efc-994c-7b444a92c0b1

*Limitations*

This is a first release, and comes with some limitations:
- Zed extensions are not managed in `devcontainer.json` yet. They will
need to be installed either on host or in the container. Host +
Container sync their extensions, so there is not currently a concept of
what is installed in the container vs what is installed on host: they
come from the same list of manifests
- This implementation uses the [devcontainer
CLI](https://github.com/devcontainers/cli) for its control plane. Hence,
it does not yet support the `forwardPorts` directive. A single port can
be opened with `appPort`. See reference in docs
[here](https://github.com/devcontainers/cli/tree/main/example-usage#how-the-tool-examples-work)
- Editing devcontainer.json does not automatically cause the dev
container to be rebuilt. So if you add features, change images, etc, you
will need to `docker kill` the existing dev container before proceeding.
- Currently takes a hard dependency on `docker` being available in the
user's `PATH`.


Release Notes:

- Added ability to Open a project in a DevContainer, provided a
`.devcontainer/devcontainer.json` is present

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com>
2025-12-10 12:10:43 -08:00
Richard Feldman
8f425a1bd5 Fix unused arg 2025-12-10 13:11:30 -05:00
Richard Feldman
743c414e7b Refresh models list after successful auth 2025-12-10 13:10:55 -05:00
Richard Feldman
0fe335efc5 Revise Copilot auth 2025-12-10 13:02:38 -05:00
Danilo Leal
a61bf33fb0 Fix label copy for file history menu items (#44569)
Buttons and menu items should preferably always start with an infinitive
verb that describes what will happen when you trigger them. Instead of
just "File History", we should say "_View_ File History".

Release Notes:

- N/A
2025-12-10 18:00:11 +00:00
Richard Feldman
36b95aac4b Debugging extension loading timing and fallbacks 2025-12-10 12:55:41 -05:00
Richard Feldman
b2df70ab58 Clean up extension markdown for settings 2025-12-10 12:55:23 -05:00
John Tur
d83201256d Use shell to launch MCP and ACP servers (#42382)
`npx`, and any `npm install`-ed programs, exist as batch
scripts/PowerShell scripts on the PATH. We have to use a shell to launch
these programs.

Fixes https://github.com/zed-industries/zed/issues/41435
Closes https://github.com/zed-industries/zed/pull/42651


Release Notes:

- windows: Custom MCP and ACP servers installed through `npm` now launch
correctly.

---------

Co-authored-by: Lukas Wirth <me@lukaswirth.dev>
2025-12-10 12:08:37 -05:00
Ben Kunkle
8ee85eab3c vim: Remove ctrl-6 keybinding alias for pane::AlternateFile (#44560)
Closes #ISSUE

It seems that `ctrl-6` is used exclusively as an alias, as can be seen
in the [linked section of the vim
docs](https://vimhelp.org/editing.txt.html#CTRL-%5E) from the initial PR
that added it. This however conflicts with the `ctrl-{n}` bindings for
`pane::ActivateItem` on macOS, leading to confusing file selection when
`ctrl-6` is pressed.

Release Notes:

- vim(BREAKING): Removed a keybinding conflict between the default macOS
bindings for `pane::ActivateItem` and the `ctrl-6` alias
for`pane::AlternateFile` which is primarily bound to `ctrl-^`. `ctrl-6`
is no longer treated as an alias for `ctrl-^` in vim mode. If you'd like
to restore `ctrl-6` as a binding for `pane::AlternateFile`, paste the
following into your `keymap.json` file:
```
  {
    "context": "VimControl && !menu",
    "bindings": {
      "ctrl-6": "pane::AlternateFile"
    }
  }
```
2025-12-10 16:55:50 +00:00
Ben Brandt
5b309ef986 acp: Better telemetry IDs for ACP agents (#44544)
We were defining these in multiple places and also weren't leveraging
the ids the agents were already providing.

This should make sure we use them consistently and avoid issues in the
future.

Release Notes:

- N/A
2025-12-10 16:48:08 +00:00
Mayank Verma
326ebb5230 git: Fix failing commits when hook command is not available (#43993) 2025-12-10 16:34:49 +00:00
Bennet Bo Fenner
f5babf96e1 agent_ui: Fix project path not found error when pasting code from other project (#44555)
The problem with inserting the absolute paths is that the agent will try
to read them. However, we don't allow the agent to read files outside
the current project. For now, we will only insert the crease in case the
code that is getting pasted is from the same project

Release Notes:

- Fixed an issue where pasting code into the agent panel from another
window would show an error
2025-12-10 16:30:10 +00:00
Joseph T. Lyons
f48aa252f8 Bump Zed to v0.218 (#44551)
Release Notes:

- N/A
2025-12-10 15:28:39 +00:00
Finn Evers
4106c8a188 Disable OmniSharp by default for C# files (#44427)
In preparation for https://github.com/zed-extensions/csharp/pull/11. Do
not merge before that PR is published.

Release Notes:

- Added support for Roslyn in C# files. Roslyn will now be the default
language server for C#
2025-12-10 10:12:41 -05:00
Richard Feldman
36293d7dd9 Debugging 2025-12-09 17:04:58 -05:00
Richard Feldman
3ae3e1fce8 Don't use a heuristic for icon path 2025-12-09 14:55:44 -05:00
Richard Feldman
e5f1fc7478 Fix some regressions 2025-12-09 14:48:31 -05:00
Richard Feldman
a4f6076da7 Migrate to extensions with fallback to builtin 2025-12-09 14:14:56 -05:00
Richard Feldman
43726b2620 Restore ai_anthropic icon svg 2025-12-09 12:00:36 -05:00
Richard Feldman
94980ffb49 Reduce duplication in compute_configured_providers 2025-12-09 11:55:37 -05:00
Richard Feldman
22cc731450 Remove some duplication from icon logic 2025-12-09 11:54:58 -05:00
Richard Feldman
d9396373e3 Eliminate more code duplication 2025-12-09 11:54:00 -05:00
Richard Feldman
48002be135 Use | instead of code duplication 2025-12-09 11:53:18 -05:00
Richard Feldman
58db83f8f5 more icon code cleanup 2025-12-09 11:48:06 -05:00
Richard Feldman
0243d5b542 Clean up some more icon code 2025-12-09 11:44:10 -05:00
Richard Feldman
06230327fa Clean up some icon code 2025-12-09 11:44:05 -05:00
Richard Feldman
ca5c8992f9 Merge remote-tracking branch 'origin/main' into migrate-provider-extensions 2025-12-08 20:23:32 -05:00
Richard Feldman
1038e1c2ef Clean up some duplicated code 2025-12-08 16:59:49 -05:00
Richard Feldman
e1fe0b3287 Restore providers, deduplicate if extensions are present 2025-12-08 16:25:41 -05:00
Richard Feldman
a0e10a91bf Merge remote-tracking branch 'origin/main' into migrate-provider-extensions 2025-12-08 15:35:44 -05:00
Richard Feldman
272b1aa4bc Remove obsolete llm_provider_authenticate 2025-12-08 14:46:04 -05:00
Richard Feldman
9ef0537b44 Add the other extensions to auto-install 2025-12-07 23:13:52 -05:00
Richard Feldman
77f1de742b delete hardcoded AI providers in favor of extnesions 2025-12-07 21:31:00 -05:00
Richard Feldman
e054cabd41 Migrate Google AI over to the extension 2025-12-07 20:57:00 -05:00
Richard Feldman
3b95cb5682 Migrate Copilot and Anthropic to extensions 2025-12-07 20:48:42 -05:00
Richard Feldman
c89653bd07 Fix bugs around logging out from provider extensions 2025-12-05 17:07:25 -05:00
Richard Feldman
b90ac2dc07 Fix Drop impl for WasmExtension 2025-12-05 16:21:53 -05:00
Marshall Bowers
c9998541f0 Revert spurious changes to default.json 2025-12-05 13:25:03 -05:00
Marshall Bowers
e2b49b3cd3 Restore blank lines from main 2025-12-05 13:08:30 -05:00
Marshall Bowers
d1e77397c6 Don't make v0.8.0 available on Stable/Preview yet 2025-12-05 13:07:36 -05:00
Richard Feldman
cc5f5e35e4 Clean up some comments 2025-12-05 13:00:19 -05:00
Richard Feldman
7183b8a1cd Fix API key bug 2025-12-05 12:59:19 -05:00
Richard Feldman
b1934fb712 Remove builtin Anthropic provider 2025-12-05 12:11:51 -05:00
Richard Feldman
a198b6c0d1 Use icon in more places 2025-12-05 11:48:11 -05:00
Richard Feldman
8b5b2712c8 Update Cargo.lock 2025-12-05 11:32:58 -05:00
Richard Feldman
4464392e8e Use kebab-case for open-router extension too. 2025-12-05 11:19:10 -05:00
Richard Feldman
a0d3bc31e9 Rename copilot_chat to copilot-chat 2025-12-05 11:15:43 -05:00
Richard Feldman
ccd6672d1a Revert "Remove builtin extensions for now"
This reverts commit 5559726fd7.
2025-12-05 11:13:29 -05:00
Richard Feldman
21de6d35dd Revert "Revert auto-install extensions for now"
This reverts commit 2031ca17e5.
2025-12-05 11:13:22 -05:00
Richard Feldman
2031ca17e5 Revert auto-install extensions for now 2025-12-05 11:06:12 -05:00
Richard Feldman
8b1ce75a57 Move wit extensions into their own module 2025-12-05 10:30:02 -05:00
Richard Feldman
5559726fd7 Remove builtin extensions for now 2025-12-04 17:20:47 -05:00
Richard Feldman
e1a9269921 Delete example provider extension 2025-12-04 17:20:47 -05:00
Richard Feldman
3b6b3ff504 Specify env vars for the builtin extensions 2025-12-04 17:19:35 -05:00
Richard Feldman
aabed94970 Add OAuth via web authentication to llm extensions, migrate copilot 2025-12-04 17:12:55 -05:00
Richard Feldman
2d3a3521ba Add OAuth Web Flow auth option for llm provider extensions 2025-12-04 17:12:55 -05:00
Richard Feldman
a48bd10da0 Add llm extensions to auto_install_extensions 2025-12-04 17:12:55 -05:00
Richard Feldman
fec9525be4 Add env var checkbox 2025-12-04 17:12:23 -05:00
Richard Feldman
bf2b8e999e use fill=black over fill=currentColor 2025-12-04 16:51:47 -05:00
Richard Feldman
63c35d2b00 Use local icons in llm extensions 2025-12-04 16:48:25 -05:00
Richard Feldman
1396c68010 Add svg icons to llm provider extensions 2025-12-04 16:43:49 -05:00
Richard Feldman
fcb3d3dec6 Update a comment 2025-12-04 16:28:29 -05:00
Richard Feldman
f54e7f8c9d Add trailing newlines 2025-12-04 16:18:43 -05:00
Richard Feldman
2a89529d7f Use named fields 2025-12-04 16:17:50 -05:00
Richard Feldman
58207325e2 restore impl Drop for WasmExtension 2025-12-04 16:12:21 -05:00
Richard Feldman
e08ab99e8d Add extensions for LLM providers 2025-12-04 16:03:51 -05:00
Richard Feldman
a95f3f33a4 Clean up debug logging 2025-12-04 12:38:06 -05:00
Richard Feldman
b0767c1b1f Merge remote-tracking branch 'origin/main' into provider-extensions 2025-12-04 12:27:15 -05:00
Richard Feldman
b200e10bc4 Clean up debug statements 2025-12-04 11:30:44 -05:00
Richard Feldman
948905d916 Revise provider extensions for Gemini API 2025-12-03 20:22:10 -05:00
Richard Feldman
04de456373 Use extension-llm- prefix for credential keys 2025-12-03 15:55:10 -05:00
Richard Feldman
e5ce32e936 Add provider extension API key in settings 2025-12-03 14:41:39 -05:00
Richard Feldman
d7caae30de Fix auth and subscriptions for provider extensions 2025-12-03 13:00:53 -05:00
Richard Feldman
c7e77674a1 Initial Claude Opus 4.5 implementation of Provider Extensions 2025-12-02 13:50:00 -05:00
178 changed files with 18889 additions and 5626 deletions

View File

@@ -497,6 +497,8 @@ jobs:
env:
GIT_AUTHOR_NAME: Protobuf Action
GIT_AUTHOR_EMAIL: ci@zed.dev
GIT_COMMITTER_NAME: Protobuf Action
GIT_COMMITTER_EMAIL: ci@zed.dev
steps:
- name: steps::checkout_repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683

46
Cargo.lock generated
View File

@@ -3111,16 +3111,6 @@ dependencies = [
"uuid",
]
[[package]]
name = "cloud_zeta2_prompt"
version = "0.1.0"
dependencies = [
"anyhow",
"cloud_llm_client",
"indoc",
"serde",
]
[[package]]
name = "cmake"
version = "0.1.54"
@@ -3595,6 +3585,7 @@ dependencies = [
"settings",
"smol",
"tempfile",
"terminal",
"url",
"util",
]
@@ -5118,7 +5109,6 @@ dependencies = [
"clock",
"cloud_api_types",
"cloud_llm_client",
"cloud_zeta2_prompt",
"collections",
"copilot",
"credentials_provider",
@@ -5149,8 +5139,6 @@ dependencies = [
"serde",
"serde_json",
"settings",
"smol",
"strsim",
"strum 0.27.2",
"telemetry",
"telemetry_events",
@@ -5161,6 +5149,7 @@ dependencies = [
"workspace",
"worktree",
"zed_actions",
"zeta_prompt",
"zlog",
]
@@ -5174,11 +5163,10 @@ dependencies = [
"clap",
"client",
"cloud_llm_client",
"cloud_zeta2_prompt",
"collections",
"debug_adapter_extension",
"dirs 4.0.0",
"edit_prediction",
"edit_prediction_context",
"extension",
"fs",
"futures 0.3.31",
@@ -5208,9 +5196,10 @@ dependencies = [
"sqlez",
"sqlez_macros",
"terminal_view",
"toml 0.8.23",
"util",
"wasmtime",
"watch",
"zeta_prompt",
"zlog",
]
@@ -5238,6 +5227,7 @@ dependencies = [
"text",
"tree-sitter",
"util",
"zeta_prompt",
"zlog",
]
@@ -5259,7 +5249,6 @@ dependencies = [
"buffer_diff",
"client",
"cloud_llm_client",
"cloud_zeta2_prompt",
"codestral",
"command_palette_hooks",
"copilot",
@@ -5290,6 +5279,7 @@ dependencies = [
"util",
"workspace",
"zed_actions",
"zeta_prompt",
]
[[package]]
@@ -5853,9 +5843,12 @@ dependencies = [
"async-trait",
"client",
"collections",
"credentials_provider",
"criterion",
"ctor",
"dap",
"dirs 4.0.0",
"editor",
"extension",
"fs",
"futures 0.3.31",
@@ -5864,8 +5857,11 @@ dependencies = [
"http_client",
"language",
"language_extension",
"language_model",
"log",
"lsp",
"markdown",
"menu",
"moka",
"node_runtime",
"parking_lot",
@@ -5880,12 +5876,14 @@ dependencies = [
"serde_json",
"serde_json_lenient",
"settings",
"smol",
"task",
"telemetry",
"tempfile",
"theme",
"theme_extension",
"toml 0.8.23",
"ui",
"url",
"util",
"wasmparser 0.221.3",
@@ -8852,6 +8850,8 @@ dependencies = [
"credentials_provider",
"deepseek",
"editor",
"extension",
"extension_host",
"fs",
"futures 0.3.31",
"google_ai",
@@ -13156,6 +13156,7 @@ dependencies = [
"askpass",
"auto_update",
"dap",
"db",
"editor",
"extension_host",
"file_finder",
@@ -13167,6 +13168,7 @@ dependencies = [
"log",
"markdown",
"menu",
"node_runtime",
"ordered-float 2.10.1",
"paths",
"picker",
@@ -13185,6 +13187,7 @@ dependencies = [
"util",
"windows-registry 0.6.1",
"workspace",
"worktree",
"zed_actions",
]
@@ -20469,7 +20472,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.217.0"
version = "0.218.0"
dependencies = [
"acp_tools",
"activity_indicator",
@@ -20929,6 +20932,13 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "zeta_prompt"
version = "0.1.0"
dependencies = [
"serde",
]
[[package]]
name = "zip"
version = "0.6.6"

View File

@@ -32,7 +32,6 @@ members = [
"crates/cloud_api_client",
"crates/cloud_api_types",
"crates/cloud_llm_client",
"crates/cloud_zeta2_prompt",
"crates/collab",
"crates/collab_ui",
"crates/collections",
@@ -202,6 +201,7 @@ members = [
"crates/zed_actions",
"crates/zed_env_vars",
"crates/edit_prediction_cli",
"crates/zeta_prompt",
"crates/zlog",
"crates/zlog_settings",
"crates/ztracing",
@@ -266,7 +266,6 @@ clock = { path = "crates/clock" }
cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections", version = "0.1.0" }
command_palette = { path = "crates/command_palette" }
@@ -425,6 +424,7 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
edit_prediction = { path = "crates/edit_prediction" }
zeta_prompt = { path = "crates/zeta_prompt" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
ztracing = { path = "crates/ztracing" }
@@ -657,6 +657,7 @@ time = { version = "0.3", features = [
tiny_http = "0.8"
tokio = { version = "1" }
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] }
toml = "0.8"
toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] }
tower-http = "0.4.4"

5
assets/icons/box.svg Normal file
View File

@@ -0,0 +1,5 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M13.3996 5.59852C13.3994 5.3881 13.3439 5.18144 13.2386 4.99926C13.1333 4.81709 12.9819 4.66581 12.7997 4.56059L8.59996 2.16076C8.41755 2.05544 8.21063 2 8 2C7.78937 2 7.58246 2.05544 7.40004 2.16076L3.20033 4.56059C3.0181 4.66581 2.86674 4.81709 2.76144 4.99926C2.65613 5.18144 2.60059 5.3881 2.60037 5.59852V10.3982C2.60059 10.6086 2.65613 10.8153 2.76144 10.9975C2.86674 11.1796 3.0181 11.3309 3.20033 11.4361L7.40004 13.836C7.58246 13.9413 7.78937 13.9967 8 13.9967C8.21063 13.9967 8.41755 13.9413 8.59996 13.836L12.7997 11.4361C12.9819 11.3309 13.1333 11.1796 13.2386 10.9975C13.3439 10.8153 13.3994 10.6086 13.3996 10.3982V5.59852Z" stroke="white" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M2.78033 4.99857L7.99998 7.99836L13.2196 4.99857" stroke="white" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M8 13.9979V7.99829" stroke="white" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@@ -180,7 +180,6 @@
"ctrl-w g shift-d": "editor::GoToTypeDefinitionSplit",
"ctrl-w space": "editor::OpenExcerptsSplit",
"ctrl-w g space": "editor::OpenExcerptsSplit",
"ctrl-6": "pane::AlternateFile",
"ctrl-^": "pane::AlternateFile",
".": "vim::Repeat"
}

View File

@@ -870,6 +870,10 @@
//
// Default: false
"collapse_untracked_diff": false,
/// Whether to show entries with tree or flat view in the panel
///
/// Default: false
"tree_view": false,
"scrollbar": {
// When to show the scrollbar in the git panel.
//
@@ -1721,7 +1725,12 @@
// If you don't want any of these extensions, add this field to your settings
// and change the value to `false`.
"auto_install_extensions": {
"html": true
"html": true,
"copilot-chat": true,
"anthropic": true,
"google-ai": true,
"openai": true,
"openrouter": true,
},
// The capabilities granted to extensions.
//

View File

@@ -1372,7 +1372,7 @@ impl AcpThread {
let path_style = self.project.read(cx).path_style(cx);
let id = update.tool_call_id.clone();
let agent = self.connection().telemetry_id();
let agent_telemetry_id = self.connection().telemetry_id();
let session = self.session_id();
if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
let status = if matches!(status, ToolCallStatus::Completed) {
@@ -1380,7 +1380,12 @@ impl AcpThread {
} else {
"failed"
};
telemetry::event!("Agent Tool Call Completed", agent, session, status);
telemetry::event!(
"Agent Tool Call Completed",
agent_telemetry_id,
session,
status
);
}
if let Some(ix) = self.index_for_tool_call(&id) {
@@ -3556,8 +3561,8 @@ mod tests {
}
impl AgentConnection for FakeAgentConnection {
fn telemetry_id(&self) -> &'static str {
"fake"
fn telemetry_id(&self) -> SharedString {
"fake".into()
}
fn auth_methods(&self) -> &[acp::AuthMethod] {

View File

@@ -20,7 +20,7 @@ impl UserMessageId {
}
pub trait AgentConnection {
fn telemetry_id(&self) -> &'static str;
fn telemetry_id(&self) -> SharedString;
fn new_thread(
self: Rc<Self>,
@@ -204,12 +204,21 @@ pub trait AgentModelSelector: 'static {
}
}
/// Icon for a model in the model selector.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AgentModelIcon {
/// A built-in icon from Zed's icon set.
Named(IconName),
/// Path to a custom SVG icon file.
Path(SharedString),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AgentModelInfo {
pub id: acp::ModelId,
pub name: SharedString,
pub description: Option<SharedString>,
pub icon: Option<IconName>,
pub icon: Option<AgentModelIcon>,
}
impl From<acp::ModelInfo> for AgentModelInfo {
@@ -322,8 +331,8 @@ mod test_support {
}
impl AgentConnection for StubAgentConnection {
fn telemetry_id(&self) -> &'static str {
"stub"
fn telemetry_id(&self) -> SharedString {
"stub".into()
}
fn auth_methods(&self) -> &[acp::AuthMethod] {

View File

@@ -777,7 +777,7 @@ impl ActionLog {
#[derive(Clone)]
pub struct ActionLogTelemetry {
pub agent_telemetry_id: &'static str,
pub agent_telemetry_id: SharedString,
pub session_id: Arc<str>,
}

View File

@@ -739,7 +739,7 @@ impl ActivityIndicator {
extension_store.outstanding_operations().iter().next()
{
let (message, icon, rotate) = match operation {
ExtensionOperation::Install => (
ExtensionOperation::Install | ExtensionOperation::AutoInstall => (
format!("Installing {extension_id} extension…"),
IconName::LoadCircle,
true,

View File

@@ -18,7 +18,7 @@ pub use templates::*;
pub use thread::*;
pub use tools::*;
use acp_thread::{AcpThread, AgentModelSelector};
use acp_thread::{AcpThread, AgentModelIcon, AgentModelSelector};
use agent_client_protocol as acp;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
@@ -105,7 +105,7 @@ impl LanguageModels {
fn refresh_list(&mut self, cx: &App) {
let providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.visible_providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
@@ -161,11 +161,16 @@ impl LanguageModels {
model: &Arc<dyn LanguageModel>,
provider: &Arc<dyn LanguageModelProvider>,
) -> acp_thread::AgentModelInfo {
let icon = if let Some(path) = provider.icon_path() {
Some(AgentModelIcon::Path(path))
} else {
Some(AgentModelIcon::Named(provider.icon()))
};
acp_thread::AgentModelInfo {
id: Self::model_id(model),
name: model.name().0,
description: None,
icon: Some(provider.icon()),
icon,
}
}
@@ -947,8 +952,8 @@ impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
}
impl acp_thread::AgentConnection for NativeAgentConnection {
fn telemetry_id(&self) -> &'static str {
"zed"
fn telemetry_id(&self) -> SharedString {
"zed".into()
}
fn new_thread(
@@ -1356,7 +1361,7 @@ mod internal_tests {
id: acp::ModelId::new("fake/fake"),
name: "Fake".into(),
description: None,
icon: Some(ui::IconName::ZedAssistant),
icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
}]
)])
);

View File

@@ -21,10 +21,6 @@ impl NativeAgentServer {
}
impl AgentServer for NativeAgentServer {
fn telemetry_id(&self) -> &'static str {
"zed"
}
fn name(&self) -> SharedString {
"Zed Agent".into()
}

View File

@@ -9,6 +9,10 @@ use futures::io::BufReader;
use project::Project;
use project::agent_server_store::AgentServerCommand;
use serde::Deserialize;
use settings::Settings as _;
use task::ShellBuilder;
#[cfg(windows)]
use task::ShellKind;
use util::ResultExt as _;
use std::path::PathBuf;
@@ -21,7 +25,7 @@ use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntit
use acp_thread::{AcpThread, AuthRequired, LoadError, TerminalProviderEvent};
use terminal::TerminalBuilder;
use terminal::terminal_settings::{AlternateScroll, CursorShape};
use terminal::terminal_settings::{AlternateScroll, CursorShape, TerminalSettings};
#[derive(Debug, Error)]
#[error("Unsupported version")]
@@ -29,7 +33,7 @@ pub struct UnsupportedVersion;
pub struct AcpConnection {
server_name: SharedString,
telemetry_id: &'static str,
telemetry_id: SharedString,
connection: Rc<acp::ClientSideConnection>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
auth_methods: Vec<acp::AuthMethod>,
@@ -54,7 +58,6 @@ pub struct AcpSession {
pub async fn connect(
server_name: SharedString,
telemetry_id: &'static str,
command: AgentServerCommand,
root_dir: &Path,
default_mode: Option<acp::SessionModeId>,
@@ -64,7 +67,6 @@ pub async fn connect(
) -> Result<Rc<dyn AgentConnection>> {
let conn = AcpConnection::stdio(
server_name,
telemetry_id,
command.clone(),
root_dir,
default_mode,
@@ -81,7 +83,6 @@ const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::ProtocolVersion::V1
impl AcpConnection {
pub async fn stdio(
server_name: SharedString,
telemetry_id: &'static str,
command: AgentServerCommand,
root_dir: &Path,
default_mode: Option<acp::SessionModeId>,
@@ -89,9 +90,26 @@ impl AcpConnection {
is_remote: bool,
cx: &mut AsyncApp,
) -> Result<Self> {
let mut child = util::command::new_smol_command(&command.path);
let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
let builder = ShellBuilder::new(&shell, cfg!(windows));
#[cfg(windows)]
let kind = builder.kind();
let (cmd, args) = builder.build(Some(command.path.display().to_string()), &command.args);
let mut child = util::command::new_smol_command(cmd);
#[cfg(windows)]
if kind == ShellKind::Cmd {
use smol::process::windows::CommandExt;
for arg in args {
child.raw_arg(arg);
}
} else {
child.args(args);
}
#[cfg(not(windows))]
child.args(args);
child
.args(command.args.iter().map(|arg| arg.as_str()))
.envs(command.env.iter().flatten())
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
@@ -199,6 +217,13 @@ impl AcpConnection {
return Err(UnsupportedVersion.into());
}
let telemetry_id = response
.agent_info
// Use the one the agent provides if we have one
.map(|info| info.name.into())
// Otherwise, just use the name
.unwrap_or_else(|| server_name.clone());
Ok(Self {
auth_methods: response.auth_methods,
root_dir: root_dir.to_owned(),
@@ -233,8 +258,8 @@ impl Drop for AcpConnection {
}
impl AgentConnection for AcpConnection {
fn telemetry_id(&self) -> &'static str {
self.telemetry_id
fn telemetry_id(&self) -> SharedString {
self.telemetry_id.clone()
}
fn new_thread(

View File

@@ -56,7 +56,6 @@ impl AgentServerDelegate {
pub trait AgentServer: Send {
fn logo(&self) -> ui::IconName;
fn name(&self) -> SharedString;
fn telemetry_id(&self) -> &'static str;
fn default_mode(&self, _cx: &mut App) -> Option<agent_client_protocol::SessionModeId> {
None
}

View File

@@ -22,10 +22,6 @@ pub struct AgentServerLoginCommand {
}
impl AgentServer for ClaudeCode {
fn telemetry_id(&self) -> &'static str {
"claude-code"
}
fn name(&self) -> SharedString {
"Claude Code".into()
}
@@ -83,7 +79,6 @@ impl AgentServer for ClaudeCode {
cx: &mut App,
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
let telemetry_id = self.telemetry_id();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
@@ -108,7 +103,6 @@ impl AgentServer for ClaudeCode {
.await?;
let connection = crate::acp::connect(
name,
telemetry_id,
command,
root_dir.as_ref(),
default_mode,

View File

@@ -23,10 +23,6 @@ pub(crate) mod tests {
}
impl AgentServer for Codex {
fn telemetry_id(&self) -> &'static str {
"codex"
}
fn name(&self) -> SharedString {
"Codex".into()
}
@@ -84,7 +80,6 @@ impl AgentServer for Codex {
cx: &mut App,
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
let telemetry_id = self.telemetry_id();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
@@ -110,7 +105,6 @@ impl AgentServer for Codex {
let connection = crate::acp::connect(
name,
telemetry_id,
command,
root_dir.as_ref(),
default_mode,

View File

@@ -1,4 +1,4 @@
use crate::{AgentServerDelegate, load_proxy_env};
use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
use acp_thread::AgentConnection;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
@@ -20,11 +20,7 @@ impl CustomAgentServer {
}
}
impl crate::AgentServer for CustomAgentServer {
fn telemetry_id(&self) -> &'static str {
"custom"
}
impl AgentServer for CustomAgentServer {
fn name(&self) -> SharedString {
self.name.clone()
}
@@ -112,14 +108,12 @@ impl crate::AgentServer for CustomAgentServer {
cx: &mut App,
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
let telemetry_id = self.telemetry_id();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let default_mode = self.default_mode(cx);
let default_model = self.default_model(cx);
let store = delegate.store.downgrade();
let extra_env = load_proxy_env(cx);
cx.spawn(async move |cx| {
let (command, root_dir, login) = store
.update(cx, |store, cx| {
@@ -139,7 +133,6 @@ impl crate::AgentServer for CustomAgentServer {
.await?;
let connection = crate::acp::connect(
name,
telemetry_id,
command,
root_dir.as_ref(),
default_mode,

View File

@@ -5,17 +5,13 @@ use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
use acp_thread::AgentConnection;
use anyhow::{Context as _, Result};
use gpui::{App, SharedString, Task};
use language_models::provider::google::GoogleLanguageModelProvider;
use language_models::api_key_for_gemini_cli;
use project::agent_server_store::GEMINI_NAME;
#[derive(Clone)]
pub struct Gemini;
impl AgentServer for Gemini {
fn telemetry_id(&self) -> &'static str {
"gemini-cli"
}
fn name(&self) -> SharedString {
"Gemini CLI".into()
}
@@ -31,7 +27,6 @@ impl AgentServer for Gemini {
cx: &mut App,
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
let telemetry_id = self.telemetry_id();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
@@ -42,11 +37,7 @@ impl AgentServer for Gemini {
cx.spawn(async move |cx| {
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
if let Some(api_key) = cx
.update(GoogleLanguageModelProvider::api_key_for_gemini_cli)?
.await
.ok()
{
if let Some(api_key) = cx.update(api_key_for_gemini_cli)?.await.ok() {
extra_env.insert("GEMINI_API_KEY".into(), api_key);
}
let (command, root_dir, login) = store
@@ -66,7 +57,6 @@ impl AgentServer for Gemini {
let connection = crate::acp::connect(
name,
telemetry_id,
command,
root_dir.as_ref(),
default_mode,

View File

@@ -565,8 +565,26 @@ impl MessageEditor {
if let Some((workspace, selections)) =
self.workspace.upgrade().zip(editor_clipboard_selections)
{
cx.stop_propagation();
let Some(first_selection) = selections.first() else {
return;
};
if let Some(file_path) = &first_selection.file_path {
// In case someone pastes selections from another window
// with a different project, we don't want to insert the
// crease (containing the absolute path) since the agent
// cannot access files outside the project.
let is_in_project = workspace
.read(cx)
.project()
.read(cx)
.project_path_for_absolute_path(file_path, cx)
.is_some();
if !is_in_project {
return;
}
}
cx.stop_propagation();
let insertion_target = self
.editor
.read(cx)

View File

@@ -1,6 +1,6 @@
use std::{cmp::Reverse, rc::Rc, sync::Arc};
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
use agent_servers::AgentServer;
use anyhow::Result;
use collections::IndexMap;
@@ -292,12 +292,18 @@ impl PickerDelegate for AcpModelPickerDelegate {
h_flex()
.w_full()
.gap_1p5()
.when_some(model_info.icon, |this, icon| {
this.child(
Icon::new(icon)
.map(|this| match &model_info.icon {
Some(AgentModelIcon::Path(path)) => this.child(
Icon::from_external_svg(path.clone())
.color(model_icon_color)
.size(IconSize::Small)
)
.size(IconSize::Small),
),
Some(AgentModelIcon::Named(icon)) => this.child(
Icon::new(*icon)
.color(model_icon_color)
.size(IconSize::Small),
),
None => this,
})
.child(Label::new(model_info.name.clone()).truncate()),
)

View File

@@ -1,7 +1,7 @@
use std::rc::Rc;
use std::sync::Arc;
use acp_thread::{AgentModelInfo, AgentModelSelector};
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector};
use agent_servers::AgentServer;
use fs::Fs;
use gpui::{Entity, FocusHandle};
@@ -64,7 +64,7 @@ impl Render for AcpModelSelectorPopover {
.map(|model| model.name.clone())
.unwrap_or_else(|| SharedString::from("Select a Model"));
let model_icon = model.as_ref().and_then(|model| model.icon);
let model_icon = model.as_ref().and_then(|model| model.icon.clone());
let focus_handle = self.focus_handle.clone();
@@ -78,8 +78,15 @@ impl Render for AcpModelSelectorPopover {
self.selector.clone(),
ButtonLike::new("active-model")
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.when_some(model_icon, |this, icon| {
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
.when_some(model_icon, |this, icon| match icon {
AgentModelIcon::Path(path) => this.child(
Icon::from_external_svg(path)
.color(color)
.size(IconSize::XSmall),
),
AgentModelIcon::Named(icon_name) => {
this.child(Icon::new(icon_name).color(color).size(IconSize::XSmall))
}
})
.child(
Label::new(model_name)

View File

@@ -170,7 +170,7 @@ impl ThreadFeedbackState {
}
}
let session_id = thread.read(cx).session_id().clone();
let agent = thread.read(cx).connection().telemetry_id();
let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
let task = telemetry.thread_data(&session_id, cx);
let rating = match feedback {
ThreadFeedback::Positive => "positive",
@@ -180,7 +180,7 @@ impl ThreadFeedbackState {
let thread = task.await?;
telemetry::event!(
"Agent Thread Rated",
agent = agent,
agent = agent_telemetry_id,
session_id = session_id,
rating = rating,
thread = thread
@@ -207,13 +207,13 @@ impl ThreadFeedbackState {
self.comments_editor.take();
let session_id = thread.read(cx).session_id().clone();
let agent = thread.read(cx).connection().telemetry_id();
let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
let task = telemetry.thread_data(&session_id, cx);
cx.background_spawn(async move {
let thread = task.await?;
telemetry::event!(
"Agent Thread Feedback Comments",
agent = agent,
agent = agent_telemetry_id,
session_id = session_id,
comments = comments,
thread = thread
@@ -333,6 +333,7 @@ impl AcpThreadView {
project: Entity<Project>,
history_store: Entity<HistoryStore>,
prompt_store: Option<Entity<PromptStore>>,
track_load_event: bool,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -391,8 +392,9 @@ impl AcpThreadView {
),
];
let show_codex_windows_warning = crate::ExternalAgent::parse_built_in(agent.as_ref())
== Some(crate::ExternalAgent::Codex);
let show_codex_windows_warning = cfg!(windows)
&& project.read(cx).is_local()
&& agent.clone().downcast::<agent_servers::Codex>().is_some();
Self {
agent: agent.clone(),
@@ -404,6 +406,7 @@ impl AcpThreadView {
resume_thread.clone(),
workspace.clone(),
project.clone(),
track_load_event,
window,
cx,
),
@@ -448,6 +451,7 @@ impl AcpThreadView {
self.resume_thread_metadata.clone(),
self.workspace.clone(),
self.project.clone(),
true,
window,
cx,
);
@@ -461,6 +465,7 @@ impl AcpThreadView {
resume_thread: Option<DbThreadMetadata>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
track_load_event: bool,
window: &mut Window,
cx: &mut Context<Self>,
) -> ThreadState {
@@ -519,6 +524,10 @@ impl AcpThreadView {
}
};
if track_load_event {
telemetry::event!("Agent Thread Started", agent = connection.telemetry_id());
}
let result = if let Some(native_agent) = connection
.clone()
.downcast::<agent::NativeAgentConnection>()
@@ -1133,8 +1142,8 @@ impl AcpThreadView {
let Some(thread) = self.thread() else {
return;
};
let agent_telemetry_id = self.agent.telemetry_id();
let session_id = thread.read(cx).session_id().clone();
let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
let thread = thread.downgrade();
if self.should_be_following {
self.workspace
@@ -1512,6 +1521,7 @@ impl AcpThreadView {
else {
return;
};
let agent_telemetry_id = connection.telemetry_id();
// Check for the experimental "terminal-auth" _meta field
let auth_method = connection.auth_methods().iter().find(|m| m.id == method);
@@ -1579,19 +1589,18 @@ impl AcpThreadView {
);
cx.notify();
self.auth_task = Some(cx.spawn_in(window, {
let agent = self.agent.clone();
async move |this, cx| {
let result = authenticate.await;
match &result {
Ok(_) => telemetry::event!(
"Authenticate Agent Succeeded",
agent = agent.telemetry_id()
agent = agent_telemetry_id
),
Err(_) => {
telemetry::event!(
"Authenticate Agent Failed",
agent = agent.telemetry_id(),
agent = agent_telemetry_id,
)
}
}
@@ -1675,6 +1684,7 @@ impl AcpThreadView {
None,
this.workspace.clone(),
this.project.clone(),
true,
window,
cx,
)
@@ -1730,43 +1740,38 @@ impl AcpThreadView {
connection.authenticate(method, cx)
};
cx.notify();
self.auth_task =
Some(cx.spawn_in(window, {
let agent = self.agent.clone();
async move |this, cx| {
let result = authenticate.await;
self.auth_task = Some(cx.spawn_in(window, {
async move |this, cx| {
let result = authenticate.await;
match &result {
Ok(_) => telemetry::event!(
"Authenticate Agent Succeeded",
agent = agent.telemetry_id()
),
Err(_) => {
telemetry::event!(
"Authenticate Agent Failed",
agent = agent.telemetry_id(),
)
}
match &result {
Ok(_) => telemetry::event!(
"Authenticate Agent Succeeded",
agent = agent_telemetry_id
),
Err(_) => {
telemetry::event!("Authenticate Agent Failed", agent = agent_telemetry_id,)
}
this.update_in(cx, |this, window, cx| {
if let Err(err) = result {
if let ThreadState::Unauthenticated {
pending_auth_method,
..
} = &mut this.thread_state
{
pending_auth_method.take();
}
this.handle_thread_error(err, cx);
} else {
this.reset(window, cx);
}
this.auth_task.take()
})
.ok();
}
}));
this.update_in(cx, |this, window, cx| {
if let Err(err) = result {
if let ThreadState::Unauthenticated {
pending_auth_method,
..
} = &mut this.thread_state
{
pending_auth_method.take();
}
this.handle_thread_error(err, cx);
} else {
this.reset(window, cx);
}
this.auth_task.take()
})
.ok();
}
}));
}
fn spawn_external_agent_login(
@@ -1896,10 +1901,11 @@ impl AcpThreadView {
let Some(thread) = self.thread() else {
return;
};
let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
telemetry::event!(
"Agent Tool Call Authorized",
agent = self.agent.telemetry_id(),
agent = agent_telemetry_id,
session = thread.read(cx).session_id(),
option = option_kind
);
@@ -3509,6 +3515,8 @@ impl AcpThreadView {
(method.id.0.clone(), method.name.clone())
};
let agent_telemetry_id = connection.telemetry_id();
Button::new(method_id.clone(), name)
.label_size(LabelSize::Small)
.map(|this| {
@@ -3528,7 +3536,7 @@ impl AcpThreadView {
cx.listener(move |this, _, window, cx| {
telemetry::event!(
"Authenticate Agent Started",
agent = this.agent.telemetry_id(),
agent = agent_telemetry_id,
method = method_id
);
@@ -5376,47 +5384,39 @@ impl AcpThreadView {
)
}
fn render_codex_windows_warning(&self, cx: &mut Context<Self>) -> Option<Callout> {
if self.show_codex_windows_warning {
Some(
Callout::new()
.icon(IconName::Warning)
.severity(Severity::Warning)
.title("Codex on Windows")
.description(
"For best performance, run Codex in Windows Subsystem for Linux (WSL2)",
)
.actions_slot(
Button::new("open-wsl-modal", "Open in WSL")
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(cx.listener({
move |_, _, _window, cx| {
#[cfg(windows)]
_window.dispatch_action(
zed_actions::wsl_actions::OpenWsl::default().boxed_clone(),
cx,
);
cx.notify();
}
})),
)
.dismiss_action(
IconButton::new("dismiss", IconName::Close)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(Tooltip::text("Dismiss Warning"))
.on_click(cx.listener({
move |this, _, _, cx| {
this.show_codex_windows_warning = false;
cx.notify();
}
})),
),
fn render_codex_windows_warning(&self, cx: &mut Context<Self>) -> Callout {
Callout::new()
.icon(IconName::Warning)
.severity(Severity::Warning)
.title("Codex on Windows")
.description("For best performance, run Codex in Windows Subsystem for Linux (WSL2)")
.actions_slot(
Button::new("open-wsl-modal", "Open in WSL")
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(cx.listener({
move |_, _, _window, cx| {
#[cfg(windows)]
_window.dispatch_action(
zed_actions::wsl_actions::OpenWsl::default().boxed_clone(),
cx,
);
cx.notify();
}
})),
)
.dismiss_action(
IconButton::new("dismiss", IconName::Close)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(Tooltip::text("Dismiss Warning"))
.on_click(cx.listener({
move |this, _, _, cx| {
this.show_codex_windows_warning = false;
cx.notify();
}
})),
)
} else {
None
}
}
fn render_thread_error(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
@@ -5936,12 +5936,8 @@ impl Render for AcpThreadView {
_ => this,
})
.children(self.render_thread_retry_status_callout(window, cx))
.children({
if cfg!(windows) && self.project.read(cx).is_local() {
self.render_codex_windows_warning(cx)
} else {
None
}
.when(self.show_codex_windows_warning, |this| {
this.child(self.render_codex_windows_warning(cx))
})
.children(self.render_thread_error(window, cx))
.when_some(
@@ -6398,6 +6394,7 @@ pub(crate) mod tests {
project,
history_store,
None,
false,
window,
cx,
)
@@ -6475,10 +6472,6 @@ pub(crate) mod tests {
where
C: 'static + AgentConnection + Send + Clone,
{
fn telemetry_id(&self) -> &'static str {
"test"
}
fn logo(&self) -> ui::IconName {
ui::IconName::Ai
}
@@ -6505,8 +6498,8 @@ pub(crate) mod tests {
struct SaboteurAgentConnection;
impl AgentConnection for SaboteurAgentConnection {
fn telemetry_id(&self) -> &'static str {
"saboteur"
fn telemetry_id(&self) -> SharedString {
"saboteur".into()
}
fn new_thread(
@@ -6569,8 +6562,8 @@ pub(crate) mod tests {
struct RefusalAgentConnection;
impl AgentConnection for RefusalAgentConnection {
fn telemetry_id(&self) -> &'static str {
"refusal"
fn telemetry_id(&self) -> SharedString {
"refusal".into()
}
fn new_thread(
@@ -6671,6 +6664,7 @@ pub(crate) mod tests {
project.clone(),
history_store.clone(),
None,
false,
window,
cx,
)

View File

@@ -36,7 +36,7 @@ use settings::{Settings, SettingsStore, update_settings_file};
use ui::{
Button, ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure,
Divider, DividerColor, ElevationIndex, IconName, IconPosition, IconSize, Indicator, LabelSize,
PopoverMenu, Switch, Tooltip, WithScrollbar, prelude::*,
PopoverMenu, Switch, SwitchColor, Tooltip, WithScrollbar, prelude::*,
};
use util::ResultExt as _;
use workspace::{Workspace, create_and_open_local_file};
@@ -117,7 +117,7 @@ impl AgentConfiguration {
}
fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let providers = LanguageModelRegistry::read_global(cx).providers();
let providers = LanguageModelRegistry::read_global(cx).visible_providers();
for provider in providers {
self.add_provider_configuration_view(&provider, window, cx);
}
@@ -260,11 +260,15 @@ impl AgentConfiguration {
h_flex()
.w_full()
.gap_1p5()
.child(
.child(if let Some(icon_path) = provider.icon_path() {
Icon::from_external_svg(icon_path)
.size(IconSize::Small)
.color(Color::Muted)
} else {
Icon::new(provider.icon())
.size(IconSize::Small)
.color(Color::Muted),
)
.color(Color::Muted)
})
.child(
h_flex()
.w_full()
@@ -416,7 +420,7 @@ impl AgentConfiguration {
&mut self,
cx: &mut Context<Self>,
) -> impl IntoElement {
let providers = LanguageModelRegistry::read_global(cx).providers();
let providers = LanguageModelRegistry::read_global(cx).visible_providers();
let popover_menu = PopoverMenu::new("add-provider-popover")
.trigger(
@@ -879,6 +883,7 @@ impl AgentConfiguration {
.child(context_server_configuration_menu)
.child(
Switch::new("context-server-switch", is_running.into())
.color(SwitchColor::Accent)
.on_click({
let context_server_manager = self.context_server_store.clone();
let fs = self.fs.clone();

View File

@@ -77,7 +77,8 @@ impl Render for AgentModelSelector {
.map(|model| model.model.name().0)
.unwrap_or_else(|| SharedString::from("Select a Model"));
let provider_icon = model.as_ref().map(|model| model.provider.icon());
let provider_icon_path = model.as_ref().and_then(|model| model.provider.icon_path());
let provider_icon_name = model.as_ref().map(|model| model.provider.icon());
let color = if self.menu_handle.is_deployed() {
Color::Accent
} else {
@@ -89,8 +90,17 @@ impl Render for AgentModelSelector {
PickerPopoverMenu::new(
self.selector.clone(),
ButtonLike::new("active-model")
.when_some(provider_icon, |this, icon| {
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
.when_some(provider_icon_path.clone(), |this, icon_path| {
this.child(
Icon::from_external_svg(icon_path)
.color(color)
.size(IconSize::XSmall),
)
})
.when(provider_icon_path.is_none(), |this| {
this.when_some(provider_icon_name, |this, icon| {
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
})
})
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.child(
@@ -102,7 +112,7 @@ impl Render for AgentModelSelector {
.child(
Icon::new(IconName::ChevronDown)
.color(color)
.size(IconSize::Small),
.size(IconSize::XSmall),
),
move |_window, cx| {
Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx)

View File

@@ -305,6 +305,7 @@ impl ActiveView {
project,
history_store,
prompt_store,
false,
window,
cx,
)
@@ -885,10 +886,6 @@ impl AgentPanel {
let server = ext_agent.server(fs, history);
if !loading {
telemetry::event!("Agent Thread Started", agent = server.telemetry_id());
}
this.update_in(cx, |this, window, cx| {
let selected_agent = ext_agent.into();
if this.selected_agent != selected_agent {
@@ -905,6 +902,7 @@ impl AgentPanel {
project,
this.history_store.clone(),
this.prompt_store.clone(),
!loading,
window,
cx,
)
@@ -2294,7 +2292,7 @@ impl AgentPanel {
let history_is_empty = self.history_store.read(cx).is_empty(cx);
let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx)
.providers()
.visible_providers()
.iter()
.any(|provider| {
provider.is_authenticated(cx)

View File

@@ -160,16 +160,6 @@ pub enum ExternalAgent {
}
impl ExternalAgent {
pub fn parse_built_in(server: &dyn agent_servers::AgentServer) -> Option<Self> {
match server.telemetry_id() {
"gemini-cli" => Some(Self::Gemini),
"claude-code" => Some(Self::ClaudeCode),
"codex" => Some(Self::Codex),
"zed" => Some(Self::NativeAgent),
_ => None,
}
}
pub fn server(
&self,
fs: Arc<dyn fs::Fs>,
@@ -348,7 +338,8 @@ fn init_language_model_settings(cx: &mut App) {
|_, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
| language_model::Event::RemovedProvider(_)
| language_model::Event::ProvidersChanged => {
update_active_language_model_from_settings(cx);
}
_ => {}
@@ -367,26 +358,49 @@ fn update_active_language_model_from_settings(cx: &mut App) {
}
}
let default = settings.default_model.as_ref().map(to_selected_model);
// Filter out models from providers that are not authenticated
fn is_provider_authenticated(
selection: &LanguageModelSelection,
registry: &LanguageModelRegistry,
cx: &App,
) -> bool {
let provider_id = LanguageModelProviderId::from(selection.provider.0.clone());
registry
.provider(&provider_id)
.map_or(false, |provider| provider.is_authenticated(cx))
}
let registry = LanguageModelRegistry::global(cx);
let registry_ref = registry.read(cx);
let default = settings
.default_model
.as_ref()
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
.map(to_selected_model);
let inline_assistant = settings
.inline_assistant_model
.as_ref()
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
.map(to_selected_model);
let commit_message = settings
.commit_message_model
.as_ref()
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
.map(to_selected_model);
let thread_summary = settings
.thread_summary_model
.as_ref()
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
.map(to_selected_model);
let inline_alternatives = settings
.inline_alternatives
.iter()
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
.map(to_selected_model)
.collect::<Vec<_>>();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.update(cx, |registry, cx| {
registry.select_default_model(default.as_ref(), cx);
registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
registry.select_commit_message_model(commit_message.as_ref(), cx);

View File

@@ -1,13 +1,12 @@
use std::{cmp::Reverse, sync::Arc};
use collections::IndexMap;
use futures::{StreamExt, channel::mpsc};
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
};
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
LanguageModelRegistry,
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
LanguageModelProviderId, LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
@@ -47,7 +46,9 @@ pub fn language_model_selector(
}
fn all_models(cx: &App) -> GroupedModels {
let providers = LanguageModelRegistry::global(cx).read(cx).providers();
let providers = LanguageModelRegistry::global(cx)
.read(cx)
.visible_providers();
let recommended = providers
.iter()
@@ -57,12 +58,12 @@ fn all_models(cx: &App) -> GroupedModels {
.into_iter()
.map(|model| ModelInfo {
model,
icon: provider.icon(),
icon: ProviderIcon::from_provider(provider.as_ref()),
})
})
.collect();
let all = providers
let all: Vec<ModelInfo> = providers
.iter()
.flat_map(|provider| {
provider
@@ -70,7 +71,7 @@ fn all_models(cx: &App) -> GroupedModels {
.into_iter()
.map(|model| ModelInfo {
model,
icon: provider.icon(),
icon: ProviderIcon::from_provider(provider.as_ref()),
})
})
.collect();
@@ -78,10 +79,26 @@ fn all_models(cx: &App) -> GroupedModels {
GroupedModels::new(all, recommended)
}
#[derive(Clone)]
enum ProviderIcon {
Name(IconName),
Path(SharedString),
}
impl ProviderIcon {
fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
if let Some(path) = provider.icon_path() {
Self::Path(path)
} else {
Self::Name(provider.icon())
}
}
}
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
icon: IconName,
icon: ProviderIcon,
}
pub struct LanguageModelPickerDelegate {
@@ -91,7 +108,7 @@ pub struct LanguageModelPickerDelegate {
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
_authenticate_all_providers_task: Task<()>,
_subscriptions: Vec<Subscription>,
_refresh_models_task: Task<()>,
popover_styles: bool,
focus_handle: FocusHandle,
}
@@ -116,24 +133,43 @@ impl LanguageModelPickerDelegate {
filtered_entries: entries,
get_active_model: Arc::new(get_active_model),
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![cx.subscribe_in(
&LanguageModelRegistry::global(cx),
window,
|picker, _, event, window, cx| {
match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
let query = picker.query(cx);
picker.delegate.all_models = Arc::new(all_models(cx));
// Update matches will automatically drop the previous task
// if we get a provider event again
picker.update_matches(query, window, cx)
}
_ => {}
_refresh_models_task: {
// Create a channel to signal when models need refreshing
let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
// Subscribe to registry events and send refresh signals through the channel
let registry = LanguageModelRegistry::global(cx);
cx.subscribe(&registry, move |_picker, _, event, _cx| match event {
language_model::Event::ProviderStateChanged(_) => {
refresh_tx.unbounded_send(()).ok();
}
},
)],
language_model::Event::AddedProvider(_) => {
refresh_tx.unbounded_send(()).ok();
}
language_model::Event::RemovedProvider(_) => {
refresh_tx.unbounded_send(()).ok();
}
language_model::Event::ProvidersChanged => {
refresh_tx.unbounded_send(()).ok();
}
_ => {}
})
.detach();
// Spawn a task that listens for refresh signals and updates the picker
cx.spawn_in(window, async move |this, cx| {
while let Some(()) = refresh_rx.next().await {
let result = this.update_in(cx, |picker, window, cx| {
picker.delegate.all_models = Arc::new(all_models(cx));
picker.refresh(window, cx);
});
if result.is_err() {
// Picker was dropped, exit the loop
break;
}
}
})
},
popover_styles,
focus_handle,
}
@@ -392,7 +428,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let configured_providers = language_model_registry
.read(cx)
.providers()
.visible_providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
@@ -504,11 +540,16 @@ impl PickerDelegate for LanguageModelPickerDelegate {
h_flex()
.w_full()
.gap_1p5()
.child(
Icon::new(model_info.icon)
.child(match &model_info.icon {
ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
.color(model_icon_color)
.size(IconSize::Small),
)
ProviderIcon::Path(icon_path) => {
Icon::from_external_svg(icon_path.clone())
.color(model_icon_color)
.size(IconSize::Small)
}
})
.child(Label::new(model_info.model.name().0).truncate()),
)
.end_slot(div().pr_3().when(is_selected, |this| {
@@ -657,7 +698,7 @@ mod tests {
.into_iter()
.map(|(provider, name)| ModelInfo {
model: Arc::new(TestLanguageModel::new(name, provider)),
icon: IconName::Ai,
icon: ProviderIcon::Name(IconName::Ai),
})
.collect()
}

View File

@@ -1682,98 +1682,6 @@ impl TextThreadEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
let editor_clipboard_selections = cx
.read_from_clipboard()
.and_then(|item| item.entries().first().cloned())
.and_then(|entry| match entry {
ClipboardEntry::String(text) => {
text.metadata_json::<Vec<editor::ClipboardSelection>>()
}
_ => None,
});
let has_file_context = editor_clipboard_selections
.as_ref()
.is_some_and(|selections| {
selections
.iter()
.any(|sel| sel.file_path.is_some() && sel.line_range.is_some())
});
if has_file_context {
if let Some(clipboard_item) = cx.read_from_clipboard() {
if let Some(ClipboardEntry::String(clipboard_text)) =
clipboard_item.entries().first()
{
if let Some(selections) = editor_clipboard_selections {
cx.stop_propagation();
let text = clipboard_text.text();
self.editor.update(cx, |editor, cx| {
let mut current_offset = 0;
let weak_editor = cx.entity().downgrade();
for selection in selections {
if let (Some(file_path), Some(line_range)) =
(selection.file_path, selection.line_range)
{
let selected_text =
&text[current_offset..current_offset + selection.len];
let fence = assistant_slash_commands::codeblock_fence_for_path(
file_path.to_str(),
Some(line_range.clone()),
);
let formatted_text = format!("{fence}{selected_text}\n```");
let insert_point = editor
.selections
.newest::<Point>(&editor.display_snapshot(cx))
.head();
let start_row = MultiBufferRow(insert_point.row);
editor.insert(&formatted_text, window, cx);
let snapshot = editor.buffer().read(cx).snapshot(cx);
let anchor_before = snapshot.anchor_after(insert_point);
let anchor_after = editor
.selections
.newest_anchor()
.head()
.bias_left(&snapshot);
editor.insert("\n", window, cx);
let crease_text = acp_thread::selection_name(
Some(file_path.as_ref()),
&line_range,
);
let fold_placeholder = quote_selection_fold_placeholder(
crease_text,
weak_editor.clone(),
);
let crease = Crease::inline(
anchor_before..anchor_after,
fold_placeholder,
render_quote_selection_output_toggle,
|_, _, _, _| Empty.into_any(),
);
editor.insert_creases(vec![crease], cx);
editor.fold_at(start_row, window, cx);
current_offset += selection.len;
if !selection.is_entire_line && current_offset < text.len() {
current_offset += 1;
}
}
}
});
return;
}
}
}
}
cx.stop_propagation();
let mut images = if let Some(item) = cx.read_from_clipboard() {
@@ -2189,7 +2097,8 @@ impl TextThreadEditor {
.default_model()
.map(|default| default.provider);
let provider_icon = match active_provider {
let provider_icon_path = active_provider.as_ref().and_then(|p| p.icon_path());
let provider_icon_name = match &active_provider {
Some(provider) => provider.icon(),
None => IconName::Ai,
};
@@ -2201,6 +2110,16 @@ impl TextThreadEditor {
(Color::Muted, IconName::ChevronDown)
};
let provider_icon_element = if let Some(icon_path) = provider_icon_path {
Icon::from_external_svg(icon_path)
.color(color)
.size(IconSize::XSmall)
} else {
Icon::new(provider_icon_name)
.color(color)
.size(IconSize::XSmall)
};
PickerPopoverMenu::new(
self.language_model_selector.clone(),
ButtonLike::new("active-model")
@@ -2208,7 +2127,7 @@ impl TextThreadEditor {
.child(
h_flex()
.gap_0p5()
.child(Icon::new(provider_icon).color(color).size(IconSize::XSmall))
.child(provider_icon_element)
.child(
Label::new(model_name)
.color(color)

View File

@@ -1,9 +1,25 @@
use gpui::{Action, IntoElement, ParentElement, RenderOnce, point};
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
use language_model::{LanguageModelProvider, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
use ui::{Divider, List, ListBulletItem, prelude::*};
#[derive(Clone)]
enum ProviderIcon {
Name(IconName),
Path(SharedString),
}
impl ProviderIcon {
fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
if let Some(path) = provider.icon_path() {
Self::Path(path)
} else {
Self::Name(provider.icon())
}
}
}
pub struct ApiKeysWithProviders {
configured_providers: Vec<(IconName, SharedString)>,
configured_providers: Vec<(ProviderIcon, SharedString)>,
}
impl ApiKeysWithProviders {
@@ -13,7 +29,8 @@ impl ApiKeysWithProviders {
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
| language_model::Event::RemovedProvider(_)
| language_model::Event::ProvidersChanged => {
this.configured_providers = Self::compute_configured_providers(cx)
}
_ => {}
@@ -26,14 +43,19 @@ impl ApiKeysWithProviders {
}
}
fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> {
fn compute_configured_providers(cx: &App) -> Vec<(ProviderIcon, SharedString)> {
LanguageModelRegistry::read_global(cx)
.providers()
.visible_providers()
.iter()
.filter(|provider| {
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
})
.map(|provider| (provider.icon(), provider.name().0))
.map(|provider| {
(
ProviderIcon::from_provider(provider.as_ref()),
provider.name().0,
)
})
.collect()
}
}
@@ -47,7 +69,14 @@ impl Render for ApiKeysWithProviders {
.map(|(icon, name)| {
h_flex()
.gap_1p5()
.child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
.child(match icon {
ProviderIcon::Name(icon_name) => Icon::new(icon_name)
.size(IconSize::XSmall)
.color(Color::Muted),
ProviderIcon::Path(icon_path) => Icon::from_external_svg(icon_path)
.size(IconSize::XSmall)
.color(Color::Muted),
})
.child(Label::new(name))
});
div()

View File

@@ -11,7 +11,7 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding};
pub struct AgentPanelOnboarding {
user_store: Entity<UserStore>,
client: Arc<Client>,
configured_providers: Vec<(IconName, SharedString)>,
has_configured_providers: bool,
continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
}
@@ -27,8 +27,9 @@ impl AgentPanelOnboarding {
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
this.configured_providers = Self::compute_available_providers(cx)
| language_model::Event::RemovedProvider(_)
| language_model::Event::ProvidersChanged => {
this.has_configured_providers = Self::has_configured_providers(cx)
}
_ => {}
},
@@ -38,20 +39,16 @@ impl AgentPanelOnboarding {
Self {
user_store,
client,
configured_providers: Self::compute_available_providers(cx),
has_configured_providers: Self::has_configured_providers(cx),
continue_with_zed_ai: Arc::new(continue_with_zed_ai),
}
}
fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> {
fn has_configured_providers(cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
.providers()
.visible_providers()
.iter()
.filter(|provider| {
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
})
.map(|provider| (provider.icon(), provider.name().0))
.collect()
.any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID)
}
}
@@ -81,7 +78,7 @@ impl Render for AgentPanelOnboarding {
}),
)
.map(|this| {
if enrolled_in_trial || is_pro_user || !self.configured_providers.is_empty() {
if enrolled_in_trial || is_pro_user || self.has_configured_providers {
this
} else {
this.child(ApiKeysWithoutProviders::new())

View File

@@ -8,7 +8,7 @@ use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::B
use http_client::http::{self, HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
use serde::{Deserialize, Serialize};
pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode};
pub use settings::ModelMode;
use strum::{EnumIter, EnumString};
use thiserror::Error;

View File

@@ -53,7 +53,7 @@ text.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http.workspace = true
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
tokio-socks.workspace = true
tokio.workspace = true
url.workspace = true
util.workspace = true

View File

@@ -1,18 +0,0 @@
[package]
name = "cloud_zeta2_prompt"
version = "0.1.0"
publish.workspace = true
edition.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/cloud_zeta2_prompt.rs"
[dependencies]
anyhow.workspace = true
cloud_llm_client.workspace = true
indoc.workspace = true
serde.workspace = true

View File

@@ -1,485 +0,0 @@
use anyhow::Result;
use cloud_llm_client::predict_edits_v3::{
self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
};
use indoc::indoc;
use std::cmp;
use std::fmt::Write;
use std::path::Path;
use std::sync::Arc;
pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
/// NOTE: Differs from zed version of constant - includes a newline
pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
/// NOTE: Differs from zed version of constant - includes a newline
pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
## Edit History
"#};
const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
---
Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
Do not include the cursor marker in your output.
If you're editing multiple files, be sure to reflect filename in the hunk's header.
"};
const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
# Instructions
You are an edit prediction agent in a code editor.
Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish.
Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make.
Always continue along the user's current trajectory, rather than changing course.
## Output Format
You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change
along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format:
<edits path="my-project/src/myapp/cli.py">
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
</edits>
- Specify the file to edit using the `path` attribute.
- Use `<old_text>` and `<new_text>` tags to replace content
- `<old_text>` must exactly match existing file content, including indentation
- `<old_text>` cannot be empty
- Do not escape quotes, newlines, or other characters within tags
- Always close all tags properly
- Don't include the <|user_cursor|> marker in your output.
## Edit History
"#};
const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
---
Remember that the edits in the edit history have already been applied.
"#};
pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
let prompt_data = PromptData {
events: request.events.clone(),
cursor_point: request.cursor_point,
cursor_path: request.excerpt_path.clone(),
included_files: request.related_files.clone(),
};
match request.prompt_format {
PromptFormat::MinimalQwen => {
return Ok(MinimalQwenPrompt.render(&prompt_data));
}
PromptFormat::SeedCoder1120 => {
return Ok(SeedCoder1120Prompt.render(&prompt_data));
}
_ => (),
};
let insertions = match request.prompt_format {
PromptFormat::Minimal | PromptFormat::OldTextNewText => {
vec![(request.cursor_point, CURSOR_MARKER)]
}
PromptFormat::OnlySnippets => vec![],
PromptFormat::MinimalQwen => unreachable!(),
PromptFormat::SeedCoder1120 => unreachable!(),
};
let mut prompt = match request.prompt_format {
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
PromptFormat::OnlySnippets => String::new(),
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
PromptFormat::MinimalQwen => unreachable!(),
PromptFormat::SeedCoder1120 => unreachable!(),
};
if request.events.is_empty() {
prompt.push_str("(No edit history)\n\n");
} else {
let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
"The following are the latest edits made by the user, from earlier to later.\n\n"
} else {
"Here are the latest edits made by the user, from earlier to later.\n\n"
};
prompt.push_str(edit_preamble);
push_events(&mut prompt, &request.events);
}
let excerpts_preamble = match request.prompt_format {
PromptFormat::Minimal => indoc! {"
## Part of the file under the cursor
(The cursor marker <|user_cursor|> indicates the current user cursor position.
The file is in current state, edits from edit history has been applied.
We only show part of the file around the cursor.
You can only edit exactly this part of the file.
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
"},
PromptFormat::OldTextNewText => indoc! {"
## Code Excerpts
Here is some excerpts of code that you should take into account to predict the next edit.
The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history.
In addition other excerpts are included to better understand what the edit will be, including the declaration
or references of symbols around the cursor, or other similar code snippets that may need to be updated
following patterns that appear in the edit history.
Consider each of them carefully in relation to the edit history, and that the user may not have navigated
to the next place they want to edit yet.
Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
"},
PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
indoc! {"
## Code Excerpts
The cursor marker <|user_cursor|> indicates the current user cursor position.
The file is in current state, edits from edit history have been applied.
"}
}
};
prompt.push_str(excerpts_preamble);
prompt.push('\n');
let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
for related_file in &request.related_files {
if request.prompt_format == PromptFormat::Minimal {
write_codeblock_with_filename(
&related_file.path,
&related_file.excerpts,
if related_file.path == request.excerpt_path {
&insertions
} else {
&[]
},
related_file.max_row,
include_line_numbers,
&mut prompt,
);
} else {
write_codeblock(
&related_file.path,
&related_file.excerpts,
if related_file.path == request.excerpt_path {
&insertions
} else {
&[]
},
related_file.max_row,
include_line_numbers,
&mut prompt,
);
}
}
match request.prompt_format {
PromptFormat::OldTextNewText => {
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
}
PromptFormat::Minimal => {
prompt.push_str(MINIMAL_PROMPT_REMINDER);
}
_ => {}
}
Ok(prompt)
}
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
match prompt_format {
PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
_ => GenerationParams::default(),
}
}
pub fn write_codeblock<'a>(
path: &Path,
excerpts: impl IntoIterator<Item = &'a Excerpt>,
sorted_insertions: &[(Point, &str)],
file_line_count: Line,
include_line_numbers: bool,
output: &'a mut String,
) {
writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
write_excerpts(
excerpts,
sorted_insertions,
file_line_count,
include_line_numbers,
output,
);
write!(output, "`````\n\n").unwrap();
}
fn write_codeblock_with_filename<'a>(
path: &Path,
excerpts: impl IntoIterator<Item = &'a Excerpt>,
sorted_insertions: &[(Point, &str)],
file_line_count: Line,
include_line_numbers: bool,
output: &'a mut String,
) {
writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
write_excerpts(
excerpts,
sorted_insertions,
file_line_count,
include_line_numbers,
output,
);
write!(output, "`````\n\n").unwrap();
}
pub fn write_excerpts<'a>(
excerpts: impl IntoIterator<Item = &'a Excerpt>,
sorted_insertions: &[(Point, &str)],
file_line_count: Line,
include_line_numbers: bool,
output: &mut String,
) {
let mut current_row = Line(0);
let mut sorted_insertions = sorted_insertions.iter().peekable();
for excerpt in excerpts {
if excerpt.start_line > current_row {
writeln!(output, "").unwrap();
}
if excerpt.text.is_empty() {
return;
}
current_row = excerpt.start_line;
for mut line in excerpt.text.lines() {
if include_line_numbers {
write!(output, "{}|", current_row.0 + 1).unwrap();
}
while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
match current_row.cmp(&insertion_location.line) {
cmp::Ordering::Equal => {
let (prefix, suffix) = line.split_at(insertion_location.column as usize);
output.push_str(prefix);
output.push_str(insertion_marker);
line = suffix;
sorted_insertions.next();
}
cmp::Ordering::Less => break,
cmp::Ordering::Greater => {
sorted_insertions.next();
break;
}
}
}
output.push_str(line);
output.push('\n');
current_row.0 += 1;
}
}
if current_row < file_line_count {
writeln!(output, "").unwrap();
}
}
pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
if events.is_empty() {
return;
};
writeln!(output, "`````diff").unwrap();
for event in events {
writeln!(output, "{}", event).unwrap();
}
writeln!(output, "`````\n").unwrap();
}
struct PromptData {
events: Vec<Arc<Event>>,
cursor_point: Point,
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
included_files: Vec<RelatedFile>,
}
#[derive(Default)]
pub struct GenerationParams {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
}
trait PromptFormatter {
fn render(&self, data: &PromptData) -> String;
fn generation_params() -> GenerationParams {
return GenerationParams::default();
}
}
struct MinimalQwenPrompt;
impl PromptFormatter for MinimalQwenPrompt {
fn render(&self, data: &PromptData) -> String {
let edit_history = self.fmt_edit_history(data);
let context = self.fmt_context(data);
format!(
"{instructions}\n\n{edit_history}\n\n{context}",
instructions = MinimalQwenPrompt::INSTRUCTIONS,
edit_history = edit_history,
context = context
)
}
}
impl MinimalQwenPrompt {
const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
fn fmt_edit_history(&self, data: &PromptData) -> String {
if data.events.is_empty() {
"(No edit history)\n\n".to_string()
} else {
let mut events_str = String::new();
push_events(&mut events_str, &data.events);
format!(
"The following are the latest edits made by the user, from earlier to later.\n\n{}",
events_str
)
}
}
fn fmt_context(&self, data: &PromptData) -> String {
let mut context = String::new();
let include_line_numbers = true;
for related_file in &data.included_files {
writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
if related_file.path == data.cursor_path {
write!(context, "<|fim_prefix|>").unwrap();
write_excerpts(
&related_file.excerpts,
&[(data.cursor_point, "<|fim_suffix|>")],
related_file.max_row,
include_line_numbers,
&mut context,
);
writeln!(context, "<|fim_middle|>").unwrap();
} else {
write_excerpts(
&related_file.excerpts,
&[],
related_file.max_row,
include_line_numbers,
&mut context,
);
}
}
context
}
}
struct SeedCoder1120Prompt;
impl PromptFormatter for SeedCoder1120Prompt {
fn render(&self, data: &PromptData) -> String {
let edit_history = self.fmt_edit_history(data);
let context = self.fmt_context(data);
format!(
"# Edit History:\n{edit_history}\n\n{context}",
edit_history = edit_history,
context = context
)
}
fn generation_params() -> GenerationParams {
GenerationParams {
temperature: Some(0.2),
top_p: Some(0.9),
stop: Some(vec!["<[end_of_sentence]>".into()]),
}
}
}
impl SeedCoder1120Prompt {
fn fmt_edit_history(&self, data: &PromptData) -> String {
if data.events.is_empty() {
"(No edit history)\n\n".to_string()
} else {
let mut events_str = String::new();
push_events(&mut events_str, &data.events);
events_str
}
}
fn fmt_context(&self, data: &PromptData) -> String {
let mut context = String::new();
let include_line_numbers = true;
for related_file in &data.included_files {
writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
if related_file.path == data.cursor_path {
let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
context.push_str(&fim_prompt);
} else {
write_excerpts(
&related_file.excerpts,
&[],
related_file.max_row,
include_line_numbers,
&mut context,
);
}
}
context
}
fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
let mut buf = String::new();
const FIM_SUFFIX: &str = "<[fim-suffix]>";
const FIM_PREFIX: &str = "<[fim-prefix]>";
const FIM_MIDDLE: &str = "<[fim-middle]>";
write!(buf, "{}", FIM_PREFIX).unwrap();
write_excerpts(
&file.excerpts,
&[(cursor_point, FIM_SUFFIX)],
file.max_row,
true,
&mut buf,
);
// Swap prefix and suffix parts
let index = buf.find(FIM_SUFFIX).unwrap();
let prefix = &buf[..index];
let suffix = &buf[index..];
format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
}
}

View File

@@ -33,6 +33,7 @@ smol.workspace = true
tempfile.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
terminal.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -8,9 +8,12 @@ use futures::{
AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
};
use gpui::AsyncApp;
use settings::Settings as _;
use smol::channel;
use smol::process::Child;
use terminal::terminal_settings::TerminalSettings;
use util::TryFutureExt as _;
use util::shell_builder::ShellBuilder;
use crate::client::ModelContextServerBinary;
use crate::transport::Transport;
@@ -28,9 +31,14 @@ impl StdioTransport {
working_directory: &Option<PathBuf>,
cx: &AsyncApp,
) -> Result<Self> {
let mut command = util::command::new_smol_command(&binary.executable);
let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
let builder = ShellBuilder::new(&shell, cfg!(windows));
let (command, args) =
builder.build(Some(binary.executable.display().to_string()), &binary.args);
let mut command = util::command::new_smol_command(command);
command
.args(&binary.args)
.args(args)
.envs(binary.env.unwrap_or_default())
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())

View File

@@ -21,7 +21,6 @@ arrayvec.workspace = true
brotli.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
collections.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
@@ -50,8 +49,6 @@ semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strsim.workspace = true
strum.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
@@ -62,6 +59,7 @@ uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }

View File

@@ -1,14 +1,13 @@
use anyhow::Result;
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
ZED_VERSION_HEADER_NAME,
};
use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
use collections::{HashMap, HashSet};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use edit_prediction_context::EditPredictionExcerptOptions;
@@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
use futures::{
AsyncReadExt as _, FutureExt as _, StreamExt as _,
channel::{
mpsc::{self, UnboundedReceiver},
oneshot,
},
channel::mpsc::{self, UnboundedReceiver},
select_biased,
};
use gpui::BackgroundExecutor;
@@ -58,8 +54,10 @@ mod onboarding_modal;
pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
pub mod udiff;
mod xml_edits;
mod zed_edit_prediction_delegate;
pub mod zeta1;
pub mod zeta2;
@@ -72,7 +70,6 @@ use crate::mercury::Mercury;
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
pub use crate::prediction::EditPredictionInputs;
use crate::prediction::EditPredictionResult;
pub use crate::sweep_ai::SweepAi;
pub use telemetry_events::EditPredictionRating;
@@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
min_bytes: 128,
target_before_cursor_over_total_bytes: 0.5,
},
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
prompt_format: PromptFormat::DEFAULT,
};
@@ -162,7 +158,6 @@ pub struct EditPredictionStore {
use_context: bool,
options: ZetaOptions,
update_required: bool,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
#[cfg(feature = "eval-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: EditPredictionModel,
@@ -183,10 +178,22 @@ pub enum EditPredictionModel {
Mercury,
}
pub struct EditPredictionModelInput {
project: Entity<Project>,
buffer: Entity<Buffer>,
snapshot: BufferSnapshot,
position: Anchor,
events: Vec<Arc<zeta_prompt::Event>>,
related_files: Arc<[RelatedFile]>,
recent_paths: VecDeque<ProjectPath>,
trigger: PredictEditsRequestTrigger,
diagnostic_search_range: Range<Point>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ZetaOptions {
pub context: EditPredictionExcerptOptions,
pub max_prompt_bytes: usize,
pub prompt_format: predict_edits_v3::PromptFormat,
}
@@ -194,7 +201,8 @@ pub struct ZetaOptions {
pub enum DebugEvent {
ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
EditPredictionRequested(EditPredictionRequestedDebugEvent),
EditPredictionStarted(EditPredictionStartedDebugEvent),
EditPredictionFinished(EditPredictionFinishedDebugEvent),
}
#[derive(Debug)]
@@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent {
}
#[derive(Debug)]
pub struct EditPredictionRequestedDebugEvent {
pub inputs: EditPredictionInputs,
pub retrieval_time: Duration,
pub struct EditPredictionStartedDebugEvent {
pub buffer: WeakEntity<Buffer>,
pub position: Anchor,
pub local_prompt: Result<String, String>,
pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
pub prompt: Option<String>,
}
#[derive(Debug)]
pub struct EditPredictionFinishedDebugEvent {
pub buffer: WeakEntity<Buffer>,
pub position: Anchor,
pub model_output: Option<String>,
}
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
struct ProjectState {
events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
events: VecDeque<Arc<zeta_prompt::Event>>,
last_event: Option<LastEvent>,
recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
context_updates_tx: smol::channel::Sender<()>,
context_updates_rx: smol::channel::Receiver<()>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
last_prediction_refresh: Option<(EntityId, Instant)>,
cancelled_predictions: HashSet<usize>,
context: Entity<RelatedExcerptStore>,
@@ -241,7 +252,7 @@ struct ProjectState {
}
impl ProjectState {
pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
self.events
.iter()
.cloned()
@@ -376,7 +387,7 @@ impl LastEvent {
&self,
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
cx: &App,
) -> Option<Arc<predict_edits_v3::Event>> {
) -> Option<Arc<zeta_prompt::Event>> {
let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
@@ -396,7 +407,7 @@ impl LastEvent {
if path == old_path && diff.is_empty() {
None
} else {
Some(Arc::new(predict_edits_v3::Event::BufferChange {
Some(Arc::new(zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
@@ -481,7 +492,6 @@ impl EditPredictionStore {
},
),
update_required: false,
debug_tx: None,
#[cfg(feature = "eval-support")]
eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2,
@@ -536,12 +546,6 @@ impl EditPredictionStore {
self.eval_cache = Some(cache);
}
pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
self.debug_tx = Some(debug_watch_tx);
debug_watch_rx
}
pub fn options(&self) -> &ZetaOptions {
&self.options
}
@@ -560,15 +564,35 @@ impl EditPredictionStore {
}
}
pub fn edit_history_for_project(
&self,
project: &Entity<Project>,
) -> Vec<Arc<zeta_prompt::Event>> {
self.projects
.get(&project.entity_id())
.map(|project_state| project_state.events.iter().cloned().collect())
.unwrap_or_default()
}
pub fn context_for_project<'a>(
&'a self,
project: &Entity<Project>,
cx: &'a App,
) -> &'a [RelatedFile] {
) -> Arc<[RelatedFile]> {
self.projects
.get(&project.entity_id())
.map(|project| project.context.read(cx).related_files())
.unwrap_or(&[])
.unwrap_or_else(|| vec![].into())
}
pub fn context_for_project_with_buffers<'a>(
&'a self,
project: &Entity<Project>,
cx: &'a App,
) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
self.projects
.get(&project.entity_id())
.map(|project| project.context.read(cx).related_files_with_buffers())
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -599,85 +623,21 @@ impl EditPredictionStore {
cx: &mut Context<Self>,
) -> &mut ProjectState {
let entity_id = project.entity_id();
let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
self.projects
.entry(entity_id)
.or_insert_with(|| ProjectState {
context: {
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
cx.subscribe(
&related_excerpt_store,
move |this, _, event, _| match event {
RelatedExcerptStoreEvent::StartedRefresh => {
if let Some(debug_tx) = this.debug_tx.clone() {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalStarted(
ContextRetrievalStartedDebugEvent {
project_entity_id: entity_id,
timestamp: Instant::now(),
search_prompt: String::new(),
},
))
.ok();
}
}
RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
mean_definition_latency,
max_definition_latency,
} => {
if let Some(debug_tx) = this.debug_tx.clone() {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalFinished(
ContextRetrievalFinishedDebugEvent {
project_entity_id: entity_id,
timestamp: Instant::now(),
metadata: vec![
(
"Cache Hits",
format!(
"{}/{}",
cache_hit_count,
cache_hit_count + cache_miss_count
)
.into(),
),
(
"Max LSP Time",
format!(
"{} ms",
max_definition_latency.as_millis()
)
.into(),
),
(
"Mean LSP Time",
format!(
"{} ms",
mean_definition_latency.as_millis()
)
.into(),
),
],
},
))
.ok();
}
if let Some(project_state) = this.projects.get(&entity_id) {
project_state.context_updates_tx.send_blocking(()).ok();
}
}
},
)
cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
this.handle_excerpt_store_event(entity_id, event);
})
.detach();
related_excerpt_store
},
events: VecDeque::new(),
last_event: None,
recent_paths: VecDeque::new(),
context_updates_rx,
context_updates_tx,
debug_tx: None,
registered_buffers: HashMap::default(),
current_prediction: None,
cancelled_predictions: HashSet::default(),
@@ -689,12 +649,79 @@ impl EditPredictionStore {
})
}
pub fn project_context_updates(
&self,
pub fn remove_project(&mut self, project: &Entity<Project>) {
self.projects.remove(&project.entity_id());
}
fn handle_excerpt_store_event(
&mut self,
project_entity_id: EntityId,
event: &RelatedExcerptStoreEvent,
) {
if let Some(project_state) = self.projects.get(&project_entity_id) {
if let Some(debug_tx) = project_state.debug_tx.clone() {
match event {
RelatedExcerptStoreEvent::StartedRefresh => {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalStarted(
ContextRetrievalStartedDebugEvent {
project_entity_id: project_entity_id,
timestamp: Instant::now(),
search_prompt: String::new(),
},
))
.ok();
}
RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
mean_definition_latency,
max_definition_latency,
} => {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalFinished(
ContextRetrievalFinishedDebugEvent {
project_entity_id: project_entity_id,
timestamp: Instant::now(),
metadata: vec![
(
"Cache Hits",
format!(
"{}/{}",
cache_hit_count,
cache_hit_count + cache_miss_count
)
.into(),
),
(
"Max LSP Time",
format!("{} ms", max_definition_latency.as_millis())
.into(),
),
(
"Mean LSP Time",
format!("{} ms", mean_definition_latency.as_millis())
.into(),
),
],
},
))
.ok();
}
}
}
}
}
pub fn debug_info(
&mut self,
project: &Entity<Project>,
) -> Option<smol::channel::Receiver<()>> {
let project_state = self.projects.get(&project.entity_id())?;
Some(project_state.context_updates_rx.clone())
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<DebugEvent> {
let project_state = self.get_or_init_project(project, cx);
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
project_state.debug_tx = Some(debug_watch_tx);
debug_watch_rx
}
fn handle_project_event(
@@ -1348,6 +1375,7 @@ impl EditPredictionStore {
let project_state = self.projects.get(&project.entity_id()).unwrap();
let events = project_state.events(cx);
let has_events = !events.is_empty();
let debug_tx = project_state.debug_tx.clone();
let snapshot = active_buffer.read(cx).snapshot();
let cursor_point = position.to_point(&snapshot);
@@ -1357,55 +1385,29 @@ impl EditPredictionStore {
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
let related_files = if self.use_context {
self.context_for_project(&project, cx).to_vec()
self.context_for_project(&project, cx)
} else {
Vec::new()
Vec::new().into()
};
let inputs = EditPredictionModelInput {
project: project.clone(),
buffer: active_buffer.clone(),
snapshot: snapshot.clone(),
position,
events,
related_files,
recent_paths: project_state.recent_paths.clone(),
trigger,
diagnostic_search_range: diagnostic_search_range.clone(),
debug_tx,
};
let task = match self.edit_prediction_model {
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
self,
&project,
&active_buffer,
snapshot.clone(),
position,
events,
trigger,
cx,
),
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
self,
&project,
&active_buffer,
snapshot.clone(),
position,
events,
related_files,
trigger,
cx,
),
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
&project,
&active_buffer,
snapshot.clone(),
position,
events,
&project_state.recent_paths,
related_files,
diagnostic_search_range.clone(),
cx,
),
EditPredictionModel::Mercury => self.mercury.request_prediction(
&project,
&active_buffer,
snapshot.clone(),
position,
events,
&project_state.recent_paths,
related_files,
diagnostic_search_range.clone(),
cx,
),
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
};
cx.spawn(async move |this, cx| {
@@ -1706,6 +1708,20 @@ impl EditPredictionStore {
}
}
#[cfg(feature = "eval-support")]
pub fn set_context_for_buffer(
&mut self,
project: &Entity<Project>,
related_files: Vec<RelatedFile>,
cx: &mut Context<Self>,
) {
self.get_or_init_project(project, cx)
.context
.update(cx, |store, _| {
store.set_related_files(related_files);
});
}
fn is_file_open_source(
&self,
project: &Entity<Project>,
@@ -1729,14 +1745,14 @@ impl EditPredictionStore {
self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
}
fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
if !self.data_collection_choice.is_enabled() {
return false;
}
events.iter().all(|event| {
matches!(
event.as_ref(),
Event::BufferChange {
zeta_prompt::Event::BufferChange {
in_open_source_repo: true,
..
}

View File

@@ -1,5 +1,5 @@
use super::*;
use crate::zeta1::MAX_EVENT_TOKENS;
use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
use client::{UserStore, test::FakeServer};
use clock::{FakeSystemClock, ReplicaId};
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -7,7 +7,6 @@ use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
RejectEditPredictionsBody,
};
use edit_prediction_context::Line;
use futures::{
AsyncReadExt, StreamExt,
channel::{mpsc, oneshot},
@@ -28,6 +27,7 @@ use settings::SettingsStore;
use std::{path::Path, sync::Arc, time::Duration};
use util::{path, rel_path::rel_path};
use uuid::Uuid;
use zeta_prompt::ZetaPromptInput;
use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
@@ -65,18 +65,21 @@ async fn test_current_state(cx: &mut TestAppContext) {
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
let (_request, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(indoc! {r"
--- a/root/1.txt
+++ b/root/1.txt
@@ ... @@
Hello!
-How
+How are you?
Bye
"}))
.send(model_response(
request,
indoc! {r"
--- a/root/1.txt
+++ b/root/1.txt
@@ ... @@
Hello!
-How
+How are you?
Bye
"},
))
.unwrap();
cx.run_until_parked();
@@ -120,16 +123,20 @@ async fn test_current_state(cx: &mut TestAppContext) {
});
});
let (_request, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(indoc! {r#"
--- a/root/2.txt
+++ b/root/2.txt
Hola!
-Como
+Como estas?
Adios
"#}))
.send(model_response(
request,
indoc! {r#"
--- a/root/2.txt
+++ b/root/2.txt
@@ ... @@
Hola!
-Como
+Como estas?
Adios
"#},
))
.unwrap();
cx.run_until_parked();
@@ -186,7 +193,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
// TODO Put back when we have a structured request again
// assert_eq!(
@@ -202,15 +209,18 @@ async fn test_simple_request(cx: &mut TestAppContext) {
// );
respond_tx
.send(model_response(indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"}))
.send(model_response(
request,
indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"},
))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -276,15 +286,18 @@ async fn test_request_events(cx: &mut TestAppContext) {
);
respond_tx
.send(model_response(indoc! {r#"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"#}))
.send(model_response(
request,
indoc! {r#"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"#},
))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -324,18 +337,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
const NO_OP_DIFF: &str = indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How
Bye
"};
let (_, respond_tx) = requests.predict.next().await.unwrap();
let response = model_response(NO_OP_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let response = model_response(request, "");
let id = response.id.clone();
respond_tx.send(response).unwrap();
@@ -389,13 +392,13 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text("Hello!\nHow are you?\nBye", cx);
});
let response = model_response(SIMPLE_DIFF);
let response = model_response(request, SIMPLE_DIFF);
let id = response.id.clone();
respond_tx.send(response).unwrap();
@@ -459,8 +462,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(SIMPLE_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
@@ -482,8 +485,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let second_response = model_response(SIMPLE_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
@@ -541,8 +544,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(SIMPLE_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
@@ -564,17 +567,20 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
// worse than current prediction
let second_response = model_response(indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are
Bye
"});
let second_response = model_response(
request,
indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are
Bye
"},
);
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
@@ -633,19 +639,19 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_first) = requests.predict.next().await.unwrap();
let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_second) = requests.predict.next().await.unwrap();
let (request, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle
cx.run_until_parked();
// second responds first
let second_response = model_response(SIMPLE_DIFF);
let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_second.send(second_response).unwrap();
@@ -663,7 +669,7 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
);
});
let first_response = model_response(SIMPLE_DIFF);
let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
@@ -724,13 +730,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_first) = requests.predict.next().await.unwrap();
let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_second) = requests.predict.next().await.unwrap();
let (request2, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle, so requests are sent
cx.run_until_parked();
@@ -754,9 +760,9 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
// wait for throttle
cx.run_until_parked();
let (_, respond_third) = requests.predict.next().await.unwrap();
let (request3, respond_third) = requests.predict.next().await.unwrap();
let first_response = model_response(SIMPLE_DIFF);
let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
@@ -774,7 +780,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
let cancelled_response = model_response(SIMPLE_DIFF);
let cancelled_response = model_response(request2, SIMPLE_DIFF);
let cancelled_id = cancelled_response.id.clone();
respond_second.send(cancelled_response).unwrap();
@@ -792,7 +798,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
let third_response = model_response(SIMPLE_DIFF);
let third_response = model_response(request3, SIMPLE_DIFF);
let third_response_id = third_response.id.clone();
respond_third.send(third_response).unwrap();
@@ -1036,7 +1042,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
// );
// }
fn model_response(text: &str) -> open_ai::Response {
// Generate a model response that would apply the given diff to the active file.
fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
let prompt = match &request.messages[0] {
open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(content),
} => content,
_ => panic!("unexpected request {request:?}"),
};
let open = "<editable_region>\n";
let close = "</editable_region>";
let cursor = "<|user_cursor|>";
let start_ix = open.len() + prompt.find(open).unwrap();
let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
open_ai::Response {
id: Uuid::new_v4().to_string(),
object: "response".into(),
@@ -1045,7 +1068,7 @@ fn model_response(text: &str) -> open_ai::Response {
choices: vec![open_ai::Choice {
index: 0,
message: open_ai::RequestMessage::Assistant {
content: Some(open_ai::MessageContent::Plain(text.to_string())),
content: Some(open_ai::MessageContent::Plain(new_excerpt)),
tool_calls: vec![],
},
finish_reason: None,
@@ -1160,20 +1183,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
.await;
let completion = EditPrediction {
let prediction = EditPrediction {
edits,
edit_preview,
buffer: buffer.clone(),
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
id: EditPredictionId("the-id".into()),
inputs: EditPredictionInputs {
inputs: ZetaPromptInput {
events: Default::default(),
included_files: Default::default(),
cursor_point: cloud_llm_client::predict_edits_v3::Point {
line: Line(0),
column: 0,
},
related_files: Default::default(),
cursor_path: Path::new("").into(),
cursor_excerpt: "".into(),
editable_range_in_excerpt: 0..0,
cursor_offset_in_excerpt: 0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
@@ -1182,7 +1204,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
cx.update(|cx| {
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1192,7 +1214,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1202,7 +1224,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1212,7 +1234,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1222,7 +1244,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1232,7 +1254,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1242,7 +1264,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1252,7 +1274,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1260,7 +1282,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
})
}

View File

@@ -1,20 +1,17 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
use project::{Project, ProjectPath};
use std::{
collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
};
use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
use zeta_prompt::ZetaPromptInput;
use crate::{
EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
prediction::EditPredictionResult,
};
@@ -38,16 +35,17 @@ impl Mercury {
store_api_token_in_keychain(api_token, cx)
}
pub fn request_prediction(
pub(crate) fn request_prediction(
&self,
_project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
_recent_paths: &VecDeque<ProjectPath>,
related_files: Vec<RelatedFile>,
_diagnostic_search_range: Range<Point>,
EditPredictionModelInput {
buffer,
snapshot,
position,
events,
related_files,
debug_tx,
..
}: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
@@ -62,6 +60,7 @@ impl Mercury {
let http_client = cx.http_client();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let active_buffer = buffer.clone();
let result = cx.background_spawn(async move {
let (editable_range, context_range) =
@@ -72,39 +71,39 @@ impl Mercury {
MAX_REWRITE_TOKENS,
);
let offset_range = editable_range.to_offset(&snapshot);
let prompt = build_prompt(
&events,
&related_files,
&snapshot,
full_path.as_ref(),
cursor_point,
editable_range,
context_range.clone(),
);
let context_offset_range = context_range.to_offset(&snapshot);
let inputs = EditPredictionInputs {
events: events,
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(
context_range.start.row,
),
text: snapshot
.text_for_range(context_range.clone())
.collect::<String>()
.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
let editable_offset_range = editable_range.to_offset(&snapshot);
let inputs = zeta_prompt::ZetaPromptInput {
events,
related_files,
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
- context_range.start.to_offset(&snapshot),
cursor_path: full_path.clone(),
cursor_excerpt: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
editable_range_in_excerpt: (editable_offset_range.start
- context_offset_range.start)
..(editable_offset_range.end - context_offset_range.start),
};
let prompt = build_prompt(&inputs);
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: active_buffer.downgrade(),
prompt: Some(prompt.clone()),
position,
},
))
.ok();
}
let request_body = open_ai::Request {
model: "mercury-coder".into(),
messages: vec![open_ai::RequestMessage::User {
@@ -160,6 +159,18 @@ impl Mercury {
let id = mem::take(&mut response.id);
let response_str = text_from_response(response).unwrap_or_default();
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
EditPredictionFinishedDebugEvent {
buffer: active_buffer.downgrade(),
model_output: Some(response_str.clone()),
position,
},
))
.ok();
}
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
@@ -168,15 +179,16 @@ impl Mercury {
if response_str != NO_PREDICTION_OUTPUT {
let old_text = snapshot
.text_for_range(offset_range.clone())
.text_for_range(editable_offset_range.clone())
.collect::<String>();
edits.extend(
language::text_diff(&old_text, &response_str)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(offset_range.start + range.start)
..snapshot.anchor_before(offset_range.start + range.end),
snapshot.anchor_after(editable_offset_range.start + range.start)
..snapshot
.anchor_before(editable_offset_range.start + range.end),
text,
)
}),
@@ -186,8 +198,6 @@ impl Mercury {
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
});
let buffer = active_buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) =
result.await.context("Mercury edit prediction failed")?;
@@ -208,15 +218,7 @@ impl Mercury {
}
}
fn build_prompt(
events: &[Arc<Event>],
related_files: &[RelatedFile],
cursor_buffer: &BufferSnapshot,
cursor_buffer_path: &Path,
cursor_point: Point,
editable_range: Range<Point>,
context_range: Range<Point>,
) -> String {
fn build_prompt(inputs: &ZetaPromptInput) -> String {
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
@@ -237,14 +239,14 @@ fn build_prompt(
&mut prompt,
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|prompt| {
for related_file in related_files {
for related_file in inputs.related_files.iter() {
for related_excerpt in &related_file.excerpts {
push_delimited(
prompt,
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|prompt| {
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
prompt.push_str(related_file.path.path.as_unix_str());
prompt.push_str(related_file.path.to_string_lossy().as_ref());
prompt.push('\n');
prompt.push_str(&related_excerpt.text.to_string());
},
@@ -259,21 +261,22 @@ fn build_prompt(
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|prompt| {
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
prompt.push('\n');
let prefix_range = context_range.start..editable_range.start;
let suffix_range = editable_range.end..context_range.end;
prompt.extend(cursor_buffer.text_for_range(prefix_range));
prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
let range_before_cursor = editable_range.start..cursor_point;
let range_after_cursor = cursor_point..editable_range.end;
prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
prompt.push_str(
&inputs.cursor_excerpt
[inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
);
prompt.push_str(CURSOR_TAG);
prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
prompt.push_str(
&inputs.cursor_excerpt
[inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
);
});
prompt.extend(cursor_buffer.text_for_range(suffix_range));
prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
},
);
@@ -281,8 +284,8 @@ fn build_prompt(
&mut prompt,
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|prompt| {
for event in events {
writeln!(prompt, "{event}").unwrap();
for event in inputs.events.iter() {
zeta_prompt::write_event(prompt, &event);
}
},
);

View File

@@ -1,6 +1,5 @@
use std::{
ops::Range,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
@@ -9,7 +8,7 @@ use cloud_llm_client::EditPredictionRejectReason;
use edit_prediction_types::interpolate_edits;
use gpui::{AsyncApp, Entity, SharedString};
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
use serde::Serialize;
use zeta_prompt::ZetaPromptInput;
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(pub SharedString);
@@ -40,7 +39,7 @@ impl EditPredictionResult {
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
buffer_snapshotted_at: Instant,
response_received_at: Instant,
inputs: EditPredictionInputs,
inputs: ZetaPromptInput,
cx: &mut AsyncApp,
) -> Self {
if edits.is_empty() {
@@ -94,15 +93,7 @@ pub struct EditPrediction {
pub buffer: Entity<Buffer>,
pub buffer_snapshotted_at: Instant,
pub response_received_at: Instant,
pub inputs: EditPredictionInputs,
}
#[derive(Debug, Clone, Serialize)]
pub struct EditPredictionInputs {
pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
pub cursor_path: Arc<Path>,
pub inputs: zeta_prompt::ZetaPromptInput,
}
impl EditPrediction {
@@ -133,9 +124,12 @@ impl std::fmt::Debug for EditPrediction {
#[cfg(test)]
mod tests {
use std::path::Path;
use super::*;
use gpui::{App, Entity, TestAppContext, prelude::*};
use language::{Buffer, ToOffset as _};
use zeta_prompt::ZetaPromptInput;
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
@@ -154,14 +148,13 @@ mod tests {
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,
inputs: EditPredictionInputs {
inputs: ZetaPromptInput {
events: vec![],
included_files: vec![],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
line: cloud_llm_client::predict_edits_v3::Line(0),
column: 0,
},
related_files: vec![].into(),
cursor_path: Path::new("path.txt").into(),
cursor_offset_in_excerpt: 0,
cursor_excerpt: "".into(),
editable_range_in_excerpt: 0..0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),

View File

@@ -1,26 +1,21 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
use language::{Point, ToOffset as _};
use lsp::DiagnosticSeverity;
use project::{Project, ProjectPath};
use serde::{Deserialize, Serialize};
use std::{
collections::VecDeque,
fmt::{self, Write as _},
ops::Range,
path::Path,
sync::Arc,
time::Instant,
};
use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
@@ -44,40 +39,34 @@ impl SweepAi {
pub fn request_prediction_with_sweep(
&self,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
recent_paths: &VecDeque<ProjectPath>,
related_files: Vec<RelatedFile>,
diagnostic_search_range: Range<Point>,
inputs: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let debug_info = self.debug_info.clone();
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = snapshot
let full_path: Arc<Path> = inputs
.snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
let project_file = project::File::from_dyn(snapshot.file());
let project_file = project::File::from_dyn(inputs.snapshot.file());
let repo_name = project_file
.map(|file| file.worktree.read(cx).root_name_str())
.unwrap_or("untitled")
.into();
let offset = position.to_offset(&snapshot);
let offset = inputs.position.to_offset(&inputs.snapshot);
let recent_buffers = recent_paths.iter().cloned();
let recent_buffers = inputs.recent_paths.iter().cloned();
let http_client = cx.http_client();
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
if active_buffer == &buffer {
let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
if inputs.buffer == buffer {
None
} else {
Some(buffer.read(cx).snapshot())
@@ -86,14 +75,13 @@ impl SweepAi {
.take(3)
.collect::<Vec<_>>();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let result = cx.background_spawn(async move {
let text = snapshot.text();
let text = inputs.snapshot.text();
let mut recent_changes = String::new();
for event in &events {
for event in &inputs.events {
write_event(event.as_ref(), &mut recent_changes).unwrap();
}
@@ -122,20 +110,23 @@ impl SweepAi {
})
.collect::<Vec<_>>();
let retrieval_chunks = related_files
let retrieval_chunks = inputs
.related_files
.iter()
.flat_map(|related_file| {
related_file.excerpts.iter().map(|excerpt| FileChunk {
file_path: related_file.path.path.as_unix_str().to_string(),
start_line: excerpt.point_range.start.row as usize,
end_line: excerpt.point_range.end.row as usize,
file_path: related_file.path.to_string_lossy().to_string(),
start_line: excerpt.row_range.start as usize,
end_line: excerpt.row_range.end as usize,
content: excerpt.text.to_string(),
timestamp: None,
})
})
.collect();
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
let diagnostic_entries = inputs
.snapshot
.diagnostics_in_range(inputs.diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
@@ -195,21 +186,14 @@ impl SweepAi {
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
let inputs = EditPredictionInputs {
events,
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(0),
text: request_body.file_contents.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
let ep_inputs = zeta_prompt::ZetaPromptInput {
events: inputs.events,
related_files: inputs.related_files.clone(),
cursor_path: full_path.clone(),
cursor_excerpt: request_body.file_contents.into(),
// we actually don't know
editable_range_in_excerpt: 0..inputs.snapshot.len(),
cursor_offset_in_excerpt: request_body.cursor_position,
};
let request = http_client::Request::builder()
@@ -237,15 +221,20 @@ impl SweepAi {
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
let old_text = snapshot
let old_text = inputs
.snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
inputs
.snapshot
.anchor_after(response.start_index + range.start)
..inputs
.snapshot
.anchor_before(response.start_index + range.end),
text,
)
})
@@ -254,13 +243,13 @@ impl SweepAi {
anyhow::Ok((
response.autocomplete_id,
edits,
snapshot,
inputs.snapshot,
response_received_at,
inputs,
ep_inputs,
))
});
let buffer = active_buffer.clone();
let buffer = inputs.buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
@@ -403,12 +392,9 @@ struct AdditionalCompletion {
pub finish_reason: Option<String>,
}
fn write_event(
event: &cloud_llm_client::predict_edits_v3::Event,
f: &mut impl fmt::Write,
) -> fmt::Result {
fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
match event {
cloud_llm_client::predict_edits_v3::Event::BufferChange {
zeta_prompt::Event::BufferChange {
old_path,
path,
diff,

View File

@@ -14,68 +14,18 @@ use anyhow::anyhow;
use collections::HashMap;
use gpui::AsyncApp;
use gpui::Entity;
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
use project::Project;
pub async fn parse_diff<'a>(
diff_str: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
let mut diff = DiffParser::new(diff_str);
let mut edited_buffer = None;
let mut edits = Vec::new();
while let Some(event) = diff.next()? {
match event {
DiffEvent::Hunk {
path: file_path,
hunk,
} => {
let (buffer, ranges) = match edited_buffer {
None => {
edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
edited_buffer
.as_ref()
.context("Model tried to edit a file that wasn't included")?
}
Some(ref current) => current,
};
edits.extend(
resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
.with_context(|| format!("Diff:\n{diff_str}"))?,
);
}
DiffEvent::FileEnd { renamed_to } => {
let (buffer, _) = edited_buffer
.take()
.context("Got a FileEnd event before an Hunk event")?;
if renamed_to.is_some() {
anyhow::bail!("edit predictions cannot rename files");
}
if diff.next()?.is_some() {
anyhow::bail!("Edited more than one file");
}
return Ok((buffer, edits));
}
}
}
Err(anyhow::anyhow!("No EOF"))
}
#[derive(Debug)]
pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffer>>);
#[derive(Clone, Debug)]
pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
#[must_use]
pub async fn apply_diff<'a>(
diff_str: &'a str,
pub async fn apply_diff(
diff_str: &str,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'a>> {
) -> Result<OpenedBuffers> {
let mut included_files = HashMap::default();
for line in diff_str.lines() {
@@ -94,7 +44,7 @@ pub async fn apply_diff<'a>(
})??
.await?;
included_files.insert(path, buffer);
included_files.insert(path.to_string(), buffer);
}
}
@@ -113,7 +63,7 @@ pub async fn apply_diff<'a>(
let (buffer, ranges) = match current_file {
None => {
let buffer = included_files
.get_mut(&file_path)
.get_mut(file_path.as_ref())
.expect("Opened all files in diff");
current_file = Some((buffer, ranges.as_slice()));
@@ -167,6 +117,29 @@ pub async fn apply_diff<'a>(
Ok(OpenedBuffers(included_files))
}
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
let mut diff = DiffParser::new(diff_str);
let mut text = text.to_string();
while let Some(event) = diff.next()? {
match event {
DiffEvent::Hunk { hunk, .. } => {
let hunk_offset = text
.find(&hunk.context)
.ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
for edit in hunk.edits.iter().rev() {
let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
text.replace_range(range, &edit.text);
}
}
DiffEvent::FileEnd { .. } => {}
}
}
Ok(text)
}
struct PatchFile<'a> {
old_path: Cow<'a, str>,
new_path: Cow<'a, str>,
@@ -492,7 +465,6 @@ mod tests {
use super::*;
use gpui::TestAppContext;
use indoc::indoc;
use language::Point;
use pretty_assertions::assert_eq;
use project::{FakeFs, Project};
use serde_json::json;
@@ -817,137 +789,6 @@ mod tests {
});
}
#[gpui::test]
async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
let fs = init_test(cx);
let buffer_1_text = indoc! {r#"
one
two
three
four
five
one
two
three
four
five
"# };
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/file1"), cx)
})
.await
.unwrap();
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
one
two
-three
+3
four
five
"#};
let final_text = indoc! {r#"
one
two
three
four
five
one
two
3
four
five
"#};
apply_diff(diff, &project, &mut cx.to_async())
.await
.expect_err("Non-unique edits should fail");
let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
.await
.unwrap();
assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
buffer.update(cx, |buffer, cx| {
buffer.edit(edits, None, cx);
assert_eq!(buffer.text(), final_text);
});
}
#[gpui::test]
async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
let fs = init_test(cx);
let buffer_1_text = indoc! {r#"
one two three four
five six seven eight
nine ten eleven twelve
"# };
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/file1"), cx)
})
.await
.unwrap();
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
one two three four
-five six seven eight
+five SIX seven eight!
nine ten eleven twelve
"#};
let (buffer, edits) = parse_diff(diff, |_path| {
Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
})
.await
.unwrap();
let edits = edits
.into_iter()
.map(|(range, text)| (range.to_point(&buffer), text))
.collect::<Vec<_>>();
assert_eq!(
edits,
&[
(Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
(Point::new(1, 20)..Point::new(1, 20), "!".into())
]
);
}
#[gpui::test]
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
let fs = init_test(cx);

View File

@@ -1,637 +0,0 @@
use anyhow::{Context as _, Result};
use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
use std::{cmp, ops::Range, path::Path, sync::Arc};
const EDITS_TAG_NAME: &'static str = "edits";
const OLD_TEXT_TAG_NAME: &'static str = "old_text";
const NEW_TEXT_TAG_NAME: &'static str = "new_text";
const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
pub async fn parse_xml_edits<'a>(
input: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
parse_xml_edits_inner(input, get_buffer)
.await
.with_context(|| format!("Failed to parse XML edits:\n{input}"))
}
async fn parse_xml_edits_inner<'a>(
input: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
let xml_edits = extract_xml_replacements(input)?;
let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
.with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
let mut all_edits = vec![];
for (old_text, new_text) in xml_edits.replacements {
let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
let matched_old_text = buffer
.text_for_range(match_range.clone())
.collect::<String>();
let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
all_edits.extend(
edits_within_hunk
.into_iter()
.map(move |(inner_range, inner_text)| {
(
buffer.anchor_after(match_range.start + inner_range.start)
..buffer.anchor_before(match_range.start + inner_range.end),
inner_text,
)
}),
);
}
Ok((buffer, all_edits))
}
fn fuzzy_match_in_ranges(
old_text: &str,
buffer: &BufferSnapshot,
context_ranges: &[Range<Anchor>],
) -> Result<Range<usize>> {
let mut state = FuzzyMatcher::new(buffer, old_text);
let mut best_match = None;
let mut tie_match_range = None;
for range in context_ranges {
let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
match (best_match_cost, state.match_range(range.to_offset(buffer))) {
(Some(lowest_cost), Some((new_cost, new_range))) => {
if new_cost == lowest_cost {
tie_match_range = Some(new_range);
} else if new_cost < lowest_cost {
tie_match_range.take();
best_match = Some((new_cost, new_range));
}
}
(None, Some(new_match)) => {
best_match = Some(new_match);
}
(None, None) | (Some(_), None) => {}
};
}
if let Some((_, best_match_range)) = best_match {
if let Some(tie_match_range) = tie_match_range {
anyhow::bail!(
"Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
best_match_range.clone(),
buffer.text_for_range(best_match_range).collect::<String>(),
tie_match_range.clone(),
buffer.text_for_range(tie_match_range).collect::<String>()
);
}
return Ok(best_match_range);
}
anyhow::bail!(
"Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
old_text,
context_ranges
.iter()
.map(|range| buffer.text_for_range(range.clone()).collect::<String>())
.collect::<Vec<String>>()
.join("```\n```")
);
}
#[derive(Debug)]
struct XmlEdits<'a> {
file_path: &'a str,
/// Vec of (old_text, new_text) pairs
replacements: Vec<(&'a str, &'a str)>,
}
fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
let mut cursor = 0;
let (edits_body_start, edits_attrs) =
find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
let file_path = edits_attrs
.trim_start()
.strip_prefix("path")
.context("no path attribute on edits tag")?
.trim_end()
.strip_prefix('=')
.context("no value for path attribute")?
.trim()
.trim_start_matches('"')
.trim_end_matches('"');
cursor = edits_body_start;
let mut edits_list = Vec::new();
while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
let old_body_end = find_tag_close(input, &mut cursor)?;
let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
.context("no new_text tag following old_text")?;
let new_body_end = find_tag_close(input, &mut cursor)?;
let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
edits_list.push((old_text, new_text));
}
Ok(XmlEdits {
file_path,
replacements: edits_list,
})
}
/// Trims a single leading and trailing newline
fn trim_surrounding_newlines(input: &str) -> &str {
let start = input.strip_prefix('\n').unwrap_or(input);
let end = start.strip_suffix('\n').unwrap_or(start);
end
}
fn find_tag_open<'a>(
input: &'a str,
cursor: &mut usize,
expected_tag: &str,
) -> Result<Option<(usize, &'a str)>> {
let mut search_pos = *cursor;
while search_pos < input.len() {
let Some(tag_start) = input[search_pos..].find("<") else {
break;
};
let tag_start = search_pos + tag_start;
if !input[tag_start + 1..].starts_with(expected_tag) {
search_pos = search_pos + tag_start + 1;
continue;
};
let after_tag_name = tag_start + expected_tag.len() + 1;
let close_bracket = input[after_tag_name..]
.find('>')
.with_context(|| format!("missing > after <{}", expected_tag))?;
let attrs_end = after_tag_name + close_bracket;
let body_start = attrs_end + 1;
let attributes = input[after_tag_name..attrs_end].trim();
*cursor = body_start;
return Ok(Some((body_start, attributes)));
}
Ok(None)
}
fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
let mut depth = 1;
let mut search_pos = *cursor;
while search_pos < input.len() && depth > 0 {
let Some(bracket_offset) = input[search_pos..].find('<') else {
break;
};
let bracket_pos = search_pos + bracket_offset;
if input[bracket_pos..].starts_with("</")
&& let Some(close_end) = input[bracket_pos + 2..].find('>')
{
let close_start = bracket_pos + 2;
let tag_name = input[close_start..close_start + close_end].trim();
if XML_TAGS.contains(&tag_name) {
depth -= 1;
if depth == 0 {
*cursor = close_start + close_end + 1;
return Ok(bracket_pos);
}
}
search_pos = close_start + close_end + 1;
continue;
} else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
let close_bracket_pos = bracket_pos + close_bracket_offset;
let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
if XML_TAGS.contains(&tag_name) {
depth += 1;
}
}
search_pos = bracket_pos + 1;
}
anyhow::bail!("no closing tag found")
}
const REPLACEMENT_COST: u32 = 1;
const INSERTION_COST: u32 = 3;
const DELETION_COST: u32 = 10;
/// A fuzzy matcher that can process text chunks incrementally
/// and return the best match found so far at each step.
struct FuzzyMatcher<'a> {
snapshot: &'a BufferSnapshot,
query_lines: Vec<&'a str>,
matrix: SearchMatrix,
}
impl<'a> FuzzyMatcher<'a> {
fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
let query_lines = old_text.lines().collect();
Self {
snapshot,
query_lines,
matrix: SearchMatrix::new(0),
}
}
fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
let point_range = range.to_point(&self.snapshot);
let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
self.matrix
.reset(self.query_lines.len() + 1, buffer_line_count + 1);
let query_line_count = self.query_lines.len();
for row in 0..query_line_count {
let query_line = self.query_lines[row].trim();
let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
self.matrix.set(
row + 1,
0,
SearchState::new(leading_deletion_cost, SearchDirection::Up),
);
let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
let mut col = 0;
while let Some(buffer_line) = buffer_lines.next() {
let buffer_line = buffer_line.trim();
let up = SearchState::new(
self.matrix
.get(row, col + 1)
.cost
.saturating_add(DELETION_COST),
SearchDirection::Up,
);
let left = SearchState::new(
self.matrix
.get(row + 1, col)
.cost
.saturating_add(INSERTION_COST),
SearchDirection::Left,
);
let diagonal = SearchState::new(
if query_line == buffer_line {
self.matrix.get(row, col).cost
} else if fuzzy_eq(query_line, buffer_line) {
self.matrix.get(row, col).cost + REPLACEMENT_COST
} else {
self.matrix
.get(row, col)
.cost
.saturating_add(DELETION_COST + INSERTION_COST)
},
SearchDirection::Diagonal,
);
self.matrix
.set(row + 1, col + 1, up.min(left).min(diagonal));
col += 1;
}
}
// Find all matches with the best cost
let mut best_cost = u32::MAX;
let mut matches_with_best_cost = Vec::new();
for col in 1..=buffer_line_count {
let cost = self.matrix.get(query_line_count, col).cost;
if cost < best_cost {
best_cost = cost;
matches_with_best_cost.clear();
matches_with_best_cost.push(col as u32);
} else if cost == best_cost {
matches_with_best_cost.push(col as u32);
}
}
// Find ranges for the matches
for &match_end_col in &matches_with_best_cost {
let mut matched_lines = 0;
let mut query_row = query_line_count;
let mut match_start_col = match_end_col;
while query_row > 0 && match_start_col > 0 {
let current = self.matrix.get(query_row, match_start_col as usize);
match current.direction {
SearchDirection::Diagonal => {
query_row -= 1;
match_start_col -= 1;
matched_lines += 1;
}
SearchDirection::Up => {
query_row -= 1;
}
SearchDirection::Left => {
match_start_col -= 1;
}
}
}
let buffer_row_start = match_start_col + point_range.start.row;
let buffer_row_end = match_end_col + point_range.start.row;
let matched_buffer_row_count = buffer_row_end - buffer_row_start;
let matched_ratio = matched_lines as f32
/ (matched_buffer_row_count as f32).max(query_line_count as f32);
if matched_ratio >= 0.8 {
let buffer_start_ix = self
.snapshot
.point_to_offset(Point::new(buffer_row_start, 0));
let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
buffer_row_end - 1,
self.snapshot.line_len(buffer_row_end - 1),
));
return Some((best_cost, buffer_start_ix..buffer_end_ix));
}
}
None
}
}
fn fuzzy_eq(left: &str, right: &str) -> bool {
const THRESHOLD: f64 = 0.8;
let min_levenshtein = left.len().abs_diff(right.len());
let min_normalized_levenshtein =
1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
if min_normalized_levenshtein < THRESHOLD {
return false;
}
strsim::normalized_levenshtein(left, right) >= THRESHOLD
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum SearchDirection {
Up,
Left,
Diagonal,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SearchState {
cost: u32,
direction: SearchDirection,
}
impl SearchState {
fn new(cost: u32, direction: SearchDirection) -> Self {
Self { cost, direction }
}
}
struct SearchMatrix {
cols: usize,
rows: usize,
data: Vec<SearchState>,
}
impl SearchMatrix {
fn new(cols: usize) -> Self {
SearchMatrix {
cols,
rows: 0,
data: Vec::new(),
}
}
fn reset(&mut self, rows: usize, cols: usize) {
self.rows = rows;
self.cols = cols;
self.data
.fill(SearchState::new(0, SearchDirection::Diagonal));
self.data.resize(
self.rows * self.cols,
SearchState::new(0, SearchDirection::Diagonal),
);
}
fn get(&self, row: usize, col: usize) -> SearchState {
debug_assert!(row < self.rows);
debug_assert!(col < self.cols);
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, state: SearchState) {
debug_assert!(row < self.rows && col < self.cols);
self.data[row * self.cols + col] = state;
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
use indoc::indoc;
use language::Point;
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
#[test]
fn test_extract_xml_edits() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
old content
</old_text>
<new_text>
new content
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "old content");
assert_eq!(result.replacements[0].1, "new content");
}
#[test]
fn test_extract_xml_edits_with_wrong_closing_tags() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
old content
</new_text>
<new_text>
new content
</old_text>
</ edits >
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "old content");
assert_eq!(result.replacements[0].1, "new content");
}
#[test]
fn test_extract_xml_edits_with_xml_like_content() {
let input = indoc! {r#"
<edits path="component.tsx">
<old_text>
<foo><bar></bar></foo>
</old_text>
<new_text>
<foo><bar><baz></baz></bar></foo>
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "component.tsx");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
assert_eq!(
result.replacements[0].1,
"<foo><bar><baz></baz></bar></foo>"
);
}
#[test]
fn test_extract_xml_edits_with_conflicting_content() {
let input = indoc! {r#"
<edits path="component.tsx">
<old_text>
<new_text></new_text>
</old_text>
<new_text>
<old_text></old_text>
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "component.tsx");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "<new_text></new_text>");
assert_eq!(result.replacements[0].1, "<old_text></old_text>");
}
#[test]
fn test_extract_xml_edits_multiple_pairs() {
let input = indoc! {r#"
Some reasoning before edits. Lots of thinking going on here
<edits path="test.rs">
<old_text>
first old
</old_text>
<new_text>
first new
</new_text>
<old_text>
second old
</edits>
<new_text>
second new
</old_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 2);
assert_eq!(result.replacements[0].0, "first old");
assert_eq!(result.replacements[0].1, "first new");
assert_eq!(result.replacements[1].0, "second old");
assert_eq!(result.replacements[1].1, "second new");
}
#[test]
fn test_extract_xml_edits_unexpected_eof() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
first old
</
"#};
extract_xml_replacements(input).expect_err("Unexpected end of file");
}
#[gpui::test]
async fn test_parse_xml_edits(cx: &mut TestAppContext) {
let fs = init_test(cx);
let buffer_1_text = indoc! {r#"
one two three four
five six seven eight
nine ten eleven twelve
thirteen fourteen fifteen
sixteen seventeen eighteen
"#};
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/file1"), cx)
})
.await
.unwrap();
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let edits = indoc! {r#"
<edits path="root/file1">
<old_text>
nine ten eleven twelve
</old_text>
<new_text>
nine TEN eleven twelve!
</new_text>
</edits>
"#};
let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
let (buffer, edits) = parse_xml_edits(edits, |_path| {
Some((&buffer_snapshot, included_ranges.as_slice()))
})
.await
.unwrap();
let edits = edits
.into_iter()
.map(|(range, text)| (range.to_point(&buffer), text))
.collect::<Vec<_>>();
assert_eq!(
edits,
&[
(Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
(Point::new(2, 22)..Point::new(2, 22), "!".into())
]
);
}
fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
});
FakeFs::new(cx.background_executor.clone())
}
}

View File

@@ -1,22 +1,23 @@
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
prediction::{EditPredictionInputs, EditPredictionResult},
prediction::EditPredictionResult,
};
use anyhow::{Context as _, Result};
use cloud_llm_client::{
PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
predict_edits_v3::Event,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
use language::{
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
};
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
use zeta_prompt::{Event, ZetaPromptInput};
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
@@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500;
pub(crate) fn request_prediction_with_zeta1(
store: &mut EditPredictionStore,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
trigger: PredictEditsRequestTrigger,
EditPredictionModelInput {
project,
buffer,
snapshot,
position,
events,
trigger,
debug_tx,
..
}: EditPredictionModelInput,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
let can_collect_file = store.can_collect_file(project, file, cx);
let can_collect_file = store.can_collect_file(&project, file, cx);
let git_info = if can_collect_file {
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
} else {
None
};
@@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1(
)
.await;
let inputs = EditPredictionInputs {
let context_start_offset = context_range.start.to_offset(&snapshot);
let editable_offset_range = editable_range.to_offset(&snapshot);
let inputs = ZetaPromptInput {
events: included_events.into(),
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
text: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
related_files: vec![].into(),
cursor_path: full_path,
cursor_excerpt: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
editable_range_in_excerpt: (editable_range.start - context_start_offset)
..(editable_offset_range.end - context_start_offset),
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
};
// let response = perform_predict_edits(PerformPredictEditsParams {
// client,
// llm_token,
// app_version,
// body,
// })
// .await;
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: buffer.downgrade(),
prompt: Some(serde_json::to_string(&inputs).unwrap()),
position,
},
))
.ok();
}
let (response, usage) = match response {
Ok(response) => response,
@@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1(
.ok();
}
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
EditPredictionFinishedDebugEvent {
buffer: buffer.downgrade(),
model_output: Some(response.output_excerpt.clone()),
position,
},
))
.ok();
}
let edit_prediction = process_completion_response(
response,
buffer,
@@ -226,7 +242,7 @@ fn process_completion_response(
buffer: Entity<Buffer>,
snapshot: &BufferSnapshot,
editable_range: Range<usize>,
inputs: EditPredictionInputs,
inputs: ZetaPromptInput,
buffer_snapshotted_at: Instant,
received_response_at: Instant,
cx: &AsyncApp,

View File

@@ -3,46 +3,39 @@ use crate::EvalCacheEntryKind;
use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
use crate::{
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
EditPredictionRequestedDebugEvent, EditPredictionStore,
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
};
use anyhow::{Result, anyhow, bail};
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
use cloud_zeta2_prompt::CURSOR_MARKER;
use edit_prediction_context::{EditPredictionExcerpt, Line};
use edit_prediction_context::{RelatedExcerpt, RelatedFile};
use futures::channel::oneshot;
use gpui::{Entity, Task, prelude::*};
use language::{Anchor, BufferSnapshot};
use language::{Buffer, Point, ToOffset as _, ToPoint};
use project::{Project, ProjectItem as _};
use anyhow::{Result, anyhow};
use cloud_llm_client::EditPredictionRejectReason;
use gpui::{Task, prelude::*};
use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
use release_channel::AppVersion;
use std::{
env,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
use std::{path::Path, sync::Arc, time::Instant};
use zeta_prompt::CURSOR_MARKER;
use zeta_prompt::format_zeta_prompt;
const MAX_CONTEXT_TOKENS: usize = 150;
const MAX_REWRITE_TOKENS: usize = 350;
pub fn request_prediction_with_zeta2(
store: &mut EditPredictionStore,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
active_snapshot: BufferSnapshot,
position: Anchor,
events: Vec<Arc<Event>>,
mut included_files: Vec<RelatedFile>,
trigger: PredictEditsRequestTrigger,
EditPredictionModelInput {
buffer,
snapshot,
position,
related_files,
events,
debug_tx,
..
}: EditPredictionModelInput,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let options = store.options.clone();
let buffer_snapshotted_at = Instant::now();
let Some((excerpt_path, active_project_path)) = active_snapshot
let Some(excerpt_path) = snapshot
.file()
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
.zip(active_buffer.read(cx).project_path(cx))
else {
return Task::ready(Err(anyhow!("No file path for excerpt")));
};
@@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2(
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let debug_tx = store.debug_tx.clone();
let file = active_buffer.read(cx).file();
let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
// TODO data collection
let can_collect_data = file
.as_ref()
.map_or(false, |file| store.can_collect_file(project, file, cx));
#[cfg(feature = "eval-support")]
let eval_cache = store.eval_cache.clone();
let request_task = cx.background_spawn({
let active_buffer = active_buffer.clone();
async move {
let cursor_offset = position.to_offset(&active_snapshot);
let cursor_point = cursor_offset.to_point(&active_snapshot);
let before_retrieval = Instant::now();
let excerpt_options = options.context;
let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
cursor_point,
&active_snapshot,
&excerpt_options,
) else {
return Ok((None, None));
};
let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
..active_snapshot.anchor_before(excerpt.range.end);
let related_excerpt = RelatedExcerpt {
anchor_range: excerpt_anchor_range.clone(),
point_range: Point::new(excerpt.line_range.start.0, 0)
..Point::new(excerpt.line_range.end.0, 0),
text: active_snapshot.as_rope().slice(excerpt.range),
};
if let Some(buffer_ix) = included_files
.iter()
.position(|file| file.buffer.entity_id() == active_buffer.entity_id())
{
let file = &mut included_files[buffer_ix];
file.excerpts.push(related_excerpt);
file.merge_excerpts();
let last_ix = included_files.len() - 1;
included_files.swap(buffer_ix, last_ix);
} else {
let active_file = RelatedFile {
path: active_project_path,
buffer: active_buffer.downgrade(),
excerpts: vec![related_excerpt],
max_row: active_snapshot.max_point().row,
};
included_files.push(active_file);
}
let included_files = included_files
.iter()
.map(|related_file| predict_edits_v3::RelatedFile {
path: Arc::from(related_file.path.path.as_std_path()),
max_row: Line(related_file.max_row),
excerpts: related_file
.excerpts
.iter()
.map(|excerpt| predict_edits_v3::Excerpt {
start_line: Line(excerpt.point_range.start.row),
text: excerpt.text.to_string().into(),
})
.collect(),
})
.collect::<Vec<_>>();
let cloud_request = predict_edits_v3::PredictEditsRequest {
excerpt_path,
excerpt: String::new(),
excerpt_line_range: Line(0)..Line(0),
excerpt_range: 0..0,
cursor_point: predict_edits_v3::Point {
line: predict_edits_v3::Line(cursor_point.row),
column: cursor_point.column,
},
related_files: included_files,
let cursor_offset = position.to_offset(&snapshot);
let (editable_offset_range, prompt_input) = zeta2_prompt_input(
&snapshot,
related_files,
events,
can_collect_data,
debug_info: debug_tx.is_some(),
prompt_max_bytes: Some(options.max_prompt_bytes),
prompt_format: options.prompt_format,
excerpt_parent: None,
git_info: None,
trigger,
};
excerpt_path,
cursor_offset,
);
let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
let inputs = EditPredictionInputs {
included_files: cloud_request.related_files,
events: cloud_request.events,
cursor_point: cloud_request.cursor_point,
cursor_path: cloud_request.excerpt_path,
};
let retrieval_time = Instant::now() - before_retrieval;
let debug_response_tx = if let Some(debug_tx) = &debug_tx {
let (response_tx, response_rx) = oneshot::channel();
let prompt = format_zeta_prompt(&prompt_input);
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionRequested(
EditPredictionRequestedDebugEvent {
inputs: inputs.clone(),
retrieval_time,
buffer: active_buffer.downgrade(),
local_prompt: match prompt_result.as_ref() {
Ok(prompt) => Ok(prompt.clone()),
Err(err) => Err(err.to_string()),
},
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: buffer.downgrade(),
prompt: Some(prompt.clone()),
position,
response_rx,
},
))
.ok();
Some(response_tx)
} else {
None
};
if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
if let Some(debug_response_tx) = debug_response_tx {
debug_response_tx
.send((Err("Request skipped".to_string()), Duration::ZERO))
.ok();
}
anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
}
let prompt = prompt_result?;
let generation_params =
cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
let request = open_ai::Request {
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
messages: vec![open_ai::RequestMessage::User {
@@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2(
}],
stream: false,
max_completion_tokens: None,
stop: generation_params.stop.unwrap_or_default(),
temperature: generation_params.temperature.or(Some(0.7)),
stop: Default::default(),
temperature: Default::default(),
tool_choice: None,
parallel_tool_calls: None,
tools: vec![],
@@ -210,7 +90,6 @@ pub fn request_prediction_with_zeta2(
log::trace!("Sending edit prediction request");
let before_request = Instant::now();
let response = EditPredictionStore::send_raw_llm_request(
request,
client,
@@ -223,68 +102,53 @@ pub fn request_prediction_with_zeta2(
)
.await;
let received_response_at = Instant::now();
let request_time = received_response_at - before_request;
log::trace!("Got edit prediction response");
if let Some(debug_response_tx) = debug_response_tx {
debug_response_tx
.send((
response
.as_ref()
.map_err(|err| err.to_string())
.map(|response| response.0.clone()),
request_time,
))
.ok();
}
let (res, usage) = response?;
let request_id = EditPredictionId(res.id.clone().into());
let Some(mut output_text) = text_from_response(res) else {
return Ok((Some((request_id, None)), usage));
};
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
EditPredictionFinishedDebugEvent {
buffer: buffer.downgrade(),
position,
model_output: Some(output_text.clone()),
},
))
.ok();
}
if output_text.contains(CURSOR_MARKER) {
log::trace!("Stripping out {CURSOR_MARKER} from response");
output_text = output_text.replace(CURSOR_MARKER, "");
}
let get_buffer_from_context = |path: &Path| {
if Some(path) == active_file_full_path.as_deref() {
Some((
&active_snapshot,
std::slice::from_ref(&excerpt_anchor_range),
))
} else {
None
}
};
let (_, edits) = match options.prompt_format {
PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
if output_text.contains("--- a/\n+++ b/\nNo edits") {
let edits = vec![];
(&active_snapshot, edits)
} else {
crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
}
}
PromptFormat::OldTextNewText => {
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
}
_ => {
bail!("unsupported prompt format {}", options.prompt_format)
}
};
let old_text = snapshot
.text_for_range(editable_offset_range.clone())
.collect::<String>();
let edits: Vec<_> = language::text_diff(&old_text, &output_text)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(editable_offset_range.start + range.start)
..snapshot.anchor_before(editable_offset_range.start + range.end),
text,
)
})
.collect();
anyhow::Ok((
Some((
request_id,
Some((
inputs,
active_buffer,
active_snapshot.clone(),
prompt_input,
buffer,
snapshot.clone(),
edits,
received_response_at,
)),
@@ -325,3 +189,40 @@ pub fn request_prediction_with_zeta2(
))
})
}
pub fn zeta2_prompt_input(
snapshot: &language::BufferSnapshot,
related_files: Arc<[zeta_prompt::RelatedFile]>,
events: Vec<Arc<zeta_prompt::Event>>,
excerpt_path: Arc<Path>,
cursor_offset: usize,
) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
let cursor_point = cursor_offset.to_point(snapshot);
let (editable_range, context_range) =
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
cursor_point,
snapshot,
MAX_CONTEXT_TOKENS,
MAX_REWRITE_TOKENS,
);
let context_start_offset = context_range.start.to_offset(snapshot);
let editable_offset_range = editable_range.to_offset(snapshot);
let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
..(editable_offset_range.end - context_start_offset);
let prompt_input = zeta_prompt::ZetaPromptInput {
cursor_path: excerpt_path,
cursor_excerpt: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
editable_range_in_excerpt,
cursor_offset_in_excerpt,
events,
related_files,
};
(editable_offset_range, prompt_input)
}

View File

@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
workspace = true
[[bin]]
name = "ep_cli"
name = "ep"
path = "src/main.rs"
[dependencies]
@@ -20,10 +20,9 @@ chrono.workspace = true
clap.workspace = true
client.workspace = true
cloud_llm_client.workspace= true
cloud_zeta2_prompt.workspace = true
collections.workspace = true
debug_adapter_extension.workspace = true
edit_prediction_context.workspace = true
dirs.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
@@ -51,12 +50,21 @@ smol.workspace = true
sqlez.workspace = true
sqlez_macros.workspace = true
terminal_view.workspace = true
toml.workspace = true
util.workspace = true
watch.workspace = true
edit_prediction = { workspace = true, features = ["eval-support"] }
wasmtime.workspace = true
zeta_prompt.workspace = true
zlog.workspace = true
# Wasmtime is included as a dependency in order to enable the same
# features that are enabled in Zed.
#
# If we don't enable these features we get crashes when creating
# a Tree-sitter WasmStore.
[package.metadata.cargo-machete]
ignored = ["wasmtime"]
[dev-dependencies]
indoc.workspace = true
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -5,11 +5,13 @@ use anthropic::{
use anyhow::Result;
use http_client::HttpClient;
use indoc::indoc;
use reqwest_client::ReqwestClient;
use sqlez::bindable::Bind;
use sqlez::bindable::StaticColumnCount;
use sqlez_macros::sql;
use std::hash::Hash;
use std::hash::Hasher;
use std::path::Path;
use std::sync::Arc;
pub struct PlainLlmClient {
@@ -18,7 +20,8 @@ pub struct PlainLlmClient {
}
impl PlainLlmClient {
fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
fn new() -> Result<Self> {
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
Ok(Self {
@@ -29,12 +32,12 @@ impl PlainLlmClient {
async fn generate(
&self,
model: String,
model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<AnthropicResponse> {
let request = AnthropicRequest {
model,
model: model.to_string(),
max_tokens,
messages,
tools: Vec::new(),
@@ -105,11 +108,12 @@ struct SerializableMessage {
}
impl BatchingLlmClient {
fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
fn new(cache_path: &Path) -> Result<Self> {
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
let connection = sqlez::connection::Connection::open_file(&cache_path);
let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
let mut statement = sqlez::statement::Statement::prepare(
&connection,
indoc! {"
@@ -182,16 +186,16 @@ impl BatchingLlmClient {
async fn generate(
&self,
model: String,
model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<Option<AnthropicResponse>> {
let response = self.lookup(&model, max_tokens, &messages)?;
let response = self.lookup(model, max_tokens, &messages)?;
if let Some(response) = response {
return Ok(Some(response));
}
self.mark_for_batch(&model, max_tokens, &messages)?;
self.mark_for_batch(model, max_tokens, &messages)?;
Ok(None)
}
@@ -258,7 +262,7 @@ impl BatchingLlmClient {
}
}
}
log::info!("Uploaded {} successful requests", success_count);
log::info!("Downloaded {} successful requests", success_count);
}
}
@@ -363,23 +367,20 @@ fn message_content_to_string(content: &[RequestContent]) -> String {
.join("\n")
}
pub enum LlmClient {
pub enum AnthropicClient {
// No batching
Plain(PlainLlmClient),
Batch(BatchingLlmClient),
Dummy,
}
impl LlmClient {
pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
Ok(Self::Plain(PlainLlmClient::new(http_client)?))
impl AnthropicClient {
pub fn plain() -> Result<Self> {
Ok(Self::Plain(PlainLlmClient::new()?))
}
pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
Ok(Self::Batch(BatchingLlmClient::new(
cache_path,
http_client,
)?))
pub fn batch(cache_path: &Path) -> Result<Self> {
Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
}
#[allow(dead_code)]
@@ -389,29 +390,29 @@ impl LlmClient {
pub async fn generate(
&self,
model: String,
model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<Option<AnthropicResponse>> {
match self {
LlmClient::Plain(plain_llm_client) => plain_llm_client
AnthropicClient::Plain(plain_llm_client) => plain_llm_client
.generate(model, max_tokens, messages)
.await
.map(Some),
LlmClient::Batch(batching_llm_client) => {
AnthropicClient::Batch(batching_llm_client) => {
batching_llm_client
.generate(model, max_tokens, messages)
.await
}
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
}
}
pub async fn sync_batches(&self) -> Result<()> {
match self {
LlmClient::Plain(_) => Ok(()),
LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
AnthropicClient::Plain(_) => Ok(()),
AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
}
}
}

View File

@@ -1,641 +0,0 @@
use crate::metrics::{self, Scores};
use std::{
collections::HashMap,
io::{IsTerminal, Write},
sync::Arc,
};
use anyhow::Result;
use edit_prediction::{EditPredictionStore, udiff::DiffLine};
use gpui::{AsyncApp, Entity};
use project::Project;
use util::ResultExt as _;
use crate::{
EvaluateArguments, PredictionOptions,
example::{Example, NamedExample},
headless::ZetaCliAppState,
paths::print_run_data_dir,
predict::{PredictionDetails, perform_predict, setup_store},
};
#[derive(Debug)]
pub(crate) struct ExecutionData {
execution_id: String,
diff: String,
reasoning: String,
}
pub async fn run_evaluate(
args: EvaluateArguments,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) {
if args.example_paths.is_empty() {
eprintln!("No examples provided");
return;
}
let all_tasks = args.example_paths.into_iter().map(|path| {
let options = args.options.clone();
let app_state = app_state.clone();
let example = NamedExample::load(&path).expect("Failed to load example");
cx.spawn(async move |cx| {
let project = example.setup_project(&app_state, cx).await.unwrap();
let providers = (0..args.repetitions)
.map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
.collect::<Vec<_>>();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
let tasks = providers
.into_iter()
.enumerate()
.map(move |(repetition_ix, store)| {
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
let example = example.clone();
let project = project.clone();
let options = options.clone();
cx.spawn(async move |cx| {
let name = example.name.clone();
run_evaluate_one(
example,
repetition_ix,
project,
store,
options,
!args.skip_prediction,
cx,
)
.await
.map_err(|err| (err, name, repetition_ix))
})
});
futures::future::join_all(tasks).await
})
});
let all_results = futures::future::join_all(all_tasks).await;
write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
if let Some(mut output_file) =
std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
{
write_aggregated_scores(&mut output_file, &all_results).log_err();
};
if args.repetitions > 1 {
if let Err(e) = write_bucketed_analysis(&all_results) {
eprintln!("Failed to write bucketed analysis: {:?}", e);
}
}
print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
}
fn write_aggregated_scores(
w: &mut impl std::io::Write,
all_results: &Vec<
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
>,
) -> Result<()> {
let mut successful = Vec::new();
let mut failed_count = 0;
for result in all_results.iter().flatten() {
match result {
Ok((eval_result, _execution_data)) => successful.push(eval_result),
Err((err, name, repetition_ix)) => {
if failed_count == 0 {
writeln!(w, "## Errors\n")?;
}
failed_count += 1;
writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
}
}
}
if successful.len() > 1 {
let edit_scores = successful
.iter()
.filter_map(|r| r.edit_scores.clone())
.collect::<Vec<_>>();
let has_edit_predictions = edit_scores.len() > 0;
let aggregated_result = EvaluationResult {
context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
/ successful.len(),
};
writeln!(w, "\n{}", "-".repeat(80))?;
writeln!(w, "\n## TOTAL SCORES")?;
writeln!(w, "{:#}", aggregated_result)?;
}
if successful.len() + failed_count > 1 {
writeln!(
w,
"\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
successful.len(),
successful.len() + failed_count,
(successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
)?;
}
Ok(())
}
pub async fn run_evaluate_one(
example: NamedExample,
repetition_ix: Option<u16>,
project: Entity<Project>,
store: Entity<EditPredictionStore>,
prediction_options: PredictionOptions,
predict: bool,
cx: &mut AsyncApp,
) -> Result<(EvaluationResult, ExecutionData)> {
let predict_result = perform_predict(
example.clone(),
project,
store,
repetition_ix,
prediction_options,
cx,
)
.await?;
let evaluation_result = evaluate(&example.example, &predict_result, predict);
if repetition_ix.is_none() {
write_eval_result(
&example,
&predict_result,
&evaluation_result,
&mut std::io::stdout(),
std::io::stdout().is_terminal(),
predict,
)?;
}
if let Some(mut results_file) =
std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
{
write_eval_result(
&example,
&predict_result,
&evaluation_result,
&mut results_file,
false,
predict,
)
.log_err();
}
let execution_data = ExecutionData {
execution_id: if let Some(rep_ix) = repetition_ix {
format!("{:03}", rep_ix)
} else {
example.name.clone()
},
diff: predict_result.diff.clone(),
reasoning: std::fs::read_to_string(
predict_result
.run_example_dir
.join("prediction_response.md"),
)
.unwrap_or_default(),
};
anyhow::Ok((evaluation_result, execution_data))
}
fn write_eval_result(
example: &NamedExample,
predictions: &PredictionDetails,
evaluation_result: &EvaluationResult,
out: &mut impl Write,
use_color: bool,
predict: bool,
) -> Result<()> {
if predict {
writeln!(
out,
"## Expected edit prediction:\n\n```diff\n{}\n```\n",
compare_diffs(
&example.example.expected_patch,
&predictions.diff,
use_color
)
)?;
writeln!(
out,
"## Actual edit prediction:\n\n```diff\n{}\n```\n",
compare_diffs(
&predictions.diff,
&example.example.expected_patch,
use_color
)
)?;
}
writeln!(out, "{:#}", evaluation_result)?;
anyhow::Ok(())
}
#[derive(Debug, Default, Clone)]
pub struct EditScores {
pub line_match: Scores,
pub chr_f: f64,
}
impl EditScores {
pub fn aggregate(scores: &[EditScores]) -> EditScores {
let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
EditScores { line_match, chr_f }
}
}
#[derive(Debug, Default)]
pub struct EvaluationResult {
pub edit_scores: Option<EditScores>,
pub context_scores: Scores,
pub prompt_len: usize,
pub generated_len: usize,
}
impl std::fmt::Display for EvaluationResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if f.alternate() {
self.fmt_table(f)
} else {
self.fmt_markdown(f)
}
}
}
impl EvaluationResult {
fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
r#"
### Context Scores
{}
"#,
self.context_scores.to_markdown(),
)?;
if let Some(scores) = &self.edit_scores {
write!(
f,
r#"
### Edit Prediction Scores
{}"#,
scores.line_match.to_markdown()
)?;
}
Ok(())
}
fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "#### Prompt Statistics")?;
writeln!(f, "─────────────────────────")?;
writeln!(f, "Prompt_len Generated_len")?;
writeln!(f, "─────────────────────────")?;
writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
writeln!(f)?;
writeln!(f)?;
writeln!(f, "#### Performance Scores")?;
writeln!(
f,
"──────────────────────────────────────────────────────────────────"
)?;
writeln!(
f,
" TP FP FN Precision Recall F1"
)?;
writeln!(
f,
"──────────────────────────────────────────────────────────────────"
)?;
writeln!(
f,
"Context Retrieval {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
self.context_scores.true_positives,
self.context_scores.false_positives,
self.context_scores.false_negatives,
self.context_scores.precision() * 100.0,
self.context_scores.recall() * 100.0,
self.context_scores.f1_score() * 100.0
)?;
if let Some(edit_scores) = &self.edit_scores {
let line_match = &edit_scores.line_match;
writeln!(f, "Edit Prediction")?;
writeln!(
f,
" ├─ exact lines {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
line_match.true_positives,
line_match.false_positives,
line_match.false_negatives,
line_match.precision() * 100.0,
line_match.recall() * 100.0,
line_match.f1_score() * 100.0
)?;
writeln!(
f,
" └─ diff chrF {:<6} {:<6} {:<6} {:>8} {:>8} {:>6.2}",
"-", "-", "-", "-", "-", edit_scores.chr_f
)?;
}
Ok(())
}
}
fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
let mut eval_result = EvaluationResult {
prompt_len: preds.prompt_len,
generated_len: preds.generated_len,
..Default::default()
};
if predict {
// todo: alternatives for patches
let expected_patch = example
.expected_patch
.lines()
.map(DiffLine::parse)
.collect::<Vec<_>>();
let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
eval_result.edit_scores = Some(EditScores { line_match, chr_f });
}
eval_result
}
/// Return annotated `patch_a` so that:
/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
/// Additions and deletions that are present in `patch_b` will be highlighted in green.
pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
let green = if use_color { "\x1b[32m✓ " } else { "" };
let red = if use_color { "\x1b[31m✗ " } else { "" };
let neutral = if use_color { " " } else { "" };
let reset = if use_color { "\x1b[0m" } else { "" };
let lines_a = patch_a.lines().map(DiffLine::parse);
let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
let annotated = lines_a
.map(|line| match line {
DiffLine::Addition(_) | DiffLine::Deletion(_) => {
if lines_b.contains(&line) {
format!("{green}{line}{reset}")
} else {
format!("{red}{line}{reset}")
}
}
_ => format!("{neutral}{line}{reset}"),
})
.collect::<Vec<String>>();
annotated.join("\n")
}
fn write_bucketed_analysis(
all_results: &Vec<
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
>,
) -> Result<()> {
#[derive(Debug)]
struct EditBucket {
diff: String,
is_correct: bool,
execution_indices: Vec<String>,
reasoning_samples: Vec<String>,
}
let mut total_executions = 0;
let mut empty_predictions = Vec::new();
let mut errors = Vec::new();
let mut buckets: HashMap<String, EditBucket> = HashMap::new();
for result in all_results.iter().flatten() {
total_executions += 1;
let (evaluation_result, execution_data) = match result {
Ok((eval_result, execution_data)) => {
if execution_data.diff.is_empty() {
empty_predictions.push(execution_data);
continue;
}
(eval_result, execution_data)
}
Err(err) => {
errors.push(err);
continue;
}
};
buckets
.entry(execution_data.diff.clone())
.and_modify(|bucket| {
bucket
.execution_indices
.push(execution_data.execution_id.clone());
bucket
.reasoning_samples
.push(execution_data.reasoning.clone());
})
.or_insert_with(|| EditBucket {
diff: execution_data.diff.clone(),
is_correct: {
evaluation_result
.edit_scores
.as_ref()
.map_or(false, |edit_scores| {
edit_scores.line_match.false_positives == 0
&& edit_scores.line_match.false_negatives == 0
&& edit_scores.line_match.true_positives > 0
})
},
execution_indices: vec![execution_data.execution_id.clone()],
reasoning_samples: vec![execution_data.reasoning.clone()],
});
}
let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => b.execution_indices.len().cmp(&a.execution_indices.len()),
});
let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
let mut output = std::fs::File::create(&output_path)?;
writeln!(output, "# Bucketed Edit Analysis\n")?;
writeln!(output, "## Summary\n")?;
writeln!(output, "- **Total executions**: {}", total_executions)?;
let correct_count: usize = sorted_buckets
.iter()
.filter(|b| b.is_correct)
.map(|b| b.execution_indices.len())
.sum();
let incorrect_count: usize = sorted_buckets
.iter()
.filter(|b| !b.is_correct)
.map(|b| b.execution_indices.len())
.sum();
writeln!(
output,
"- **Correct predictions**: {} ({:.1}%)",
correct_count,
(correct_count as f64 / total_executions as f64) * 100.0
)?;
writeln!(
output,
"- **Incorrect predictions**: {} ({:.1}%)",
incorrect_count,
(incorrect_count as f64 / total_executions as f64) * 100.0
)?;
writeln!(
output,
"- **No Predictions**: {} ({:.1}%)",
empty_predictions.len(),
(empty_predictions.len() as f64 / total_executions as f64) * 100.0
)?;
let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
writeln!(
output,
"- **Unique incorrect edit patterns**: {}\n",
unique_incorrect
)?;
writeln!(output, "---\n")?;
for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
if idx == 0 {
writeln!(
output,
"## Correct Predictions ({} occurrences)\n",
bucket.execution_indices.len()
)?;
}
writeln!(output, "**Predicted Edit:**\n")?;
writeln!(output, "```diff")?;
writeln!(output, "{}", bucket.diff)?;
writeln!(output, "```\n")?;
writeln!(
output,
"**Executions:** {}\n",
bucket.execution_indices.join(", ")
)?;
writeln!(output, "---\n")?;
}
for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
writeln!(
output,
"## Incorrect Prediction #{} ({} occurrences)\n",
idx + 1,
bucket.execution_indices.len()
)?;
writeln!(output, "**Predicted Edit:**\n")?;
writeln!(output, "```diff")?;
writeln!(output, "{}", bucket.diff)?;
writeln!(output, "```\n")?;
writeln!(
output,
"**Executions:** {}\n",
bucket.execution_indices.join(", ")
)?;
for (exec_id, reasoning) in bucket
.execution_indices
.iter()
.zip(bucket.reasoning_samples.iter())
{
writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
}
writeln!(output, "\n---\n")?;
}
if !empty_predictions.is_empty() {
writeln!(
output,
"## No Predictions ({} occurrences)\n",
empty_predictions.len()
)?;
for execution_data in &empty_predictions {
writeln!(
output,
"{}",
fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
)?;
}
writeln!(output, "\n---\n")?;
}
if !errors.is_empty() {
writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
for (err, name, repetition_ix) in &errors {
writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
}
writeln!(output, "\n---\n")?;
}
fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
let exec_content = format!(
"\n### Execution {} `{}/{}/prediction_response.md`{}",
exec_id,
crate::paths::RUN_DIR.display(),
exec_id,
indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
);
indent_text(&exec_content, 2)
}
fn indent_text(text: &str, spaces: usize) -> String {
let indent = " ".repeat(spaces);
text.lines()
.collect::<Vec<_>>()
.join(&format!("\n{}", indent))
}
Ok(())
}
fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
let err = format!("{err:?}")
.replace("<edits", "```xml\n<edits")
.replace("</edits>", "</edits>\n```");
format!(
"### ERROR {name}{}\n\n{err}\n",
repetition_ix
.map(|ix| format!(" [RUN {ix:03}]"))
.unwrap_or_default()
)
}

View File

@@ -1,59 +1,103 @@
use crate::{
PredictionProvider, PromptFormat,
metrics::ClassificationMetrics,
paths::{REPOS_DIR, WORKTREES_DIR},
};
use anyhow::{Context as _, Result};
use edit_prediction::udiff::OpenedBuffers;
use gpui::Entity;
use http_client::Url;
use language::{Anchor, Buffer};
use project::Project;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::{
borrow::Cow,
cell::RefCell,
fmt::{self, Display},
fs,
hash::Hash,
hash::Hasher,
io::Write,
io::{Read, Write},
mem,
path::{Path, PathBuf},
sync::{Arc, OnceLock},
};
use zeta_prompt::RelatedFile;
use crate::headless::ZetaCliAppState;
use anyhow::{Context as _, Result, anyhow};
use clap::ValueEnum;
use cloud_zeta2_prompt::CURSOR_MARKER;
use collections::HashMap;
use edit_prediction::udiff::OpenedBuffers;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
};
use futures::{FutureExt as _, future::Shared};
use gpui::{AsyncApp, Entity, Task, http_client::Url};
use language::{Anchor, Buffer};
use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use serde::{Deserialize, Serialize};
use util::{paths::PathStyle, rel_path::RelPath};
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
const REPOSITORY_URL_FIELD: &str = "repository_url";
const REVISION_FIELD: &str = "revision";
#[derive(Debug, Clone)]
pub struct NamedExample {
pub name: String,
pub example: Example,
}
#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Example {
#[serde(default)]
pub name: String,
pub repository_url: String,
pub revision: String,
pub uncommitted_diff: String,
pub cursor_path: PathBuf,
pub cursor_path: Arc<Path>,
pub cursor_position: String,
pub edit_history: String,
pub expected_patch: String,
/// The full content of the file where an edit is being predicted, and the
/// actual cursor offset.
#[serde(skip_serializing_if = "Option::is_none")]
pub buffer: Option<ExampleBuffer>,
/// The context retrieved for the prediction. This requires the worktree to
/// be loaded and the language server to be started.
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<ExampleContext>,
/// The input and expected output from the edit prediction model.
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<ExamplePrompt>,
/// The actual predictions from the model.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub predictions: Vec<ExamplePrediction>,
/// The scores, for how well the actual predictions match the expected
/// predictions.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub score: Vec<ExampleScore>,
/// The application state used to process this example.
#[serde(skip)]
pub state: Option<ExampleState>,
}
#[derive(Clone, Debug)]
pub struct ExampleState {
pub project: Entity<Project>,
pub buffer: Entity<Buffer>,
pub cursor_position: Anchor,
pub _open_buffers: OpenedBuffers,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleContext {
pub files: Arc<[RelatedFile]>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleBuffer {
pub content: String,
pub cursor_row: u32,
pub cursor_column: u32,
pub cursor_offset: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExamplePrompt {
pub input: String,
pub expected_output: String,
pub format: PromptFormat,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExamplePrediction {
pub actual_patch: String,
pub actual_output: String,
pub provider: PredictionProvider,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleScore {
pub delta_chr_f: f32,
pub line_match: ClassificationMetrics,
}
impl Example {
@@ -90,485 +134,244 @@ impl Example {
}
}
pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
let (repo_owner, repo_name) = self.repo_name()?;
pub fn worktree_path(&self) -> PathBuf {
WORKTREES_DIR
.join(&self.name)
.join(self.repo_name().unwrap().1.as_ref())
}
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let repo_lock = lock_repo(&repo_dir).await;
pub fn repo_path(&self) -> PathBuf {
let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
}
}
if !repo_dir.is_dir() {
fs::create_dir_all(&repo_dir)?;
run_git(&repo_dir, &["init"]).await?;
run_git(
&repo_dir,
&["remote", "add", "origin", &self.repository_url],
)
.await?;
}
pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
let mut examples = Vec::new();
// Resolve the example to a revision, fetching it if needed.
let revision = run_git(
&repo_dir,
&["rev-parse", &format!("{}^{{commit}}", self.revision)],
)
.await;
let revision = if let Ok(revision) = revision {
revision
let stdin_path: PathBuf = PathBuf::from("-");
let inputs = if inputs.is_empty() {
&[stdin_path]
} else {
inputs
};
for path in inputs {
let is_stdin = path.as_path() == Path::new("-");
let content = if is_stdin {
let mut buffer = String::new();
std::io::stdin()
.read_to_string(&mut buffer)
.expect("Failed to read from stdin");
buffer
} else {
if run_git(
&repo_dir,
&["fetch", "--depth", "1", "origin", &self.revision],
)
.await
.is_err()
{
run_git(&repo_dir, &["fetch", "origin"]).await?;
}
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
if revision != self.revision {
run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
}
revision
std::fs::read_to_string(path)
.unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
};
let filename = path.file_stem().unwrap().to_string_lossy().to_string();
let ext = if !is_stdin {
path.extension()
.map(|ext| ext.to_string_lossy().to_string())
.unwrap_or_else(|| panic!("{} should have an extension", path.display()))
} else {
"jsonl".to_string()
};
// Create the worktree for this example if needed.
let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
run_git(
&repo_dir,
&["worktree", "add", "-f", &worktree_path_string, &file_name],
)
.await?;
}
drop(repo_lock);
// Apply the uncommitted diff for this example.
if !self.uncommitted_diff.is_empty() {
let mut apply_process = smol::process::Command::new("git")
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()?;
let mut stdin = apply_process.stdin.take().unwrap();
stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
stdin.close().await?;
drop(stdin);
let apply_result = apply_process.output().await?;
if !apply_result.status.success() {
anyhow::bail!(
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
match ext.as_ref() {
"json" => {
let mut example =
serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
panic!("Failed to parse example file: {}\n{error}", path.display())
});
if example.name.is_empty() {
example.name = filename;
}
examples.push(example);
}
"jsonl" => examples.extend(
content
.lines()
.enumerate()
.map(|(line_ix, line)| {
let mut example =
serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
panic!(
"Failed to parse example on {}:{}",
path.display(),
line_ix + 1
)
});
if example.name.is_empty() {
example.name = format!("{filename}-{line_ix}")
}
example
})
.collect::<Vec<Example>>(),
),
"md" => {
examples.push(parse_markdown_example(filename, &content).unwrap());
}
ext => {
panic!("{} has invalid example extension `{ext}`", path.display())
}
}
Ok(worktree_path)
}
examples
}
pub fn unique_name(&self) -> String {
let mut hasher = std::hash::DefaultHasher::new();
self.hash(&mut hasher);
let disambiguator = hasher.finish();
let hash = format!("{:04x}", disambiguator);
format!("{}_{}", &self.revision[..8], &hash[..4])
pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
let mut content = String::new();
for example in examples {
let line = serde_json::to_string(example).unwrap();
content.push_str(&line);
content.push('\n');
}
if let Some(output_path) = output_path {
std::fs::write(output_path, content).expect("Failed to write examples");
} else {
std::io::stdout().write_all(&content.as_bytes()).unwrap();
}
}
pub type ActualExcerpt = Excerpt;
fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Excerpt {
pub path: PathBuf,
pub text: String,
}
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
const REPOSITORY_URL_FIELD: &str = "repository_url";
const REVISION_FIELD: &str = "revision";
#[derive(ValueEnum, Debug, Clone)]
pub enum ExampleFormat {
Json,
Toml,
Md,
}
let parser = Parser::new(input);
impl NamedExample {
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)?;
let ext = path.extension();
let mut example = Example {
name: id,
repository_url: String::new(),
revision: String::new(),
uncommitted_diff: String::new(),
cursor_path: PathBuf::new().into(),
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
buffer: None,
context: None,
prompt: None,
predictions: Vec::new(),
score: Vec::new(),
state: None,
};
match ext.and_then(|s| s.to_str()) {
Some("json") => Ok(Self {
name: path.file_stem().unwrap_or_default().display().to_string(),
example: serde_json::from_str(&content)?,
}),
Some("toml") => Ok(Self {
name: path.file_stem().unwrap_or_default().display().to_string(),
example: toml::from_str(&content)?,
}),
Some("md") => Self::parse_md(&content),
Some(_) => {
anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
}
None => {
anyhow::bail!(
"Failed to determine example type since the file does not have an extension."
);
}
}
let mut name = String::new();
let mut text = String::new();
let mut block_info: CowStr = "".into();
#[derive(PartialEq)]
enum Section {
UncommittedDiff,
EditHistory,
CursorPosition,
ExpectedExcerpts,
ExpectedPatch,
Other,
}
pub fn parse_md(input: &str) -> Result<Self> {
use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
let mut current_section = Section::Other;
let parser = Parser::new(input);
for event in parser {
match event {
Event::Text(line) => {
text.push_str(&line);
let mut named = NamedExample {
name: String::new(),
example: Example {
repository_url: String::new(),
revision: String::new(),
uncommitted_diff: String::new(),
cursor_path: PathBuf::new(),
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
},
};
let mut text = String::new();
let mut block_info: CowStr = "".into();
#[derive(PartialEq)]
enum Section {
UncommittedDiff,
EditHistory,
CursorPosition,
ExpectedExcerpts,
ExpectedPatch,
Other,
}
let mut current_section = Section::Other;
for event in parser {
match event {
Event::Text(line) => {
text.push_str(&line);
if !named.name.is_empty()
&& current_section == Section::Other
// in h1 section
&& let Some((field, value)) = line.split_once('=')
{
match field.trim() {
REPOSITORY_URL_FIELD => {
named.example.repository_url = value.trim().to_string();
}
REVISION_FIELD => {
named.example.revision = value.trim().to_string();
}
_ => {}
if let Some((field, value)) = line.split_once('=') {
match field.trim() {
REPOSITORY_URL_FIELD => {
example.repository_url = value.trim().to_string();
}
REVISION_FIELD => {
example.revision = value.trim().to_string();
}
_ => {}
}
}
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
if !named.name.is_empty() {
anyhow::bail!(
"Found multiple H1 headings. There should only be one with the name of the example."
);
}
named.name = mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
let title = mem::take(&mut text);
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
Section::UncommittedDiff
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
Section::EditHistory
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
Section::CursorPosition
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
Section::ExpectedPatch
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
Section::ExpectedExcerpts
} else {
Section::Other
};
}
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(level)) => {
anyhow::bail!("Unexpected heading level: {level}");
}
Event::Start(Tag::CodeBlock(kind)) => {
match kind {
CodeBlockKind::Fenced(info) => {
block_info = info;
}
CodeBlockKind::Indented => {
anyhow::bail!("Unexpected indented codeblock");
}
};
}
Event::Start(_) => {
text.clear();
block_info = "".into();
}
Event::End(TagEnd::CodeBlock) => {
let block_info = block_info.trim();
match current_section {
Section::UncommittedDiff => {
named.example.uncommitted_diff = mem::take(&mut text);
}
Section::EditHistory => {
named.example.edit_history.push_str(&mem::take(&mut text));
}
Section::CursorPosition => {
named.example.cursor_path = block_info.into();
named.example.cursor_position = mem::take(&mut text);
}
Section::ExpectedExcerpts => {
mem::take(&mut text);
}
Section::ExpectedPatch => {
named.example.expected_patch = mem::take(&mut text);
}
Section::Other => {}
}
}
_ => {}
}
}
if named.example.cursor_path.as_path() == Path::new("")
|| named.example.cursor_position.is_empty()
{
anyhow::bail!("Missing cursor position codeblock");
}
Ok(named)
}
pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
match format {
ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
ExampleFormat::Toml => {
Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
if !name.is_empty() {
anyhow::bail!(
"Found multiple H1 headings. There should only be one with the name of the example."
);
}
name = mem::take(&mut text);
}
ExampleFormat::Md => Ok(write!(out, "{}", self)?),
}
}
pub async fn setup_project(
&self,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<Entity<Project>> {
let worktree_path = self.setup_worktree().await?;
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
AUTHENTICATED
.get_or_init(|| {
let client = app_state.client.clone();
cx.spawn(async move |cx| {
client
.sign_in_with_optional_connect(true, cx)
.await
.unwrap();
})
.shared()
})
.clone()
.await;
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(&worktree_path, true, cx)
})?
.await?;
worktree
.read_with(cx, |worktree, _cx| {
worktree.as_local().unwrap().scan_complete()
})?
.await;
anyhow::Ok(project)
}
pub async fn setup_worktree(&self) -> Result<PathBuf> {
self.example.setup_worktree(self.file_name()).await
}
pub fn file_name(&self) -> String {
self.name
.chars()
.map(|c| {
if c.is_whitespace() {
'-'
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
let title = mem::take(&mut text);
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
Section::UncommittedDiff
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
Section::EditHistory
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
Section::CursorPosition
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
Section::ExpectedPatch
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
Section::ExpectedExcerpts
} else {
c.to_ascii_lowercase()
Section::Other
};
}
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(level)) => {
anyhow::bail!("Unexpected heading level: {level}");
}
Event::Start(Tag::CodeBlock(kind)) => {
match kind {
CodeBlockKind::Fenced(info) => {
block_info = info;
}
CodeBlockKind::Indented => {
anyhow::bail!("Unexpected indented codeblock");
}
};
}
Event::Start(_) => {
text.clear();
block_info = "".into();
}
Event::End(TagEnd::CodeBlock) => {
let block_info = block_info.trim();
match current_section {
Section::UncommittedDiff => {
example.uncommitted_diff = mem::take(&mut text);
}
Section::EditHistory => {
example.edit_history.push_str(&mem::take(&mut text));
}
Section::CursorPosition => {
example.cursor_path = Path::new(block_info).into();
example.cursor_position = mem::take(&mut text);
}
Section::ExpectedExcerpts => {
mem::take(&mut text);
}
Section::ExpectedPatch => {
example.expected_patch = mem::take(&mut text);
}
Section::Other => {}
}
})
.collect()
}
pub async fn cursor_position(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<(Entity<Buffer>, Anchor)> {
let worktree = project.read_with(cx, |project, cx| {
project.visible_worktrees(cx).next().unwrap()
})?;
let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
project.open_buffer(
ProjectPath {
worktree_id: worktree.read(cx).id(),
path: cursor_path,
},
cx,
)
})?
.await?;
let cursor_offset_within_excerpt = self
.example
.cursor_position
.find(CURSOR_MARKER)
.ok_or_else(|| anyhow!("missing cursor marker"))?;
let mut cursor_excerpt = self.example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
"",
);
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let Some((excerpt_offset, _)) = matches.next() else {
anyhow::bail!(
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
);
};
assert!(matches.next().is_none());
Ok(excerpt_offset)
})??;
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor =
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
Ok((cursor_buffer, cursor_anchor))
}
#[must_use]
pub async fn apply_edit_history(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'_>> {
edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
}
}
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = smol::process::Command::new("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
anyhow::ensure!(
output.status.success(),
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
args.join(" "),
repo_path.display(),
output.status,
String::from_utf8_lossy(&output.stderr),
String::from_utf8_lossy(&output.stdout),
);
Ok(String::from_utf8(output.stdout)?.trim().to_string())
}
impl Display for NamedExample {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "# {}\n\n", self.name)?;
write!(
f,
"{REPOSITORY_URL_FIELD} = {}\n",
self.example.repository_url
)?;
write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
write!(f, "`````diff\n")?;
write!(f, "{}", self.example.uncommitted_diff)?;
write!(f, "`````\n")?;
if !self.example.edit_history.is_empty() {
write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
}
_ => {}
}
write!(
f,
"## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
self.example.cursor_path.display(),
self.example.cursor_position
)?;
write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
if !self.example.expected_patch.is_empty() {
write!(
f,
"\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
self.example.expected_patch
)?;
}
Ok(())
}
}
if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
anyhow::bail!("Missing cursor position codeblock");
}
thread_local! {
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
}
#[must_use]
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
REPO_LOCKS
.with(|cell| {
cell.borrow_mut()
.entry(path.as_ref().to_path_buf())
.or_default()
.clone()
})
.lock_owned()
.await
Ok(example)
}

View File

@@ -0,0 +1,280 @@
use crate::{
PromptFormat,
example::{Example, ExamplePrompt},
headless::EpAppState,
retrieve_context::run_context_retrieval,
};
use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
use gpui::AsyncApp;
use std::sync::Arc;
use zeta_prompt::format_zeta_prompt;
pub async fn run_format_prompt(
example: &mut Example,
prompt_format: PromptFormat,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) {
run_context_retrieval(example, app_state, cx.clone()).await;
let prompt = match prompt_format {
PromptFormat::Teacher => TeacherPrompt::format(example),
PromptFormat::Zeta2 => {
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let state = example.state.as_ref().unwrap();
let snapshot = state
.buffer
.read_with(&cx, |buffer, _| buffer.snapshot())
.unwrap();
let project = state.project.clone();
let (_, input) = ep_store
.update(&mut cx, |ep_store, _cx| {
zeta2_prompt_input(
&snapshot,
example.context.as_ref().unwrap().files.clone(),
ep_store.edit_history_for_project(&project),
example.cursor_path.clone(),
example.buffer.as_ref().unwrap().cursor_offset,
)
})
.unwrap();
format_zeta_prompt(&input)
}
};
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output: example.expected_patch.clone(), // TODO
format: prompt_format,
});
}
pub trait PromptFormatter {
fn format(example: &Example) -> String;
}
pub trait PromptParser {
/// Return unified diff patch of prediction given raw LLM response
fn parse(example: &Example, response: &str) -> String;
}
pub struct TeacherPrompt;
impl PromptFormatter for TeacherPrompt {
fn format(example: &Example) -> String {
let edit_history = Self::format_edit_history(&example.edit_history);
let context = Self::format_context(example);
let editable_region = Self::format_editable_region(example);
let prompt = Self::PROMPT
.replace("{{context}}", &context)
.replace("{{edit_history}}", &edit_history)
.replace("{{editable_region}}", &editable_region);
prompt
}
}
impl TeacherPrompt {
const PROMPT: &str = include_str!("teacher.prompt.md");
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
/// Truncate edit history to this number of last lines
const MAX_HISTORY_LINES: usize = 128;
fn format_edit_history(edit_history: &str) -> String {
// Strip comments ("garbage lines") from edit history
let lines = edit_history
.lines()
.filter(|&s| Self::is_udiff_content_line(s))
.collect::<Vec<_>>();
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
} else {
&lines
};
if history_lines.is_empty() {
return "(No edit history)".to_string();
}
history_lines.join("\n")
}
fn format_context(example: &Example) -> String {
if example.context.is_none() {
panic!("Missing context retriever step");
}
let mut prompt = String::new();
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
prompt
}
fn format_editable_region(example: &Example) -> String {
let mut result = String::new();
let path_str = example.cursor_path.to_string_lossy();
result.push_str(&format!("`````path=\"{path_str}\"\n"));
result.push_str(Self::EDITABLE_REGION_START);
// TODO: control number of lines around cursor
result.push_str(&example.cursor_position);
if !example.cursor_position.ends_with('\n') {
result.push('\n');
}
result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
result.push_str("`````");
result
}
fn extract_editable_region(text: &str) -> String {
let start = text
.find(Self::EDITABLE_REGION_START)
.map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
let region = &text[start..end];
region.replace("<|user_cursor|>", "")
}
fn is_udiff_content_line(s: &str) -> bool {
s.starts_with("-")
|| s.starts_with("+")
|| s.starts_with(" ")
|| s.starts_with("---")
|| s.starts_with("+++")
|| s.starts_with("@@")
}
}
impl PromptParser for TeacherPrompt {
fn parse(example: &Example, response: &str) -> String {
// Ideally, we should always be able to find cursor position in the retrieved context.
// In reality, sometimes we don't find it for these reasons:
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
// (can be fixed by getting cursor coordinates at the load_example stage)
// 2. Context retriever just didn't include cursor line.
//
// In that case, fallback to using `cursor_position` as excerpt.
let cursor_file = &example
.buffer
.as_ref()
.expect("`buffer` should be filled in in the context collection step")
.content;
// Extract updated (new) editable region from the model response
let new_editable_region = extract_last_codeblock(response);
// Reconstruct old editable region we sent to the model
let old_editable_region = Self::format_editable_region(example);
let old_editable_region = Self::extract_editable_region(&old_editable_region);
if !cursor_file.contains(&old_editable_region) {
panic!("Something's wrong: editable_region is not found in the cursor file")
}
// Apply editable region to a larger context and compute diff.
// This is needed to get a better context lines around the editable region
let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
let diff = language::unified_diff(&cursor_file, &edited_file);
let diff = indoc::formatdoc! {"
--- a/{path}
+++ b/{path}
{diff}
",
path = example.cursor_path.to_string_lossy(),
diff = diff,
};
diff
}
}
fn extract_last_codeblock(text: &str) -> String {
let mut last_block = None;
let mut search_start = 0;
while let Some(start) = text[search_start..].find("```") {
let start = start + search_start;
let bytes = text.as_bytes();
let mut backtick_end = start;
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
backtick_end += 1;
}
let backtick_count = backtick_end - start;
let closing_backticks = "`".repeat(backtick_count);
while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
backtick_end += 1;
}
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
last_block = Some(code_block.to_string());
search_start = backtick_end + end_pos + backtick_count;
} else {
break;
}
}
last_block.unwrap_or_else(|| text.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_last_code_block() {
let text = indoc::indoc! {"
Some thinking
```
first block
```
`````path='something' lines=1:2
last block
`````
"};
let last_block = extract_last_codeblock(text);
assert_eq!(last_block, "last block");
}
#[test]
fn test_extract_editable_region() {
let text = indoc::indoc! {"
some lines
are
here
<|editable_region_start|>
one
two three
<|editable_region_end|>
more
lines here
"};
let parsed = TeacherPrompt::extract_editable_region(text);
assert_eq!(
parsed,
indoc::indoc! {"
one
two three
"}
);
}
}

View File

@@ -16,7 +16,7 @@ use std::sync::Arc;
use util::ResultExt as _;
/// Headless subset of `workspace::AppState`.
pub struct ZetaCliAppState {
pub struct EpAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
@@ -25,7 +25,7 @@ pub struct ZetaCliAppState {
}
// TODO: dedupe with crates/eval/src/eval.rs
pub fn init(cx: &mut App) -> ZetaCliAppState {
pub fn init(cx: &mut App) -> EpAppState {
let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
let app_version = AppVersion::load(
@@ -112,7 +112,7 @@ pub fn init(cx: &mut App) -> ZetaCliAppState {
prompt_store::init(cx);
terminal_view::init(cx);
ZetaCliAppState {
EpAppState {
languages,
client,
user_store,

View File

@@ -0,0 +1,320 @@
use crate::{
example::{Example, ExampleBuffer, ExampleState},
headless::EpAppState,
};
use anyhow::{Result, anyhow};
use collections::HashMap;
use edit_prediction::EditPredictionStore;
use edit_prediction::udiff::OpenedBuffers;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
};
use gpui::{AsyncApp, Entity};
use language::{Anchor, Buffer, ToOffset, ToPoint};
use project::buffer_store::BufferStoreEvent;
use project::{Project, ProjectPath};
use std::{
cell::RefCell,
fs,
path::{Path, PathBuf},
sync::Arc,
};
use util::{paths::PathStyle, rel_path::RelPath};
use zeta_prompt::CURSOR_MARKER;
pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
if example.state.is_some() {
return;
}
let project = setup_project(example, &app_state, &mut cx).await;
let buffer_store = project
.read_with(&cx, |project, _| project.buffer_store().clone())
.unwrap();
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
cx.subscribe(&buffer_store, {
let project = project.clone();
move |_, event, cx| match event {
BufferStoreEvent::BufferAdded(buffer) => {
ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
}
_ => {}
}
})
.unwrap()
.detach();
let _open_buffers = apply_edit_history(example, &project, &mut cx)
.await
.unwrap();
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
example.buffer = buffer
.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
Some(ExampleBuffer {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
})
})
.unwrap();
example.state = Some(ExampleState {
buffer,
project,
cursor_position,
_open_buffers,
});
}
async fn cursor_position(
example: &Example,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> (Entity<Buffer>, Anchor) {
let worktree = project
.read_with(cx, |project, cx| {
project.visible_worktrees(cx).next().unwrap()
})
.unwrap();
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
.unwrap()
.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
project.open_buffer(
ProjectPath {
worktree_id: worktree.read(cx).id(),
path: cursor_path,
},
cx,
)
})
.unwrap()
.await
.unwrap();
let cursor_offset_within_excerpt = example
.cursor_position
.find(CURSOR_MARKER)
.ok_or_else(|| anyhow!("missing cursor marker"))
.unwrap();
let mut cursor_excerpt = example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
"",
);
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
panic!(
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
);
});
assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
excerpt_offset
}).unwrap();
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor = cursor_buffer
.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
.unwrap();
(cursor_buffer, cursor_anchor)
}
async fn setup_project(
example: &mut Example,
app_state: &Arc<EpAppState>,
cx: &mut AsyncApp,
) -> Entity<Project> {
setup_worktree(example).await;
let project = cx
.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})
.unwrap();
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(&example.worktree_path(), true, cx)
})
.unwrap()
.await
.unwrap();
worktree
.read_with(cx, |worktree, _cx| {
worktree.as_local().unwrap().scan_complete()
})
.unwrap()
.await;
project
}
pub async fn setup_worktree(example: &Example) {
let repo_dir = example.repo_path();
let repo_lock = lock_repo(&repo_dir).await;
if !repo_dir.is_dir() {
fs::create_dir_all(&repo_dir).unwrap();
run_git(&repo_dir, &["init"]).await.unwrap();
run_git(
&repo_dir,
&["remote", "add", "origin", &example.repository_url],
)
.await
.unwrap();
}
// Resolve the example to a revision, fetching it if needed.
let revision = run_git(
&repo_dir,
&["rev-parse", &format!("{}^{{commit}}", example.revision)],
)
.await;
let revision = if let Ok(revision) = revision {
revision
} else {
if run_git(
&repo_dir,
&["fetch", "--depth", "1", "origin", &example.revision],
)
.await
.is_err()
{
run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
}
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
.await
.unwrap();
if revision != example.revision {
run_git(&repo_dir, &["tag", &example.revision, &revision])
.await
.unwrap();
}
revision
};
// Create the worktree for this example if needed.
let worktree_path = example.worktree_path();
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"])
.await
.unwrap();
run_git(&worktree_path, &["reset", "--hard", "HEAD"])
.await
.unwrap();
run_git(&worktree_path, &["checkout", revision.as_str()])
.await
.unwrap();
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(
&repo_dir,
&["branch", "-f", &example.name, revision.as_str()],
)
.await
.unwrap();
run_git(
&repo_dir,
&[
"worktree",
"add",
"-f",
&worktree_path_string,
&example.name,
],
)
.await
.unwrap();
}
drop(repo_lock);
// Apply the uncommitted diff for this example.
if !example.uncommitted_diff.is_empty() {
let mut apply_process = smol::process::Command::new("git")
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()
.unwrap();
let mut stdin = apply_process.stdin.take().unwrap();
stdin
.write_all(example.uncommitted_diff.as_bytes())
.await
.unwrap();
stdin.close().await.unwrap();
drop(stdin);
let apply_result = apply_process.output().await.unwrap();
if !apply_result.status.success() {
panic!(
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
}
}
async fn apply_edit_history(
example: &Example,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers> {
edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
}
thread_local! {
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
}
#[must_use]
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
REPO_LOCKS
.with(|cell| {
cell.borrow_mut()
.entry(path.as_ref().to_path_buf())
.or_default()
.clone()
})
.lock_owned()
.await
}
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = smol::process::Command::new("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
anyhow::ensure!(
output.status.success(),
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
args.join(" "),
repo_path.display(),
output.status,
String::from_utf8_lossy(&output.stderr),
String::from_utf8_lossy(&output.stdout),
);
Ok(String::from_utf8(output.stdout)?.trim().to_string())
}

View File

@@ -1,522 +1,196 @@
mod evaluate;
mod anthropic_client;
mod example;
mod format_prompt;
mod headless;
mod load_project;
mod metrics;
mod paths;
mod predict;
mod source_location;
mod training;
mod util;
mod retrieve_context;
mod score;
use crate::{
evaluate::run_evaluate,
example::{ExampleFormat, NamedExample},
headless::ZetaCliAppState,
predict::run_predict,
source_location::SourceLocation,
training::{context::ContextType, distill::run_distill},
util::{open_buffer, open_buffer_with_language_server},
};
use ::util::{ResultExt, paths::PathStyle};
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand, ValueEnum};
use cloud_llm_client::predict_edits_v3;
use edit_prediction::udiff::DiffLine;
use edit_prediction_context::EditPredictionExcerptOptions;
use gpui::{Application, AsyncApp, Entity, prelude::*};
use language::{Bias, Buffer, BufferSnapshot, Point};
use metrics::delta_chr_f;
use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use edit_prediction::EditPredictionStore;
use gpui::Application;
use reqwest_client::ReqwestClient;
use std::io::{self};
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use crate::example::{read_examples, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::predict::run_prediction;
use crate::retrieve_context::run_context_retrieval;
use crate::score::run_scoring;
#[derive(Parser, Debug)]
#[command(name = "zeta")]
struct ZetaCliArgs {
#[command(name = "ep")]
struct EpArgs {
#[arg(long, default_value_t = false)]
printenv: bool,
#[clap(long, default_value_t = 10)]
max_parallelism: usize,
#[command(subcommand)]
command: Option<Command>,
#[clap(global = true)]
inputs: Vec<PathBuf>,
#[arg(long, short, global = true)]
output: Option<PathBuf>,
#[arg(long, short, global = true)]
in_place: bool,
}
#[derive(Subcommand, Debug)]
enum Command {
Context(ContextArgs),
Predict(PredictArguments),
Eval(EvaluateArguments),
Distill(DistillArguments),
ConvertExample {
path: PathBuf,
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
output_format: ExampleFormat,
},
Score {
golden_patch: PathBuf,
actual_patch: PathBuf,
},
/// Parse markdown examples and output a combined .jsonl file
ParseExample,
/// Create git worktrees for each example and load file contents
LoadBuffer,
/// Retrieve context for input examples.
Context,
/// Generate a prompt string for a specific model
FormatPrompt(FormatPromptArgs),
/// Runs edit prediction
Predict(PredictArgs),
/// Computes a score based on actual and expected patches
Score(PredictArgs),
/// Print aggregated scores
Eval(PredictArgs),
/// Remove git repositories and worktrees
Clean,
}
#[derive(Debug, Args)]
struct ContextArgs {
#[arg(long)]
provider: ContextProvider,
#[arg(long)]
worktree: PathBuf,
#[arg(long)]
cursor: SourceLocation,
#[arg(long)]
use_language_server: bool,
#[arg(long)]
edit_history: Option<FileOrStdin>,
#[clap(flatten)]
zeta2_args: Zeta2Args,
struct FormatPromptArgs {
#[clap(long)]
prompt_format: PromptFormat,
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
enum ContextProvider {
Zeta1,
#[default]
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
enum PromptFormat {
Teacher,
Zeta2,
}
#[derive(Clone, Debug, Args)]
struct Zeta2Args {
#[arg(long, default_value_t = 8192)]
max_prompt_bytes: usize,
#[arg(long, default_value_t = 2048)]
max_excerpt_bytes: usize,
#[arg(long, default_value_t = 1024)]
min_excerpt_bytes: usize,
#[arg(long, default_value_t = 0.66)]
target_before_cursor_over_total_bytes: f32,
#[arg(long, default_value_t = 1024)]
max_diagnostic_bytes: usize,
#[arg(long, value_enum, default_value_t = PromptFormat::default())]
prompt_format: PromptFormat,
#[arg(long, value_enum, default_value_t = Default::default())]
output_format: OutputFormat,
#[arg(long, default_value_t = 42)]
file_indexing_parallelism: usize,
#[arg(long, default_value_t = false)]
disable_imports_gathering: bool,
#[arg(long, default_value_t = u8::MAX)]
max_retrieved_definitions: u8,
}
#[derive(Debug, Args)]
pub struct PredictArguments {
#[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
format: PredictionsOutputFormat,
example_path: PathBuf,
#[clap(flatten)]
options: PredictionOptions,
}
#[derive(Debug, Args)]
pub struct DistillArguments {
split_commit_dataset: PathBuf,
#[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
context_type: ContextType,
#[clap(long)]
batch: Option<String>,
}
#[derive(Clone, Debug, Args)]
pub struct PredictionOptions {
#[clap(flatten)]
zeta2: Zeta2Args,
struct PredictArgs {
#[clap(long)]
provider: PredictionProvider,
#[clap(long, value_enum, default_value_t = CacheMode::default())]
cache: CacheMode,
#[clap(long, default_value_t = 1)]
repetitions: usize,
}
#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
pub enum CacheMode {
/// Use cached LLM requests and responses, except when multiple repetitions are requested
#[default]
Auto,
/// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
#[value(alias = "request")]
Requests,
/// Ignore existing cache entries for both LLM and search.
Skip,
/// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
/// Useful for reproducing results and fixing bugs outside of search queries
Force,
}
impl CacheMode {
fn use_cached_llm_responses(&self) -> bool {
self.assert_not_auto();
matches!(self, CacheMode::Requests | CacheMode::Force)
}
fn use_cached_search_results(&self) -> bool {
self.assert_not_auto();
matches!(self, CacheMode::Force)
}
fn assert_not_auto(&self) {
assert_ne!(
*self,
CacheMode::Auto,
"Cache mode should not be auto at this point!"
);
}
}
#[derive(clap::ValueEnum, Debug, Clone)]
pub enum PredictionsOutputFormat {
Json,
Md,
Diff,
}
#[derive(Debug, Args)]
pub struct EvaluateArguments {
example_paths: Vec<PathBuf>,
#[clap(flatten)]
options: PredictionOptions,
#[clap(short, long, default_value_t = 1, alias = "repeat")]
repetitions: u16,
#[arg(long)]
skip_prediction: bool,
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
enum PredictionProvider {
Zeta1,
#[default]
Zeta2,
Sweep,
Mercury,
Zeta1,
Zeta2,
Teacher,
}
fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
edit_prediction::ZetaOptions {
context: EditPredictionExcerptOptions {
max_bytes: args.max_excerpt_bytes,
min_bytes: args.min_excerpt_bytes,
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
},
max_prompt_bytes: args.max_prompt_bytes,
prompt_format: args.prompt_format.into(),
}
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
enum PromptFormat {
OnlySnippets,
#[default]
OldTextNewText,
Minimal,
MinimalQwen,
SeedCoder1120,
}
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
fn into(self) -> predict_edits_v3::PromptFormat {
match self {
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
}
}
}
#[derive(clap::ValueEnum, Default, Debug, Clone)]
enum OutputFormat {
#[default]
Prompt,
Request,
Full,
}
#[derive(Debug, Clone)]
enum FileOrStdin {
File(PathBuf),
Stdin,
}
impl FileOrStdin {
async fn read_to_string(&self) -> Result<String, std::io::Error> {
match self {
FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
}
}
}
impl FromStr for FileOrStdin {
type Err = <PathBuf as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"-" => Ok(Self::Stdin),
_ => Ok(Self::File(PathBuf::from_str(s)?)),
}
}
}
struct LoadedContext {
full_path_str: String,
snapshot: BufferSnapshot,
clipped_cursor: Point,
worktree: Entity<Worktree>,
project: Entity<Project>,
buffer: Entity<Buffer>,
lsp_open_handle: Option<OpenLspBufferHandle>,
}
async fn load_context(
args: &ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<LoadedContext> {
let ContextArgs {
worktree: worktree_path,
cursor,
use_language_server,
..
} = args;
let worktree_path = worktree_path.canonicalize()?;
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(&worktree_path, true, cx)
})?
.await?;
let mut ready_languages = HashSet::default();
let (lsp_open_handle, buffer) = if *use_language_server {
let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
project.clone(),
worktree.clone(),
cursor.path.clone(),
&mut ready_languages,
cx,
)
.await?;
(Some(lsp_open_handle), buffer)
} else {
let buffer =
open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
(None, buffer)
};
let full_path_str = worktree
.read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
.display(PathStyle::local())
.to_string();
let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
if clipped_cursor != cursor.point {
let max_row = snapshot.max_point().row;
if cursor.point.row < max_row {
return Err(anyhow!(
"Cursor position {:?} is out of bounds (line length is {})",
cursor.point,
snapshot.line_len(cursor.point.row)
));
impl EpArgs {
fn output_path(&self) -> Option<PathBuf> {
if self.in_place {
if self.inputs.len() == 1 {
self.inputs.first().cloned()
} else {
panic!("--in-place requires exactly one input file")
}
} else {
return Err(anyhow!(
"Cursor position {:?} is out of bounds (max row is {})",
cursor.point,
max_row
));
self.output.clone()
}
}
Ok(LoadedContext {
full_path_str,
snapshot,
clipped_cursor,
worktree,
project,
buffer,
lsp_open_handle,
})
}
async fn zeta2_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<String> {
let LoadedContext {
worktree,
project,
buffer,
clipped_cursor,
lsp_open_handle: _handle,
..
} = load_context(&args, app_state, cx).await?;
// wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
// the whole worktree.
worktree
.read_with(cx, |worktree, _cx| {
worktree.as_local().unwrap().scan_complete()
})?
.await;
let output = cx
.update(|cx| {
let store = cx.new(|cx| {
edit_prediction::EditPredictionStore::new(
app_state.client.clone(),
app_state.user_store.clone(),
cx,
)
});
store.update(cx, |store, cx| {
store.set_options(zeta2_args_to_options(&args.zeta2_args));
store.register_buffer(&buffer, &project, cx);
});
cx.spawn(async move |cx| {
let updates_rx = store.update(cx, |store, cx| {
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
store.set_use_context(true);
store.refresh_context(&project, &buffer, cursor, cx);
store.project_context_updates(&project).unwrap()
})?;
updates_rx.recv().await.ok();
let context = store.update(cx, |store, cx| {
store.context_for_project(&project, cx).to_vec()
})?;
anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
})
})?
.await?;
Ok(output)
}
async fn zeta1_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<edit_prediction::zeta1::GatherContextOutput> {
let LoadedContext {
full_path_str,
snapshot,
clipped_cursor,
..
} = load_context(&args, app_state, cx).await?;
let events = match args.edit_history {
Some(events) => events.read_to_string().await?,
None => String::new(),
};
let prompt_for_events = move || (events, 0);
cx.update(|cx| {
edit_prediction::zeta1::gather_context(
full_path_str,
&snapshot,
clipped_cursor,
prompt_for_events,
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})?
.await
}
fn main() {
zlog::init();
zlog::init_output_stderr();
let args = ZetaCliArgs::parse();
let args = EpArgs::parse();
if args.printenv {
::util::shell_env::print_env();
return;
}
let output = args.output_path();
let command = match args.command {
Some(cmd) => cmd,
None => {
EpArgs::command().print_help().unwrap();
return;
}
};
match &command {
Command::Clean => {
std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
return;
}
_ => {}
}
let mut examples = read_examples(&args.inputs);
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client);
app.run(move |cx| {
let app_state = Arc::new(headless::init(cx));
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
cx.spawn(async move |cx| {
match args.command {
None => {
if args.printenv {
::util::shell_env::print_env();
} else {
panic!("Expected a command");
}
}
Some(Command::Context(context_args)) => {
let result = match context_args.provider {
ContextProvider::Zeta1 => {
let context =
zeta1_context(context_args, &app_state, cx).await.unwrap();
serde_json::to_string_pretty(&context.body).unwrap()
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await,
_ => (),
};
for data in examples.chunks_mut(args.max_parallelism) {
let mut futures = Vec::new();
for example in data.iter_mut() {
let cx = cx.clone();
let app_state = app_state.clone();
futures.push(async {
match &command {
Command::ParseExample => {}
Command::LoadBuffer => {
run_load_project(example, app_state.clone(), cx).await;
}
Command::Context => {
run_context_retrieval(example, app_state, cx).await;
}
Command::FormatPrompt(args) => {
run_format_prompt(example, args.prompt_format, app_state, cx).await;
}
Command::Predict(args) => {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state.clone(),
cx,
)
.await;
}
Command::Score(args) | Command::Eval(args) => {
run_scoring(example, &args, app_state, cx).await;
}
Command::Clean => {
unreachable!()
}
}
ContextProvider::Zeta2 => {
zeta2_context(context_args, &app_state, cx).await.unwrap()
}
};
println!("{}", result);
});
}
Some(Command::Predict(arguments)) => {
run_predict(arguments, &app_state, cx).await;
}
Some(Command::Eval(arguments)) => {
run_evaluate(arguments, &app_state, cx).await;
}
Some(Command::Distill(arguments)) => {
let _guard = cx
.update(|cx| gpui_tokio::Tokio::handle(cx))
.unwrap()
.enter();
run_distill(arguments).await.log_err();
}
Some(Command::ConvertExample {
path,
output_format,
}) => {
let example = NamedExample::load(path).unwrap();
example.write(output_format, io::stdout()).unwrap();
}
Some(Command::Score {
golden_patch,
actual_patch,
}) => {
let golden_content = std::fs::read_to_string(golden_patch).unwrap();
let actual_content = std::fs::read_to_string(actual_patch).unwrap();
futures::future::join_all(futures).await;
}
let golden_diff: Vec<DiffLine> = golden_content
.lines()
.map(|line| DiffLine::parse(line))
.collect();
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
write_examples(&examples, output.as_ref());
}
let actual_diff: Vec<DiffLine> = actual_content
.lines()
.map(|line| DiffLine::parse(line))
.collect();
let score = delta_chr_f(&golden_diff, &actual_diff);
println!("{:.2}", score);
}
Some(Command::Clean) => {
std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
}
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await,
Command::Eval(_) => score::print_report(&examples),
_ => (),
};
let _ = cx.update(|cx| cx.quit());

View File

@@ -1,30 +1,34 @@
use collections::{HashMap, HashSet};
use edit_prediction::udiff::DiffLine;
use serde::{Deserialize, Serialize};
type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
#[derive(Default, Debug, Clone)]
pub struct Scores {
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationMetrics {
pub true_positives: usize,
pub false_positives: usize,
pub false_negatives: usize,
}
impl Scores {
pub fn from_sets(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
impl ClassificationMetrics {
pub fn from_sets(
expected: &HashSet<String>,
actual: &HashSet<String>,
) -> ClassificationMetrics {
let true_positives = expected.intersection(actual).count();
let false_positives = actual.difference(expected).count();
let false_negatives = expected.difference(actual).count();
Scores {
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores {
pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -45,32 +49,16 @@ impl Scores {
}
}
Scores {
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
pub fn to_markdown(&self) -> String {
format!(
"
Precision : {:.4}
Recall : {:.4}
F1 Score : {:.4}
True Positives : {}
False Positives : {}
False Negatives : {}",
self.precision(),
self.recall(),
self.f1_score(),
self.true_positives,
self.false_positives,
self.false_negatives
)
}
pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
pub fn aggregate<'a>(
scores: impl Iterator<Item = &'a ClassificationMetrics>,
) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -81,7 +69,7 @@ False Negatives : {}",
false_negatives += score.false_negatives;
}
Scores {
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
@@ -115,7 +103,10 @@ False Negatives : {}",
}
}
pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores {
pub fn line_match_score(
expected_patch: &[DiffLine],
actual_patch: &[DiffLine],
) -> ClassificationMetrics {
let expected_change_lines = expected_patch
.iter()
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
@@ -128,7 +119,7 @@ pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine])
.map(|line| line.to_string())
.collect();
Scores::from_sets(&expected_change_lines, &actual_change_lines)
ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
}
enum ChrfWhitespace {
@@ -204,7 +195,7 @@ pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
let expected_counts = ngram_delta_to_counts(&expected_delta);
let actual_counts = ngram_delta_to_counts(&actual_delta);
let score = Scores::from_counts(&expected_counts, &actual_counts);
let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
total_precision += score.precision();
total_recall += score.recall();
}

View File

@@ -1,57 +1,25 @@
use std::{env, path::PathBuf, sync::LazyLock};
use std::{
path::{Path, PathBuf},
sync::LazyLock,
};
pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
pub static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
let dir = dirs::home_dir().unwrap().join(".zed_ep");
ensure_dir(&dir)
});
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("cache")));
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("repos")));
pub static WORKTREES_DIR: LazyLock<PathBuf> =
LazyLock::new(|| ensure_dir(&DATA_DIR.join("worktrees")));
pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
TARGET_ZETA_DIR
DATA_DIR
.join("runs")
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
});
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
pub fn print_run_data_dir(deep: bool, use_color: bool) {
println!("\n## Run Data\n");
let mut files = Vec::new();
let current_dir = std::env::current_dir().unwrap();
for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
let file = file.unwrap();
if file.file_type().unwrap().is_dir() && deep {
for file in std::fs::read_dir(file.path()).unwrap() {
let path = file.unwrap().path();
let path = path.strip_prefix(&current_dir).unwrap_or(&path);
files.push(format!(
"- {}/{}{}{}",
path.parent().unwrap().display(),
if use_color { "\x1b[34m" } else { "" },
path.file_name().unwrap().display(),
if use_color { "\x1b[0m" } else { "" },
));
}
} else {
let path = file.path();
let path = path.strip_prefix(&current_dir).unwrap_or(&path);
files.push(format!(
"- {}/{}{}{}",
path.parent().unwrap().display(),
if use_color { "\x1b[34m" } else { "" },
path.file_name().unwrap().display(),
if use_color { "\x1b[0m" } else { "" }
));
}
}
files.sort();
for file in files {
println!("{}", file);
}
println!(
"\n💡 Tip of the day: {} always points to the latest run\n",
LATEST_EXAMPLE_RUN_DIR.display()
);
fn ensure_dir(path: &Path) -> PathBuf {
std::fs::create_dir_all(path).expect("Failed to create directory");
path.to_path_buf()
}

View File

@@ -1,374 +1,271 @@
use crate::example::{ActualExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
use crate::{
CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
PredictionProvider, PromptFormat,
anthropic_client::AnthropicClient,
example::{Example, ExamplePrediction},
format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
headless::EpAppState,
load_project::run_load_project,
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
retrieve_context::run_context_retrieval,
};
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, future::Shared};
use gpui::{AppContext as _, AsyncApp, Task};
use std::{
fs,
sync::{
Arc, Mutex, OnceLock,
atomic::{AtomicUsize, Ordering::SeqCst},
},
};
use ::serde::Serialize;
use anyhow::{Context, Result, anyhow};
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
use futures::StreamExt as _;
use gpui::{AppContext, AsyncApp, Entity};
use project::Project;
use project::buffer_store::BufferStoreEvent;
use serde::Deserialize;
use std::fs;
use std::io::{IsTerminal, Write};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
pub async fn run_predict(
args: PredictArguments,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
pub async fn run_prediction(
example: &mut Example,
provider: Option<PredictionProvider>,
repetition_count: usize,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) {
let example = NamedExample::load(args.example_path).unwrap();
let project = example.setup_project(app_state, cx).await.unwrap();
let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
let result = perform_predict(example, project, store, None, args.options, cx)
.await
if !example.predictions.is_empty() {
return;
}
run_load_project(example, app_state.clone(), cx.clone()).await;
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
let provider = provider.unwrap();
if matches!(provider, PredictionProvider::Teacher) {
if example.prompt.is_none() {
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
}
let batched = true;
return predict_anthropic(example, repetition_count, batched).await;
}
if matches!(
provider,
PredictionProvider::Zeta1 | PredictionProvider::Zeta2
) {
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
AUTHENTICATED
.get_or_init(|| {
let client = app_state.client.clone();
cx.spawn(async move |cx| {
client
.sign_in_with_optional_connect(true, cx)
.await
.unwrap();
})
.shared()
})
.clone()
.await;
}
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
result.write(args.format, std::io::stdout()).unwrap();
print_run_data_dir(true, std::io::stdout().is_terminal());
}
ep_store
.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher => unreachable!(),
};
store.set_edit_prediction_model(model);
})
.unwrap();
let state = example.state.as_ref().unwrap();
let run_dir = RUN_DIR.join(&example.name);
pub fn setup_store(
provider: PredictionProvider,
project: &Entity<Project>,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<Entity<EditPredictionStore>> {
let store = cx.new(|cx| {
edit_prediction::EditPredictionStore::new(
app_state.client.clone(),
app_state.user_store.clone(),
cx,
)
})?;
let updated_example = Arc::new(Mutex::new(example.clone()));
let current_run_ix = Arc::new(AtomicUsize::new(0));
store.update(cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
};
store.set_edit_prediction_model(model);
})?;
let mut debug_rx = ep_store
.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
.unwrap();
let debug_task = cx.background_spawn({
let updated_example = updated_example.clone();
let current_run_ix = current_run_ix.clone();
let run_dir = run_dir.clone();
async move {
while let Some(event) = debug_rx.next().await {
let run_ix = current_run_ix.load(SeqCst);
let mut updated_example = updated_example.lock().unwrap();
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
let run_dir = if repetition_count > 1 {
run_dir.join(format!("{:03}", run_ix))
} else {
run_dir.clone()
};
cx.subscribe(&buffer_store, {
let project = project.clone();
let store = store.clone();
move |_, event, cx| match event {
BufferStoreEvent::BufferAdded(buffer) => {
store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
}
_ => {}
}
})?
.detach();
match event {
DebugEvent::EditPredictionStarted(request) => {
assert_eq!(updated_example.predictions.len(), run_ix + 1);
anyhow::Ok(store)
}
pub async fn perform_predict(
example: NamedExample,
project: Entity<Project>,
store: Entity<EditPredictionStore>,
repetition_ix: Option<u16>,
options: PredictionOptions,
cx: &mut AsyncApp,
) -> Result<PredictionDetails> {
let mut cache_mode = options.cache;
if repetition_ix.is_some() {
if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
panic!("Repetitions are not supported in Auto cache mode");
} else {
cache_mode = CacheMode::Skip;
}
} else if cache_mode == CacheMode::Auto {
cache_mode = CacheMode::Requests;
}
let mut example_run_dir = RUN_DIR.join(&example.file_name());
if let Some(repetition_ix) = repetition_ix {
example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
}
fs::create_dir_all(&example_run_dir)?;
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
}
#[cfg(unix)]
std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
.context("creating latest link")?;
#[cfg(windows)]
std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
.context("creating latest link")?;
store.update(cx, |store, _cx| {
store.with_eval_cache(Arc::new(RunCache {
example_run_dir: example_run_dir.clone(),
cache_mode,
}));
})?;
let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
let prompt_format = options.zeta2.prompt_format;
store.update(cx, |store, _cx| {
let mut options = store.options().clone();
options.prompt_format = prompt_format.into();
store.set_options(options);
})?;
let mut debug_task = gpui::Task::ready(Ok(()));
if options.provider == crate::PredictionProvider::Zeta2 {
let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
debug_task = cx.background_spawn({
let result = result.clone();
async move {
let mut start_time = None;
let mut retrieval_finished_at = None;
while let Some(event) = debug_rx.next().await {
match event {
edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
fs::write(
example_run_dir.join("search_prompt.md"),
&info.search_prompt,
)?;
if let Some(prompt) = request.prompt {
fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
}
edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
retrieval_finished_at = Some(info.timestamp);
for (key, value) in &info.metadata {
if *key == "search_queries" {
fs::write(
example_run_dir.join("search_queries.json"),
value.as_bytes(),
)?;
}
}
}
DebugEvent::EditPredictionFinished(request) => {
assert_eq!(updated_example.predictions.len(), run_ix + 1);
if let Some(output) = request.model_output {
fs::write(run_dir.join("prediction_response.md"), &output)?;
updated_example
.predictions
.last_mut()
.unwrap()
.actual_output = output;
}
edit_prediction::DebugEvent::EditPredictionRequested(request) => {
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
let prompt = request.local_prompt.unwrap_or_default();
fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
{
let mut result = result.lock().unwrap();
result.prompt_len = prompt.chars().count();
for included_file in request.inputs.included_files {
let insertions =
vec![(request.inputs.cursor_point, CURSOR_MARKER)];
result.excerpts.extend(included_file.excerpts.iter().map(
|excerpt| ActualExcerpt {
path: included_file.path.components().skip(1).collect(),
text: String::from(excerpt.text.as_ref()),
},
));
write_codeblock(
&included_file.path,
included_file.excerpts.iter(),
if included_file.path == request.inputs.cursor_path {
&insertions
} else {
&[]
},
included_file.max_row,
false,
&mut result.excerpts_text,
);
}
}
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
let response =
edit_prediction::open_ai_response::text_from_response(response)
.unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
let mut result = result.lock().unwrap();
result.generated_len = response.chars().count();
result.retrieval_time =
retrieval_finished_at.unwrap() - start_time.unwrap();
result.prediction_time = prediction_finished_at - prediction_started_at;
result.total_time = prediction_finished_at - start_time.unwrap();
if run_ix >= repetition_count {
break;
}
}
_ => {}
}
anyhow::Ok(())
}
});
anyhow::Ok(())
}
});
store.update(cx, |store, cx| {
store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
})?;
}
let prediction = store
.update(cx, |store, cx| {
store.request_prediction(
&project,
&cursor_buffer,
cursor_anchor,
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})?
.await?;
debug_task.await?;
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
result.diff = prediction
.and_then(|prediction| {
let prediction = prediction.prediction.ok()?;
prediction.edit_preview.as_unified_diff(&prediction.edits)
})
.unwrap_or_default();
anyhow::Ok(result)
}
struct RunCache {
cache_mode: CacheMode,
example_run_dir: PathBuf,
}
impl RunCache {
fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
}
fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
}
fn link_to_run(&self, key: &EvalCacheKey) {
let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
}
}
impl EvalCache for RunCache {
fn read(&self, key: EvalCacheKey) -> Option<String> {
let path = RunCache::output_cache_path(&key);
if path.exists() {
let use_cache = match key.0 {
EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
self.cache_mode.use_cached_llm_responses()
}
};
if use_cache {
log::info!("Using cache entry: {}", path.display());
self.link_to_run(&key);
Some(fs::read_to_string(path).unwrap())
} else {
log::trace!("Skipping cached entry: {}", path.display());
None
}
} else if matches!(self.cache_mode, CacheMode::Force) {
panic!(
"No cached entry found for {:?}. Run without `--cache force` at least once.",
key.0
);
for ix in 0..repetition_count {
current_run_ix.store(ix, SeqCst);
let run_dir = if repetition_count > 1 {
run_dir.join(format!("{:03}", ix))
} else {
None
}
}
fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
fs::create_dir_all(&*CACHE_DIR).unwrap();
let input_path = RunCache::input_cache_path(&key);
fs::write(&input_path, input).unwrap();
let output_path = RunCache::output_cache_path(&key);
log::trace!("Writing cache entry: {}", output_path.display());
fs::write(&output_path, output).unwrap();
self.link_to_run(&key);
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PredictionDetails {
pub diff: String,
pub excerpts: Vec<ActualExcerpt>,
pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
pub retrieval_time: Duration,
pub prediction_time: Duration,
pub total_time: Duration,
pub run_example_dir: PathBuf,
pub prompt_len: usize,
pub generated_len: usize,
}
impl PredictionDetails {
pub fn new(run_example_dir: PathBuf) -> Self {
Self {
diff: Default::default(),
excerpts: Default::default(),
excerpts_text: Default::default(),
retrieval_time: Default::default(),
prediction_time: Default::default(),
total_time: Default::default(),
run_example_dir,
prompt_len: 0,
generated_len: 0,
}
}
pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
let formatted = match format {
PredictionsOutputFormat::Md => self.to_markdown(),
PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
PredictionsOutputFormat::Diff => self.diff.clone(),
run_dir.clone()
};
Ok(out.write_all(formatted.as_bytes())?)
fs::create_dir_all(&run_dir).unwrap();
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
}
#[cfg(unix)]
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
#[cfg(windows)]
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
updated_example
.lock()
.unwrap()
.predictions
.push(ExamplePrediction {
actual_patch: String::new(),
actual_output: String::new(),
provider,
});
let prediction = ep_store
.update(&mut cx, |store, cx| {
store.request_prediction(
&state.project,
&state.buffer,
state.cursor_position,
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})
.unwrap()
.await
.unwrap();
updated_example
.lock()
.unwrap()
.predictions
.last_mut()
.unwrap()
.actual_patch = prediction
.and_then(|prediction| {
let prediction = prediction.prediction.ok()?;
prediction.edit_preview.as_unified_diff(&prediction.edits)
})
.unwrap_or_default();
}
pub fn to_markdown(&self) -> String {
format!(
"## Excerpts\n\n\
{}\n\n\
## Prediction\n\n\
{}\n\n\
## Time\n\n\
Retrieval: {}ms\n\
Prediction: {}ms\n\n\
Total: {}ms\n",
self.excerpts_text,
self.diff,
self.retrieval_time.as_millis(),
self.prediction_time.as_millis(),
self.total_time.as_millis(),
)
ep_store
.update(&mut cx, |store, _| {
store.remove_project(&state.project);
})
.unwrap();
debug_task.await.unwrap();
*example = Arc::into_inner(updated_example)
.unwrap()
.into_inner()
.unwrap();
}
async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
let llm_model_name = "claude-sonnet-4-5";
let max_tokens = 16384;
let llm_client = if batched {
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
} else {
AnthropicClient::plain()
};
let llm_client = llm_client.expect("Failed to create LLM client");
let prompt = example
.prompt
.as_ref()
.unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
let messages = vec![anthropic::Message {
role: anthropic::Role::User,
content: vec![anthropic::RequestContent::Text {
text: prompt.input.clone(),
cache_control: None,
}],
}];
let Some(response) = llm_client
.generate(llm_model_name, max_tokens, messages)
.await
.unwrap()
else {
// Request stashed for batched processing
return;
};
let actual_output = response
.content
.into_iter()
.filter_map(|content| match content {
anthropic::ResponseContent::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<String>>()
.join("\n");
let actual_patch = TeacherPrompt::parse(example, &actual_output);
let prediction = ExamplePrediction {
actual_patch,
actual_output,
provider: PredictionProvider::Teacher,
};
example.predictions.push(prediction);
}
pub async fn sync_batches(provider: &PredictionProvider) {
match provider {
PredictionProvider::Teacher => {
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
let llm_client =
AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
llm_client
.sync_batches()
.await
.expect("Failed to sync batches");
}
_ => (),
}
}

View File

@@ -1,106 +1,136 @@
use anyhow::{Result, anyhow};
use futures::channel::mpsc;
use futures::{FutureExt as _, StreamExt as _};
use crate::{
example::{Example, ExampleContext},
headless::EpAppState,
load_project::run_load_project,
};
use anyhow::Result;
use collections::HashSet;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
use gpui::{AsyncApp, Entity, Task};
use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
use project::lsp_store::OpenLspBufferHandle;
use project::{Project, ProjectPath, Worktree};
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use util::rel_path::RelPath;
use language::{Buffer, LanguageNotFound};
use project::Project;
use std::{sync::Arc, time::Duration};
pub fn open_buffer(
project: Entity<Project>,
worktree: Entity<Worktree>,
path: Arc<RelPath>,
cx: &AsyncApp,
) -> Task<Result<Entity<Buffer>>> {
cx.spawn(async move |cx| {
let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
worktree_id: worktree.id(),
path,
})?;
pub async fn run_context_retrieval(
example: &mut Example,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) {
if example.context.is_some() {
return;
}
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
run_load_project(example, app_state.clone(), cx.clone()).await;
let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
let state = example.state.as_ref().unwrap();
let project = state.project.clone();
let _lsp_handle = project
.update(&mut cx, |project, cx| {
project.register_buffer_with_language_servers(&state.buffer, cx)
})
.unwrap();
wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let mut events = ep_store
.update(&mut cx, |store, cx| {
store.register_buffer(&state.buffer, &project, cx);
store.set_use_context(true);
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
store.debug_info(&project, cx)
})
.unwrap();
while let Some(event) = events.next().await {
match event {
DebugEvent::ContextRetrievalFinished(_) => {
break;
}
_ => {}
}
}
Ok(buffer)
})
let context_files = ep_store
.update(&mut cx, |store, cx| store.context_for_project(&project, cx))
.unwrap();
example.context = Some(ExampleContext {
files: context_files,
});
}
pub async fn open_buffer_with_language_server(
project: Entity<Project>,
worktree: Entity<Worktree>,
path: Arc<RelPath>,
ready_languages: &mut HashSet<LanguageId>,
async fn wait_for_language_server_to_start(
example: &Example,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
cx: &mut AsyncApp,
) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
(
project.register_buffer_with_language_servers(&buffer, cx),
project.path_style(cx),
)
})?;
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
) {
let language_registry = project
.read_with(cx, |project, _| project.languages().clone())
.unwrap();
let result = language_registry
.load_language_for_file_path(path.as_std_path())
.load_language_for_file_path(&example.cursor_path)
.await;
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
anyhow::bail!(error);
panic!("Failed to load language for file path: {}", error);
}
let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
buffer.language().map(|language| language.id())
})?
let Some(language_id) = buffer
.read_with(cx, |buffer, _cx| {
buffer.language().map(|language| language.id())
})
.unwrap()
else {
return Err(anyhow!("No language for {}", path.display(path_style)));
panic!("No language for {:?}", example.cursor_path);
};
let log_prefix = format!("{} | ", path.display(path_style));
let mut ready_languages = HashSet::default();
let log_prefix = format!("{} | ", example.name);
if !ready_languages.contains(&language_id) {
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
wait_for_lang_server(&project, &buffer, log_prefix, cx)
.await
.unwrap();
ready_languages.insert(language_id);
}
let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
let lsp_store = project
.read_with(cx, |project, _cx| project.lsp_store())
.unwrap();
// hacky wait for buffer to be registered with the language server
for _ in 0..100 {
let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
buffer.update(cx, |buffer, cx| {
lsp_store
.language_servers_for_local_buffer(&buffer, cx)
.next()
.map(|(_, language_server)| language_server.server_id())
if lsp_store
.update(cx, |lsp_store, cx| {
buffer.update(cx, |buffer, cx| {
lsp_store
.language_servers_for_local_buffer(&buffer, cx)
.next()
.map(|(_, language_server)| language_server.server_id())
})
})
})?
else {
.unwrap()
.is_some()
{
return;
} else {
cx.background_executor()
.timer(Duration::from_millis(10))
.await;
continue;
};
return Ok((lsp_open_handle, language_server_id, buffer));
}
}
return Err(anyhow!("No language server found for buffer"));
panic!("No language server found for buffer");
}
// TODO: Dedupe with similar function in crates/eval/src/instance.rs
pub fn wait_for_lang_server(
project: &Entity<Project>,
buffer: &Entity<Buffer>,

View File

@@ -0,0 +1,119 @@
use crate::{
PredictArgs,
example::{Example, ExampleScore},
headless::EpAppState,
metrics::{self, ClassificationMetrics},
predict::run_prediction,
};
use edit_prediction::udiff::DiffLine;
use gpui::AsyncApp;
use std::sync::Arc;
pub async fn run_scoring(
example: &mut Example,
args: &PredictArgs,
app_state: Arc<EpAppState>,
cx: AsyncApp,
) {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state,
cx,
)
.await;
let expected_patch = parse_patch(&example.expected_patch);
let mut scores = vec![];
for pred in &example.predictions {
let actual_patch = parse_patch(&pred.actual_patch);
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
scores.push(ExampleScore {
delta_chr_f,
line_match,
});
}
example.score = scores;
}
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
patch.lines().map(DiffLine::parse).collect()
}
pub fn print_report(examples: &[Example]) {
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
"Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
);
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
let mut all_line_match_scores = Vec::new();
let mut all_delta_chr_f_scores = Vec::new();
for example in examples {
for score in example.score.iter() {
let line_match = &score.line_match;
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
truncate_name(&example.name, 30),
line_match.true_positives,
line_match.false_positives,
line_match.false_negatives,
line_match.precision() * 100.0,
line_match.recall() * 100.0,
line_match.f1_score() * 100.0,
score.delta_chr_f
);
all_line_match_scores.push(line_match.clone());
all_delta_chr_f_scores.push(score.delta_chr_f);
}
}
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
if !all_line_match_scores.is_empty() {
let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
let avg_delta_chr_f: f32 =
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
"TOTAL",
total_line_match.true_positives,
total_line_match.false_positives,
total_line_match.false_negatives,
total_line_match.precision() * 100.0,
total_line_match.recall() * 100.0,
total_line_match.f1_score() * 100.0,
avg_delta_chr_f
);
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
}
eprintln!("\n");
}
fn truncate_name(name: &str, max_len: usize) -> String {
if name.len() <= max_len {
name.to_string()
} else {
format!("{}...", &name[..max_len - 3])
}
}

View File

@@ -1,70 +0,0 @@
use std::{fmt, fmt::Display, path::Path, str::FromStr, sync::Arc};
use ::util::{paths::PathStyle, rel_path::RelPath};
use anyhow::{Result, anyhow};
use language::Point;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct SourceLocation {
pub path: Arc<RelPath>,
pub point: Point,
}
impl Serialize for SourceLocation {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for SourceLocation {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
impl Display for SourceLocation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}:{}:{}",
self.path.display(PathStyle::Posix),
self.point.row + 1,
self.point.column + 1
)
}
}
impl FromStr for SourceLocation {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
let parts: Vec<&str> = s.split(':').collect();
if parts.len() != 3 {
return Err(anyhow!(
"Invalid source location. Expected 'file.rs:line:column', got '{}'",
s
));
}
let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
let line: u32 = parts[1]
.parse()
.map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
let column: u32 = parts[2]
.parse()
.map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
// Convert from 1-based to 0-based indexing
let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
Ok(SourceLocation { path, point })
}
}

View File

@@ -46,3 +46,7 @@ Output example:
## Code Context
{{context}}
## Editable region
{{editable_region}}

View File

@@ -1,89 +0,0 @@
use std::path::Path;
use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
#[derive(Debug, Clone, Default, clap::ValueEnum)]
pub enum ContextType {
#[default]
CurrentFile,
}
const MAX_CONTEXT_SIZE: usize = 32768;
pub fn collect_context(
context_type: &ContextType,
worktree_dir: &Path,
cursor: SourceLocation,
) -> String {
let context = match context_type {
ContextType::CurrentFile => {
let file_path = worktree_dir.join(cursor.path.as_std_path());
let context = std::fs::read_to_string(&file_path).unwrap_or_default();
let context = add_special_tags(&context, worktree_dir, cursor);
context
}
};
let region_end_offset = context.find(TeacherModel::REGION_END);
if context.len() <= MAX_CONTEXT_SIZE {
return context;
}
if let Some(region_end_offset) = region_end_offset
&& region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
{
let to_truncate = context.len() - MAX_CONTEXT_SIZE;
format!(
"[...{} bytes truncated]\n{}\n",
to_truncate,
&context[to_truncate..]
)
} else {
format!(
"{}\n[...{} bytes truncated]\n",
&context[..MAX_CONTEXT_SIZE],
context.len() - MAX_CONTEXT_SIZE
)
}
}
/// Add <|editable_region_start/end|> tags
fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
let path = worktree_dir.join(cursor.path.as_std_path());
let file = std::fs::read_to_string(&path).unwrap_or_default();
let lines = file.lines().collect::<Vec<_>>();
let cursor_row = cursor.point.row as usize;
let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
let snippet = lines[start_line..end_line].join("\n");
if context.contains(&snippet) {
let mut cursor_line = lines[cursor_row].to_string();
cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
let mut snippet_with_tags_lines = vec![];
snippet_with_tags_lines.push(TeacherModel::REGION_START);
snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
snippet_with_tags_lines.push(&cursor_line);
snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
snippet_with_tags_lines.push(TeacherModel::REGION_END);
let snippet_with_tags = snippet_with_tags_lines.join("\n");
context.replace(&snippet, &snippet_with_tags)
} else {
log::warn!(
"Can't find area around the cursor in the context; proceeding without special tags"
);
context.to_string()
}
}
pub fn strip_special_tags(context: &str) -> String {
context
.replace(TeacherModel::REGION_START, "")
.replace(TeacherModel::REGION_END, "")
.replace(TeacherModel::USER_CURSOR, "")
}

View File

@@ -1,94 +0,0 @@
use serde::Deserialize;
use std::sync::Arc;
use crate::{
DistillArguments,
example::Example,
source_location::SourceLocation,
training::{
context::ContextType,
llm_client::LlmClient,
teacher::{TeacherModel, TeacherOutput},
},
};
use anyhow::Result;
use reqwest_client::ReqwestClient;
#[derive(Debug, Deserialize)]
pub struct SplitCommit {
repo_url: String,
commit_sha: String,
edit_history: String,
expected_patch: String,
cursor_position: String,
}
pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
.expect("Failed to read split commit dataset")
.lines()
.map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
.collect();
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let llm_client = if let Some(cache_path) = arguments.batch {
LlmClient::batch(&cache_path, http_client)?
} else {
LlmClient::plain(http_client)?
};
let mut teacher = TeacherModel::new(
"claude-sonnet-4-5".to_string(),
ContextType::CurrentFile,
llm_client,
);
let mut num_marked_for_batching = 0;
for commit in split_commits {
if let Some(distilled) = distill_one(&mut teacher, commit).await? {
println!("{}", serde_json::to_string(&distilled)?);
} else {
if num_marked_for_batching == 0 {
log::warn!("Marked for batching");
}
num_marked_for_batching += 1;
}
}
eprintln!(
"{} requests are marked for batching",
num_marked_for_batching
);
let llm_client = teacher.client;
llm_client.sync_batches().await?;
Ok(())
}
pub async fn distill_one(
teacher: &mut TeacherModel,
commit: SplitCommit,
) -> Result<Option<TeacherOutput>> {
let cursor: SourceLocation = commit
.cursor_position
.parse()
.expect("Failed to parse cursor position");
let path = cursor.path.to_rel_path_buf();
let example = Example {
repository_url: commit.repo_url,
revision: commit.commit_sha,
uncommitted_diff: commit.edit_history.clone(),
cursor_path: path.as_std_path().to_path_buf(),
cursor_position: commit.cursor_position,
edit_history: commit.edit_history, // todo: trim
expected_patch: commit.expected_patch,
};
let prediction = teacher.predict(example).await;
prediction
}

View File

@@ -1,4 +0,0 @@
pub mod context;
pub mod distill;
pub mod llm_client;
pub mod teacher;

View File

@@ -1,266 +0,0 @@
use crate::{
example::Example,
source_location::SourceLocation,
training::{
context::{ContextType, collect_context, strip_special_tags},
llm_client::LlmClient,
},
};
use anthropic::{Message, RequestContent, ResponseContent, Role};
use anyhow::Result;
pub struct TeacherModel {
pub llm_name: String,
pub context: ContextType,
pub client: LlmClient,
}
#[derive(Debug, serde::Serialize)]
pub struct TeacherOutput {
parsed_output: String,
prompt: String,
raw_llm_response: String,
context: String,
diff: String,
}
impl TeacherModel {
const PROMPT: &str = include_str!("teacher.prompt.md");
pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
pub(crate) const REGION_END: &str = "<|editable_region_end|>";
pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
/// Number of lines to include before the cursor position
pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
/// Number of lines to include after the cursor position
pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
/// Truncate edit history to this number of last lines
const MAX_HISTORY_LINES: usize = 128;
pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
TeacherModel {
llm_name,
context,
client,
}
}
pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
let name = input.unique_name();
let worktree_dir = input.setup_worktree(name).await?;
let cursor: SourceLocation = input
.cursor_position
.parse()
.expect("Failed to parse cursor position");
let context = collect_context(&self.context, &worktree_dir, cursor.clone());
let edit_history = Self::format_edit_history(&input.edit_history);
let prompt = Self::PROMPT
.replace("{{context}}", &context)
.replace("{{edit_history}}", &edit_history);
let messages = vec![Message {
role: Role::User,
content: vec![RequestContent::Text {
text: prompt.clone(),
cache_control: None,
}],
}];
let Some(response) = self
.client
.generate(self.llm_name.clone(), 16384, messages)
.await?
else {
return Ok(None);
};
let response_text = response
.content
.into_iter()
.filter_map(|content| match content {
ResponseContent::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<String>>()
.join("\n");
let parsed_output = self.parse_response(&response_text);
let original_editable_region = Self::extract_editable_region(&context);
let context_after_edit = context.replace(&original_editable_region, &parsed_output);
let context_after_edit = strip_special_tags(&context_after_edit);
let context_before_edit = strip_special_tags(&context);
let diff = language::unified_diff(&context_before_edit, &context_after_edit);
// zeta distill --batch batch_results.txt
// zeta distill
// 1. Run `zeta distill <2000 examples <- all examples>` for the first time
// - store LLM requests in a batch, don't actual send the request
// - send the batch (2000 requests) after all inputs are processed
// 2. `zeta send-batches`
// - upload the batch to Anthropic
// https://platform.claude.com/docs/en/build-with-claude/batch-processing
// https://crates.io/crates/anthropic-sdk-rust
// - poll for results
// - when ready, store results in cache (a database)
// 3. `zeta distill` again
// - use the cached results this time
Ok(Some(TeacherOutput {
parsed_output,
prompt,
raw_llm_response: response_text,
context,
diff,
}))
}
fn parse_response(&self, content: &str) -> String {
let codeblock = Self::extract_last_codeblock(content);
let editable_region = Self::extract_editable_region(&codeblock);
editable_region
}
/// Extract content from the last code-fenced block if any, or else return content as is
fn extract_last_codeblock(text: &str) -> String {
let mut last_block = None;
let mut search_start = 0;
while let Some(start) = text[search_start..].find("```") {
let start = start + search_start;
let bytes = text.as_bytes();
let mut backtick_end = start;
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
backtick_end += 1;
}
let backtick_count = backtick_end - start;
let closing_backticks = "`".repeat(backtick_count);
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
last_block = Some(code_block.to_string());
search_start = backtick_end + end_pos + backtick_count;
} else {
break;
}
}
last_block.unwrap_or_else(|| text.to_string())
}
fn extract_editable_region(text: &str) -> String {
let start = text
.find(Self::REGION_START)
.map_or(0, |pos| pos + Self::REGION_START.len());
let end = text.find(Self::REGION_END).unwrap_or(text.len());
text[start..end].to_string()
}
/// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
fn format_edit_history(edit_history: &str) -> String {
let lines = edit_history
.lines()
.filter(|&s| Self::is_content_line(s))
.collect::<Vec<_>>();
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
} else {
&lines
};
history_lines.join("\n")
}
fn is_content_line(s: &str) -> bool {
s.starts_with("-")
|| s.starts_with("+")
|| s.starts_with(" ")
|| s.starts_with("---")
|| s.starts_with("+++")
|| s.starts_with("@@")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_response() {
let teacher = TeacherModel::new(
"test".to_string(),
ContextType::CurrentFile,
LlmClient::dummy(),
);
let response = "This is a test response.";
let parsed = teacher.parse_response(response);
assert_eq!(parsed, response.to_string());
let response = indoc::indoc! {"
Some thinking
`````
actual response
`````
"};
let parsed = teacher.parse_response(response);
assert_eq!(parsed, "actual response");
}
#[test]
fn test_extract_last_code_block() {
let text = indoc::indoc! {"
Some thinking
```
first block
```
`````
last block
`````
"};
let last_block = TeacherModel::extract_last_codeblock(text);
assert_eq!(last_block, "last block");
}
#[test]
fn test_extract_editable_region() {
let teacher = TeacherModel::new(
"test".to_string(),
ContextType::CurrentFile,
LlmClient::dummy(),
);
let response = indoc::indoc! {"
some lines
are
here
<|editable_region_start|>
one
two three
<|editable_region_end|>
more
lines here
"};
let parsed = teacher.parse_response(response);
assert_eq!(
parsed,
indoc::indoc! {"
one
two three
"}
);
}
}

View File

@@ -26,6 +26,7 @@ serde.workspace = true
smallvec.workspace = true
tree-sitter.workspace = true
util.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
env_logger.workspace = true

View File

@@ -1,6 +1,6 @@
use crate::RelatedExcerpt;
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
use std::ops::Range;
use zeta_prompt::RelatedExcerpt;
#[cfg(not(test))]
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
@@ -76,14 +76,9 @@ pub fn assemble_excerpts(
input_ranges
.into_iter()
.map(|range| {
let offset_range = range.to_offset(buffer);
RelatedExcerpt {
point_range: range,
anchor_range: buffer.anchor_before(offset_range.start)
..buffer.anchor_after(offset_range.end),
text: buffer.as_rope().slice(offset_range),
}
.map(|range| RelatedExcerpt {
row_range: range.start.row..range.end.row,
text: buffer.text_for_range(range).collect(),
})
.collect()
}

View File

@@ -3,13 +3,13 @@ use anyhow::Result;
use collections::HashMap;
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
use project::{LocationLink, Project, ProjectPath};
use serde::{Serialize, Serializer};
use smallvec::SmallVec;
use std::{
collections::hash_map,
ops::Range,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
@@ -24,12 +24,14 @@ mod fake_definition_lsp;
pub use cloud_llm_client::predict_edits_v3::Line;
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
const IDENTIFIER_LINE_COUNT: u32 = 3;
pub struct RelatedExcerptStore {
project: WeakEntity<Project>,
related_files: Vec<RelatedFile>,
related_files: Arc<[RelatedFile]>,
related_file_buffers: Vec<Entity<Buffer>>,
cache: HashMap<Identifier, Arc<CacheEntry>>,
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
identifier_line_count: u32,
@@ -68,82 +70,6 @@ struct CachedDefinition {
anchor_range: Range<Anchor>,
}
#[derive(Clone, Debug, Serialize)]
pub struct RelatedFile {
#[serde(serialize_with = "serialize_project_path")]
pub path: ProjectPath,
#[serde(skip)]
pub buffer: WeakEntity<Buffer>,
pub excerpts: Vec<RelatedExcerpt>,
pub max_row: u32,
}
impl RelatedFile {
pub fn merge_excerpts(&mut self) {
self.excerpts.sort_unstable_by(|a, b| {
a.point_range
.start
.cmp(&b.point_range.start)
.then(b.point_range.end.cmp(&a.point_range.end))
});
let mut index = 1;
while index < self.excerpts.len() {
if self.excerpts[index - 1]
.point_range
.end
.cmp(&self.excerpts[index].point_range.start)
.is_ge()
{
let removed = self.excerpts.remove(index);
if removed
.point_range
.end
.cmp(&self.excerpts[index - 1].point_range.end)
.is_gt()
{
self.excerpts[index - 1].point_range.end = removed.point_range.end;
self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
}
} else {
index += 1;
}
}
}
}
#[derive(Clone, Debug, Serialize)]
pub struct RelatedExcerpt {
#[serde(skip)]
pub anchor_range: Range<Anchor>,
#[serde(serialize_with = "serialize_point_range")]
pub point_range: Range<Point>,
#[serde(serialize_with = "serialize_rope")]
pub text: Rope,
}
fn serialize_project_path<S: Serializer>(
project_path: &ProjectPath,
serializer: S,
) -> Result<S::Ok, S::Error> {
project_path.path.serialize(serializer)
}
fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
rope.to_string().serialize(serializer)
}
fn serialize_point_range<S: Serializer>(
range: &Range<Point>,
serializer: S,
) -> Result<S::Ok, S::Error> {
[
[range.start.row, range.start.column],
[range.end.row, range.end.column],
]
.serialize(serializer)
}
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
@@ -179,7 +105,8 @@ impl RelatedExcerptStore {
RelatedExcerptStore {
project: project.downgrade(),
update_tx,
related_files: Vec::new(),
related_files: Vec::new().into(),
related_file_buffers: Vec::new(),
cache: Default::default(),
identifier_line_count: IDENTIFIER_LINE_COUNT,
}
@@ -193,8 +120,21 @@ impl RelatedExcerptStore {
self.update_tx.unbounded_send((buffer, position)).ok();
}
pub fn related_files(&self) -> &[RelatedFile] {
&self.related_files
pub fn related_files(&self) -> Arc<[RelatedFile]> {
self.related_files.clone()
}
pub fn related_files_with_buffers(
&self,
) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
self.related_files
.iter()
.cloned()
.zip(self.related_file_buffers.iter().cloned())
}
pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
self.related_files = files.into();
}
async fn fetch_excerpts(
@@ -297,7 +237,8 @@ impl RelatedExcerptStore {
}
mean_definition_latency /= cache_miss_count.max(1) as u32;
let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
let (new_cache, related_files, related_file_buffers) =
rebuild_related_files(&project, new_cache, cx).await?;
if let Some(file) = &file {
log::debug!(
@@ -309,7 +250,8 @@ impl RelatedExcerptStore {
this.update(cx, |this, cx| {
this.cache = new_cache;
this.related_files = related_files;
this.related_files = related_files.into();
this.related_file_buffers = related_file_buffers;
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
@@ -323,10 +265,16 @@ impl RelatedExcerptStore {
}
async fn rebuild_related_files(
project: &Entity<Project>,
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
cx: &mut AsyncApp,
) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
) -> Result<(
HashMap<Identifier, Arc<CacheEntry>>,
Vec<RelatedFile>,
Vec<Entity<Buffer>>,
)> {
let mut snapshots = HashMap::default();
let mut worktree_root_names = HashMap::default();
for entry in new_entries.values() {
for definition in &entry.definitions {
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
@@ -340,12 +288,22 @@ async fn rebuild_related_files(
.read_with(cx, |buffer, _| buffer.snapshot())?,
);
}
let worktree_id = definition.path.worktree_id;
if let hash_map::Entry::Vacant(e) =
worktree_root_names.entry(definition.path.worktree_id)
{
project.read_with(cx, |project, cx| {
if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
}
})?;
}
}
}
Ok(cx
.background_spawn(async move {
let mut files = Vec::<RelatedFile>::new();
let mut files = Vec::new();
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
let mut paths_by_buffer = HashMap::default();
for entry in new_entries.values() {
@@ -369,16 +327,31 @@ async fn rebuild_related_files(
continue;
};
let excerpts = assemble_excerpts(snapshot, ranges);
files.push(RelatedFile {
path: project_path.clone(),
buffer: buffer.downgrade(),
excerpts,
max_row: snapshot.max_point().row,
});
let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
continue;
};
let path = Path::new(&format!(
"{}/{}",
root_name,
project_path.path.as_unix_str()
))
.into();
files.push((
buffer,
RelatedFile {
path,
excerpts,
max_row: snapshot.max_point().row,
},
));
}
files.sort_by_key(|file| file.path.clone());
(new_entries, files)
files.sort_by_key(|(_, file)| file.path.clone());
let (related_buffers, related_files) = files.into_iter().unzip();
(new_entries, related_files, related_buffers)
})
.await)
}

View File

@@ -48,7 +48,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
&excerpts,
&[
(
"src/company.rs",
"root/src/company.rs",
&[indoc! {"
pub struct Company {
owner: Arc<Person>,
@@ -56,7 +56,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
}"}],
),
(
"src/main.rs",
"root/src/main.rs",
&[
indoc! {"
pub struct Session {
@@ -71,7 +71,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
],
),
(
"src/person.rs",
"root/src/person.rs",
&[
indoc! {"
impl Person {
@@ -446,7 +446,7 @@ fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &
.iter()
.map(|excerpt| excerpt.text.to_string())
.collect::<Vec<_>>();
(file.path.path.as_unix_str(), excerpts)
(file.path.to_str().unwrap(), excerpts)
})
.collect::<Vec<_>>();
let expected_excerpts = expected_files
@@ -492,10 +492,10 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
if excerpt.text.is_empty() {
continue;
}
if current_row < excerpt.point_range.start.row {
if current_row < excerpt.row_range.start {
writeln!(&mut output, "").unwrap();
}
current_row = excerpt.point_range.start.row;
current_row = excerpt.row_range.start;
for line in excerpt.text.to_string().lines() {
output.push_str(line);

View File

@@ -17,7 +17,6 @@ anyhow.workspace = true
buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
codestral.workspace = true
command_palette_hooks.workspace = true
copilot.workspace = true
@@ -46,6 +45,7 @@ ui_input.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }

View File

@@ -17,7 +17,7 @@ use gpui::{
};
use multi_buffer::MultiBuffer;
use project::Project;
use text::OffsetRangeExt;
use text::Point;
use ui::{
ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
StyledTypography as _, h_flex, v_flex,
@@ -66,7 +66,7 @@ impl EditPredictionContextView {
) -> Self {
let store = EditPredictionStore::global(client, user_store, cx);
let mut debug_rx = store.update(cx, |store, _| store.debug_info());
let mut debug_rx = store.update(cx, |store, cx| store.debug_info(&project, cx));
let _update_task = cx.spawn_in(window, async move |this, cx| {
while let Some(event) = debug_rx.next().await {
this.update_in(cx, |this, window, cx| {
@@ -103,7 +103,8 @@ impl EditPredictionContextView {
self.handle_context_retrieval_finished(info, window, cx);
}
}
DebugEvent::EditPredictionRequested(_) => {}
DebugEvent::EditPredictionStarted(_) => {}
DebugEvent::EditPredictionFinished(_) => {}
}
}
@@ -152,12 +153,11 @@ impl EditPredictionContextView {
run.finished_at = Some(info.timestamp);
run.metadata = info.metadata;
let project = self.project.clone();
let related_files = self
.store
.read(cx)
.context_for_project(&self.project, cx)
.to_vec();
.context_for_project_with_buffers(&self.project, cx)
.map_or(Vec::new(), |files| files.collect());
let editor = run.editor.clone();
let multibuffer = run.editor.read(cx).buffer().clone();
@@ -168,33 +168,14 @@ impl EditPredictionContextView {
cx.spawn_in(window, async move |this, cx| {
let mut paths = Vec::new();
for related_file in related_files {
let (buffer, point_ranges): (_, Vec<_>) =
if let Some(buffer) = related_file.buffer.upgrade() {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
(
buffer,
related_file
.excerpts
.iter()
.map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
.collect(),
)
} else {
(
project
.update(cx, |project, cx| {
project.open_buffer(related_file.path.clone(), cx)
})?
.await?,
related_file
.excerpts
.iter()
.map(|excerpt| excerpt.point_range.clone())
.collect(),
)
};
for (related_file, buffer) in related_files {
let point_ranges = related_file
.excerpts
.iter()
.map(|excerpt| {
Point::new(excerpt.row_range.start, 0)..Point::new(excerpt.row_range.end, 0)
})
.collect::<Vec<_>>();
cx.update(|_, cx| {
let path = PathKey::for_buffer(&buffer, cx);
paths.push((path, buffer, point_ranges));

View File

@@ -1,5 +1,4 @@
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use cloud_zeta2_prompt::write_codeblock;
use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
use editor::{Editor, ExcerptRange, MultiBuffer};
use feature_flags::FeatureFlag;
@@ -362,14 +361,14 @@ impl RatePredictionsModal {
write!(&mut formatted_inputs, "## Events\n\n").unwrap();
for event in &prediction.inputs.events {
write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
formatted_inputs.push_str("```diff\n");
zeta_prompt::write_event(&mut formatted_inputs, event.as_ref());
formatted_inputs.push_str("```\n\n");
}
write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
for included_file in &prediction.inputs.included_files {
let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
write!(&mut formatted_inputs, "## Related files\n\n").unwrap();
for included_file in prediction.inputs.related_files.as_ref() {
write!(
&mut formatted_inputs,
"### {}\n\n",
@@ -377,20 +376,28 @@ impl RatePredictionsModal {
)
.unwrap();
write_codeblock(
&included_file.path,
&included_file.excerpts,
if included_file.path == prediction.inputs.cursor_path {
cursor_insertions.as_slice()
} else {
&[]
},
included_file.max_row,
false,
&mut formatted_inputs,
);
for excerpt in included_file.excerpts.iter() {
write!(
&mut formatted_inputs,
"```{}\n{}\n```\n",
included_file.path.display(),
excerpt.text
)
.unwrap();
}
}
write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap();
writeln!(
&mut formatted_inputs,
"```{}\n{}<CURSOR>{}\n```\n",
prediction.inputs.cursor_path.display(),
&prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt],
&prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..],
)
.unwrap();
self.active_prediction = Some(ActivePrediction {
prediction,
feedback_editor: cx.new(|cx| {

View File

@@ -280,7 +280,11 @@ pub fn deploy_context_menu(
"Copy Permalink",
Box::new(CopyPermalinkToLine),
)
.action_disabled_when(!has_git_repo, "File History", Box::new(git::FileHistory));
.action_disabled_when(
!has_git_repo,
"View File History",
Box::new(git::FileHistory),
);
match focus {
Some(focus) => builder.context(focus),
None => builder,

View File

@@ -29,6 +29,7 @@ pub struct ExtensionHostProxy {
slash_command_proxy: RwLock<Option<Arc<dyn ExtensionSlashCommandProxy>>>,
context_server_proxy: RwLock<Option<Arc<dyn ExtensionContextServerProxy>>>,
debug_adapter_provider_proxy: RwLock<Option<Arc<dyn ExtensionDebugAdapterProviderProxy>>>,
language_model_provider_proxy: RwLock<Option<Arc<dyn ExtensionLanguageModelProviderProxy>>>,
}
impl ExtensionHostProxy {
@@ -54,6 +55,7 @@ impl ExtensionHostProxy {
slash_command_proxy: RwLock::default(),
context_server_proxy: RwLock::default(),
debug_adapter_provider_proxy: RwLock::default(),
language_model_provider_proxy: RwLock::default(),
}
}
@@ -90,6 +92,15 @@ impl ExtensionHostProxy {
.write()
.replace(Arc::new(proxy));
}
pub fn register_language_model_provider_proxy(
&self,
proxy: impl ExtensionLanguageModelProviderProxy,
) {
self.language_model_provider_proxy
.write()
.replace(Arc::new(proxy));
}
}
pub trait ExtensionThemeProxy: Send + Sync + 'static {
@@ -375,6 +386,49 @@ pub trait ExtensionContextServerProxy: Send + Sync + 'static {
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App);
}
/// A function that registers a language model provider with the registry.
/// This allows extension_host to create the provider (which requires WasmExtension)
/// and pass a registration closure to the language_models crate.
pub type LanguageModelProviderRegistration = Box<dyn FnOnce(&mut App) + Send + Sync + 'static>;
pub trait ExtensionLanguageModelProviderProxy: Send + Sync + 'static {
/// Register an LLM provider from an extension.
/// The `register_fn` closure will be called with the App context and should
/// register the provider with the LanguageModelRegistry.
fn register_language_model_provider(
&self,
provider_id: Arc<str>,
register_fn: LanguageModelProviderRegistration,
cx: &mut App,
);
/// Unregister an LLM provider when an extension is unloaded.
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App);
}
impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy {
fn register_language_model_provider(
&self,
provider_id: Arc<str>,
register_fn: LanguageModelProviderRegistration,
cx: &mut App,
) {
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
return;
};
proxy.register_language_model_provider(provider_id, register_fn, cx)
}
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
return;
};
proxy.unregister_language_model_provider(provider_id, cx)
}
}
impl ExtensionContextServerProxy for ExtensionHostProxy {
fn register_context_server(
&self,

View File

@@ -93,6 +93,8 @@ pub struct ExtensionManifest {
pub debug_adapters: BTreeMap<Arc<str>, DebugAdapterManifestEntry>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub debug_locators: BTreeMap<Arc<str>, DebugLocatorManifestEntry>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub language_model_providers: BTreeMap<Arc<str>, LanguageModelProviderManifestEntry>,
}
impl ExtensionManifest {
@@ -288,6 +290,71 @@ pub struct DebugAdapterManifestEntry {
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct DebugLocatorManifestEntry {}
/// Manifest entry for a language model provider.
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct LanguageModelProviderManifestEntry {
/// Display name for the provider.
pub name: String,
/// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
#[serde(default)]
pub icon: Option<String>,
/// Default models to show even before API connection.
#[serde(default)]
pub models: Vec<LanguageModelManifestEntry>,
/// Authentication configuration.
#[serde(default)]
pub auth: Option<LanguageModelAuthConfig>,
}
/// Manifest entry for a language model.
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct LanguageModelManifestEntry {
/// Unique identifier for the model.
pub id: String,
/// Display name for the model.
pub name: String,
/// Maximum input token count.
#[serde(default)]
pub max_token_count: u64,
/// Maximum output tokens (optional).
#[serde(default)]
pub max_output_tokens: Option<u64>,
/// Whether the model supports image inputs.
#[serde(default)]
pub supports_images: bool,
/// Whether the model supports tool/function calling.
#[serde(default)]
pub supports_tools: bool,
/// Whether the model supports extended thinking/reasoning.
#[serde(default)]
pub supports_thinking: bool,
}
/// Authentication configuration for a language model provider.
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct LanguageModelAuthConfig {
/// Environment variable name for the API key.
#[serde(default)]
pub env_var: Option<String>,
/// Human-readable name for the credential shown in the UI input field (e.g., "API Key", "Access Token").
#[serde(default)]
pub credential_label: Option<String>,
/// OAuth configuration for web-based authentication flows.
#[serde(default)]
pub oauth: Option<OAuthConfig>,
}
/// OAuth configuration for web-based authentication.
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct OAuthConfig {
/// The text to display on the sign-in button (e.g., "Sign in with GitHub").
#[serde(default)]
pub sign_in_button_label: Option<String>,
/// The icon to display on the sign-in button (e.g., "github").
#[serde(default)]
pub sign_in_button_icon: Option<String>,
}
impl ExtensionManifest {
pub async fn load(fs: Arc<dyn Fs>, extension_dir: &Path) -> Result<Self> {
let extension_name = extension_dir
@@ -358,6 +425,7 @@ fn manifest_from_old_manifest(
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: Default::default(),
}
}
@@ -391,6 +459,7 @@ mod tests {
capabilities: vec![],
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}
}

View File

@@ -29,6 +29,27 @@ pub use wit::{
GithubRelease, GithubReleaseAsset, GithubReleaseOptions, github_release_by_tag_name,
latest_github_release,
},
zed::extension::llm_provider::{
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
ImageData as LlmImageData, MessageContent as LlmMessageContent,
MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest,
OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig,
OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo,
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
ToolUseJsonParseError as LlmToolUseJsonParseError,
delete_credential as llm_delete_credential, get_credential as llm_get_credential,
get_env_var as llm_get_env_var, oauth_open_browser as llm_oauth_open_browser,
oauth_start_web_auth as llm_oauth_start_web_auth,
request_credential as llm_request_credential,
send_oauth_http_request as llm_oauth_http_request,
store_credential as llm_store_credential,
},
zed::extension::nodejs::{
node_binary_path, npm_install_package, npm_package_installed_version,
npm_package_latest_version,
@@ -259,6 +280,94 @@ pub trait Extension: Send + Sync {
) -> Result<DebugRequest, String> {
Err("`run_dap_locator` not implemented".to_string())
}
/// Returns information about language model providers offered by this extension.
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
Vec::new()
}
/// Returns the models available for a provider.
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
Ok(Vec::new())
}
/// Returns markdown content to display in the provider's settings UI.
/// This can include setup instructions, links to documentation, etc.
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
None
}
/// Check if the provider is authenticated.
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
false
}
/// Start an OAuth device flow sign-in.
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
/// Opens the browser to the verification URL and returns the user code that should
/// be displayed to the user.
fn llm_provider_start_device_flow_sign_in(
&mut self,
_provider_id: &str,
) -> Result<String, String> {
Err("`llm_provider_start_device_flow_sign_in` not implemented".to_string())
}
/// Poll for device flow sign-in completion.
/// This is called after llm_provider_start_device_flow_sign_in returns the user code.
/// The extension should poll the OAuth provider until the user authorizes or the flow times out.
fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
Err("`llm_provider_poll_device_flow_sign_in` not implemented".to_string())
}
/// Reset credentials for the provider.
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
Err("`llm_provider_reset_credentials` not implemented".to_string())
}
/// Count tokens for a request.
fn llm_count_tokens(
&self,
_provider_id: &str,
_model_id: &str,
_request: &LlmCompletionRequest,
) -> Result<u64, String> {
Err("`llm_count_tokens` not implemented".to_string())
}
/// Start streaming a completion from the model.
/// Returns a stream ID that can be used with `llm_stream_completion_next` and `llm_stream_completion_close`.
fn llm_stream_completion_start(
&mut self,
_provider_id: &str,
_model_id: &str,
_request: &LlmCompletionRequest,
) -> Result<String, String> {
Err("`llm_stream_completion_start` not implemented".to_string())
}
/// Get the next event from a completion stream.
/// Returns `Ok(None)` when the stream is complete.
fn llm_stream_completion_next(
&mut self,
_stream_id: &str,
) -> Result<Option<LlmCompletionEvent>, String> {
Err("`llm_stream_completion_next` not implemented".to_string())
}
/// Close a completion stream and release its resources.
fn llm_stream_completion_close(&mut self, _stream_id: &str) {
// Default implementation does nothing
}
/// Get cache configuration for a model (if prompt caching is supported).
fn llm_cache_configuration(
&self,
_provider_id: &str,
_model_id: &str,
) -> Option<LlmCacheConfiguration> {
None
}
}
/// Registers the provided type as a Zed extension.
@@ -518,6 +627,65 @@ impl wit::Guest for Component {
) -> Result<DebugRequest, String> {
extension().run_dap_locator(locator_name, build_task)
}
fn llm_providers() -> Vec<LlmProviderInfo> {
extension().llm_providers()
}
fn llm_provider_models(provider_id: String) -> Result<Vec<LlmModelInfo>, String> {
extension().llm_provider_models(&provider_id)
}
fn llm_provider_settings_markdown(provider_id: String) -> Option<String> {
extension().llm_provider_settings_markdown(&provider_id)
}
fn llm_provider_is_authenticated(provider_id: String) -> bool {
extension().llm_provider_is_authenticated(&provider_id)
}
fn llm_provider_start_device_flow_sign_in(provider_id: String) -> Result<String, String> {
extension().llm_provider_start_device_flow_sign_in(&provider_id)
}
fn llm_provider_poll_device_flow_sign_in(provider_id: String) -> Result<(), String> {
extension().llm_provider_poll_device_flow_sign_in(&provider_id)
}
fn llm_provider_reset_credentials(provider_id: String) -> Result<(), String> {
extension().llm_provider_reset_credentials(&provider_id)
}
fn llm_count_tokens(
provider_id: String,
model_id: String,
request: LlmCompletionRequest,
) -> Result<u64, String> {
extension().llm_count_tokens(&provider_id, &model_id, &request)
}
fn llm_stream_completion_start(
provider_id: String,
model_id: String,
request: LlmCompletionRequest,
) -> Result<String, String> {
extension().llm_stream_completion_start(&provider_id, &model_id, &request)
}
fn llm_stream_completion_next(stream_id: String) -> Result<Option<LlmCompletionEvent>, String> {
extension().llm_stream_completion_next(&stream_id)
}
fn llm_stream_completion_close(stream_id: String) {
extension().llm_stream_completion_close(&stream_id)
}
fn llm_cache_configuration(
provider_id: String,
model_id: String,
) -> Option<LlmCacheConfiguration> {
extension().llm_cache_configuration(&provider_id, &model_id)
}
}
/// The ID of a language server.

View File

@@ -8,6 +8,7 @@ world extension {
import platform;
import process;
import nodejs;
import llm-provider;
use common.{env-vars, range};
use context-server.{context-server-configuration};
@@ -15,6 +16,10 @@ world extension {
use lsp.{completion, symbol};
use process.{command};
use slash-command.{slash-command, slash-command-argument-completion, slash-command-output};
use llm-provider.{
provider-info, model-info, completion-request,
credential-type, cache-configuration, completion-event, token-usage
};
/// Initializes the extension.
export init-extension: func();
@@ -164,4 +169,74 @@ world extension {
export dap-config-to-scenario: func(config: debug-config) -> result<debug-scenario, string>;
export dap-locator-create-scenario: func(locator-name: string, build-config-template: build-task-template, resolved-label: string, debug-adapter-name: string) -> option<debug-scenario>;
export run-dap-locator: func(locator-name: string, config: resolved-task) -> result<debug-request, string>;
/// Returns information about language model providers offered by this extension.
export llm-providers: func() -> list<provider-info>;
/// Returns the models available for a provider.
export llm-provider-models: func(provider-id: string) -> result<list<model-info>, string>;
/// Returns markdown content to display in the provider's settings UI.
/// This can include setup instructions, links to documentation, etc.
export llm-provider-settings-markdown: func(provider-id: string) -> option<string>;
/// Check if the provider is authenticated.
export llm-provider-is-authenticated: func(provider-id: string) -> bool;
/// Start an OAuth device flow sign-in.
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
///
/// The device flow works as follows:
/// 1. Extension requests a device code from the OAuth provider
/// 2. Extension opens the verification URL in the browser
/// 3. Extension returns the user code to display to the user
/// 4. Host displays the user code and calls llm-provider-poll-device-flow-sign-in
/// 5. Extension polls for the access token while user authorizes in browser
/// 6. Once authorized, extension stores the credential and returns success
///
/// Returns the user code that should be displayed to the user while they
/// complete authorization in the browser.
export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result<string, string>;
/// Poll for device flow sign-in completion.
/// This is called after llm-provider-start-device-flow-sign-in returns the user code.
/// The extension should poll the OAuth provider until the user authorizes or the flow times out.
/// Returns Ok(()) on successful authentication, or an error message on failure.
export llm-provider-poll-device-flow-sign-in: func(provider-id: string) -> result<_, string>;
/// Reset credentials for the provider.
export llm-provider-reset-credentials: func(provider-id: string) -> result<_, string>;
/// Count tokens for a request.
export llm-count-tokens: func(
provider-id: string,
model-id: string,
request: completion-request
) -> result<u64, string>;
/// Start streaming a completion from the model.
/// Returns a stream ID that can be used with llm-stream-next and llm-stream-close.
export llm-stream-completion-start: func(
provider-id: string,
model-id: string,
request: completion-request
) -> result<string, string>;
/// Get the next event from a completion stream.
/// Returns None when the stream is complete.
export llm-stream-completion-next: func(
stream-id: string
) -> result<option<completion-event>, string>;
/// Close a completion stream and release its resources.
export llm-stream-completion-close: func(
stream-id: string
);
/// Get cache configuration for a model (if prompt caching is supported).
export llm-cache-configuration: func(
provider-id: string,
model-id: string
) -> option<cache-configuration>;
}

View File

@@ -0,0 +1,348 @@
interface llm-provider {
/// Information about a language model provider.
record provider-info {
/// Unique identifier for the provider (e.g., "my-extension.my-provider").
id: string,
/// Display name for the provider.
name: string,
/// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
icon: option<string>,
}
/// Capabilities of a language model.
record model-capabilities {
/// Whether the model supports image inputs.
supports-images: bool,
/// Whether the model supports tool/function calling.
supports-tools: bool,
/// Whether the model supports the "auto" tool choice.
supports-tool-choice-auto: bool,
/// Whether the model supports the "any" tool choice.
supports-tool-choice-any: bool,
/// Whether the model supports the "none" tool choice.
supports-tool-choice-none: bool,
/// Whether the model supports extended thinking/reasoning.
supports-thinking: bool,
/// The format for tool input schemas.
tool-input-format: tool-input-format,
}
/// Format for tool input schemas.
enum tool-input-format {
/// Standard JSON Schema format.
json-schema,
/// Simplified schema format for certain providers.
simplified,
}
/// Information about a specific model.
record model-info {
/// Unique identifier for the model.
id: string,
/// Display name for the model.
name: string,
/// Maximum input token count.
max-token-count: u64,
/// Maximum output tokens (optional).
max-output-tokens: option<u64>,
/// Model capabilities.
capabilities: model-capabilities,
/// Whether this is the default model for the provider.
is-default: bool,
/// Whether this is the default fast model.
is-default-fast: bool,
}
/// The role of a message participant.
enum message-role {
/// User message.
user,
/// Assistant message.
assistant,
/// System message.
system,
}
/// A message in a completion request.
record request-message {
/// The role of the message sender.
role: message-role,
/// The content of the message.
content: list<message-content>,
/// Whether to cache this message for prompt caching.
cache: bool,
}
/// Content within a message.
variant message-content {
/// Plain text content.
text(string),
/// Image content.
image(image-data),
/// A tool use request from the assistant.
tool-use(tool-use),
/// A tool result from the user.
tool-result(tool-result),
/// Thinking/reasoning content.
thinking(thinking-content),
/// Redacted/encrypted thinking content.
redacted-thinking(string),
}
/// Image data for vision models.
record image-data {
/// Base64-encoded image data.
source: string,
/// Image width in pixels (optional).
width: option<u32>,
/// Image height in pixels (optional).
height: option<u32>,
}
/// A tool use request from the model.
record tool-use {
/// Unique identifier for this tool use.
id: string,
/// The name of the tool being used.
name: string,
/// JSON string of the tool input arguments.
input: string,
/// Thought signature for providers that support it (e.g., Anthropic).
thought-signature: option<string>,
}
/// A tool result to send back to the model.
record tool-result {
/// The ID of the tool use this is a result for.
tool-use-id: string,
/// The name of the tool.
tool-name: string,
/// Whether this result represents an error.
is-error: bool,
/// The content of the result.
content: tool-result-content,
}
/// Content of a tool result.
variant tool-result-content {
/// Text result.
text(string),
/// Image result.
image(image-data),
}
/// Thinking/reasoning content from models that support extended thinking.
record thinking-content {
/// The thinking text.
text: string,
/// Signature for the thinking block (provider-specific).
signature: option<string>,
}
/// A tool definition for function calling.
record tool-definition {
/// The name of the tool.
name: string,
/// Description of what the tool does.
description: string,
/// JSON Schema for input parameters.
input-schema: string,
}
/// Tool choice preference for the model.
enum tool-choice {
/// Let the model decide whether to use tools.
auto,
/// Force the model to use at least one tool.
any,
/// Prevent the model from using tools.
none,
}
/// A completion request to send to the model.
record completion-request {
/// The messages in the conversation.
messages: list<request-message>,
/// Available tools for the model to use.
tools: list<tool-definition>,
/// Tool choice preference.
tool-choice: option<tool-choice>,
/// Stop sequences to end generation.
stop-sequences: list<string>,
/// Temperature for sampling (0.0-1.0).
temperature: option<f32>,
/// Whether thinking/reasoning is allowed.
thinking-allowed: bool,
/// Maximum tokens to generate.
max-tokens: option<u64>,
}
/// Events emitted during completion streaming.
variant completion-event {
/// Completion has started.
started,
/// Text content chunk.
text(string),
/// Thinking/reasoning content chunk.
thinking(thinking-content),
/// Redacted thinking (encrypted) chunk.
redacted-thinking(string),
/// Tool use request from the model.
tool-use(tool-use),
/// JSON parse error when parsing tool input.
tool-use-json-parse-error(tool-use-json-parse-error),
/// Completion stopped.
stop(stop-reason),
/// Token usage update.
usage(token-usage),
/// Reasoning details (provider-specific JSON).
reasoning-details(string),
}
/// Error information when tool use JSON parsing fails.
record tool-use-json-parse-error {
/// The tool use ID.
id: string,
/// The tool name.
tool-name: string,
/// The raw input that failed to parse.
raw-input: string,
/// The parse error message.
error: string,
}
/// Reason the completion stopped.
enum stop-reason {
/// The model finished generating.
end-turn,
/// Maximum tokens reached.
max-tokens,
/// The model wants to use a tool.
tool-use,
/// The model refused to respond.
refusal,
}
/// Token usage statistics.
record token-usage {
/// Number of input tokens used.
input-tokens: u64,
/// Number of output tokens generated.
output-tokens: u64,
/// Tokens used for cache creation (if supported).
cache-creation-input-tokens: option<u64>,
/// Tokens read from cache (if supported).
cache-read-input-tokens: option<u64>,
}
/// Credential types that can be requested.
enum credential-type {
/// An API key.
api-key,
/// An OAuth token.
oauth-token,
}
/// Cache configuration for prompt caching.
record cache-configuration {
/// Maximum number of cache anchors.
max-cache-anchors: u32,
/// Whether caching should be applied to tool definitions.
should-cache-tool-definitions: bool,
/// Minimum token count for a message to be cached.
min-total-token-count: u64,
}
/// Configuration for starting an OAuth web authentication flow.
record oauth-web-auth-config {
/// The URL to open in the user's browser to start authentication.
/// This should include client_id, redirect_uri, scope, state, etc.
/// Use `{port}` as a placeholder in the URL - it will be replaced with
/// the actual localhost port before opening the browser.
/// Example: "https://example.com/oauth?redirect_uri=http://127.0.0.1:{port}/callback"
auth-url: string,
/// The path to listen on for the OAuth callback (e.g., "/callback").
/// A localhost server will be started to receive the redirect.
callback-path: string,
/// Timeout in seconds to wait for the callback (default: 300 = 5 minutes).
timeout-secs: option<u32>,
}
/// Result of an OAuth web authentication flow.
record oauth-web-auth-result {
/// The full callback URL that was received, including query parameters.
/// The extension is responsible for parsing the code, state, etc.
callback-url: string,
/// The port that was used for the localhost callback server.
port: u32,
}
/// A generic HTTP request for OAuth token exchange.
record oauth-http-request {
/// The URL to request.
url: string,
/// HTTP method (e.g., "POST", "GET").
method: string,
/// Request headers as key-value pairs.
headers: list<tuple<string, string>>,
/// Request body as a string (for form-encoded or JSON bodies).
body: string,
}
/// Response from an OAuth HTTP request.
record oauth-http-response {
/// HTTP status code.
status: u16,
/// Response headers as key-value pairs.
headers: list<tuple<string, string>>,
/// Response body as a string.
body: string,
}
/// Request a credential from the user.
/// Returns true if the credential was provided, false if the user cancelled.
request-credential: func(
provider-id: string,
credential-type: credential-type,
label: string,
placeholder: string
) -> result<bool, string>;
/// Get a stored credential for this provider.
get-credential: func(provider-id: string) -> option<string>;
/// Store a credential for this provider.
store-credential: func(provider-id: string, value: string) -> result<_, string>;
/// Delete a stored credential for this provider.
delete-credential: func(provider-id: string) -> result<_, string>;
/// Read an environment variable.
get-env-var: func(name: string) -> option<string>;
/// Start an OAuth web authentication flow.
///
/// This will:
/// 1. Start a localhost server to receive the OAuth callback
/// 2. Open the auth URL in the user's default browser
/// 3. Wait for the callback (up to the timeout)
/// 4. Return the callback URL with query parameters
///
/// The extension is responsible for:
/// - Constructing the auth URL with client_id, redirect_uri, scope, state, etc.
/// - Parsing the callback URL to extract the authorization code
/// - Exchanging the code for tokens using oauth-http-request
oauth-start-web-auth: func(config: oauth-web-auth-config) -> result<oauth-web-auth-result, string>;
/// Make an HTTP request for OAuth token exchange.
///
/// This is a simple HTTP client for OAuth flows, allowing the extension
/// to handle token exchange with full control over serialization.
send-oauth-http-request: func(request: oauth-http-request) -> result<oauth-http-response, string>;
/// Open a URL in the user's default browser.
///
/// Useful for OAuth flows that need to open a browser but handle the
/// callback differently (e.g., polling-based flows).
oauth-open-browser: func(url: string) -> result<_, string>;
}

View File

@@ -255,6 +255,21 @@ async fn copy_extension_resources(
}
}
for (_, provider_entry) in &manifest.language_model_providers {
if let Some(icon_path) = &provider_entry.icon {
let source_icon = extension_path.join(icon_path);
let dest_icon = output_dir.join(icon_path);
// Create parent directory if needed
if let Some(parent) = dest_icon.parent() {
fs::create_dir_all(parent)?;
}
fs::copy(&source_icon, &dest_icon)
.with_context(|| format!("failed to copy LLM provider icon '{}'", icon_path))?;
}
}
if !manifest.languages.is_empty() {
let output_languages_dir = output_dir.join("languages");
fs::create_dir_all(&output_languages_dir)?;

View File

@@ -22,7 +22,10 @@ async-tar.workspace = true
async-trait.workspace = true
client.workspace = true
collections.workspace = true
credentials_provider.workspace = true
dap.workspace = true
dirs.workspace = true
editor.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
@@ -30,8 +33,11 @@ gpui.workspace = true
gpui_tokio.workspace = true
http_client.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
markdown.workspace = true
lsp.workspace = true
menu.workspace = true
moka.workspace = true
node_runtime.workspace = true
paths.workspace = true
@@ -43,10 +49,13 @@ serde.workspace = true
serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
smol.workspace = true
task.workspace = true
telemetry.workspace = true
tempfile.workspace = true
theme.workspace = true
toml.workspace = true
ui.workspace = true
url.workspace = true
util.workspace = true
wasmparser.workspace = true

View File

@@ -148,6 +148,7 @@ fn manifest() -> ExtensionManifest {
)],
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}
}

View File

@@ -0,0 +1,124 @@
use credentials_provider::CredentialsProvider;
use gpui::App;
const ANTHROPIC_EXTENSION_ID: &str = "anthropic";
const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
const ANTHROPIC_DEFAULT_API_URL: &str = "https://api.anthropic.com";
/// Migrates Anthropic API credentials from the old built-in provider location
/// to the new extension-based location.
///
/// This should only be called during auto-install of the extension.
pub fn migrate_anthropic_credentials_if_needed(extension_id: &str, cx: &mut App) {
if extension_id != ANTHROPIC_EXTENSION_ID {
return;
}
let extension_credential_key = format!(
"extension-llm-{}:{}",
ANTHROPIC_EXTENSION_ID, ANTHROPIC_PROVIDER_ID
);
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
// Read from old location
let old_credential = credentials_provider
.read_credentials(ANTHROPIC_DEFAULT_API_URL, &cx)
.await
.ok()
.flatten();
let api_key = match old_credential {
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
Ok(key) if !key.is_empty() => key,
Ok(_) => {
log::debug!("Existing Anthropic API key is empty, nothing to migrate");
return;
}
Err(_) => {
log::error!("Failed to decode Anthropic API key as UTF-8");
return;
}
},
None => {
log::debug!("No existing Anthropic API key found to migrate");
return;
}
};
log::info!("Migrating existing Anthropic API key to Anthropic extension");
match credentials_provider
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
.await
{
Ok(()) => {
log::info!("Successfully migrated Anthropic API key to extension");
}
Err(err) => {
log::error!("Failed to migrate Anthropic API key: {}", err);
}
}
})
.detach();
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
#[gpui::test]
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
let api_key = "sk-ant-test-key-12345";
cx.write_credentials(ANTHROPIC_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_anthropic_credentials_if_needed(ANTHROPIC_EXTENSION_ID, cx);
});
cx.run_until_parked();
let migrated = cx.read_credentials("extension-llm-anthropic:anthropic");
assert!(migrated.is_some(), "Credentials should have been migrated");
let (username, password) = migrated.unwrap();
assert_eq!(username, "Bearer");
assert_eq!(String::from_utf8(password).unwrap(), api_key);
}
#[gpui::test]
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_anthropic_credentials_if_needed(ANTHROPIC_EXTENSION_ID, cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-anthropic:anthropic");
assert!(
credentials.is_none(),
"Should not create credentials if none existed"
);
}
#[gpui::test]
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
let api_key = "sk-ant-test-key";
cx.write_credentials(ANTHROPIC_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_anthropic_credentials_if_needed("some-other-extension", cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-anthropic:anthropic");
assert!(
credentials.is_none(),
"Should not migrate for other extensions"
);
}
}

View File

@@ -113,6 +113,7 @@ mod tests {
capabilities: vec![],
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}
}

View File

@@ -0,0 +1,216 @@
use credentials_provider::CredentialsProvider;
use gpui::App;
use std::path::PathBuf;
const COPILOT_CHAT_EXTENSION_ID: &str = "copilot-chat";
const COPILOT_CHAT_PROVIDER_ID: &str = "copilot-chat";
/// Migrates Copilot OAuth credentials from the GitHub Copilot config files
/// to the new extension-based credential location.
///
/// This should only be called during auto-install of the extension.
pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) {
if extension_id != COPILOT_CHAT_EXTENSION_ID {
return;
}
let credential_key = format!(
"extension-llm-{}:{}",
COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID
);
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |_cx| {
// Read from copilot config files
let oauth_token = match read_copilot_oauth_token().await {
Some(token) if !token.is_empty() => token,
_ => {
log::debug!("No existing Copilot OAuth token found to migrate");
return;
}
};
log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension");
match credentials_provider
.write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &_cx)
.await
{
Ok(()) => {
log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension");
}
Err(err) => {
log::error!("Failed to migrate Copilot OAuth token: {}", err);
}
}
})
.detach();
}
async fn read_copilot_oauth_token() -> Option<String> {
let config_paths = copilot_config_paths();
for path in config_paths {
if let Some(token) = read_oauth_token_from_file(&path).await {
return Some(token);
}
}
None
}
fn copilot_config_paths() -> Vec<PathBuf> {
let config_dir = if cfg!(target_os = "windows") {
dirs::data_local_dir()
} else {
std::env::var("XDG_CONFIG_HOME")
.map(PathBuf::from)
.ok()
.or_else(|| dirs::home_dir().map(|h| h.join(".config")))
};
let Some(config_dir) = config_dir else {
return Vec::new();
};
let copilot_dir = config_dir.join("github-copilot");
vec![
copilot_dir.join("hosts.json"),
copilot_dir.join("apps.json"),
]
}
async fn read_oauth_token_from_file(path: &PathBuf) -> Option<String> {
let contents = match smol::fs::read_to_string(path).await {
Ok(contents) => contents,
Err(_) => return None,
};
extract_oauth_token(&contents, "github.com")
}
fn extract_oauth_token(contents: &str, domain: &str) -> Option<String> {
let value: serde_json::Value = serde_json::from_str(contents).ok()?;
let obj = value.as_object()?;
for (key, value) in obj.iter() {
if key.starts_with(domain) {
if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) {
return Some(token.to_string());
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
#[test]
fn test_extract_oauth_token_from_hosts_json() {
let contents = r#"{
"github.com": {
"oauth_token": "ghu_test_token_12345"
}
}"#;
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, Some("ghu_test_token_12345".to_string()));
}
#[test]
fn test_extract_oauth_token_with_user_suffix() {
let contents = r#"{
"github.com:user": {
"oauth_token": "ghu_another_token"
}
}"#;
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, Some("ghu_another_token".to_string()));
}
#[test]
fn test_extract_oauth_token_wrong_domain() {
let contents = r#"{
"gitlab.com": {
"oauth_token": "some_token"
}
}"#;
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, None);
}
#[test]
fn test_extract_oauth_token_invalid_json() {
let contents = "not valid json";
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, None);
}
#[test]
fn test_extract_oauth_token_missing_oauth_token_field() {
let contents = r#"{
"github.com": {
"user": "testuser"
}
}"#;
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, None);
}
#[test]
fn test_extract_oauth_token_multiple_entries_picks_first_match() {
let contents = r#"{
"gitlab.com": {
"oauth_token": "gitlab_token"
},
"github.com": {
"oauth_token": "github_token"
}
}"#;
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, Some("github_token".to_string()));
}
#[gpui::test]
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_copilot_credentials_if_needed("some-other-extension", cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
assert!(
credentials.is_none(),
"Should not create credentials for other extensions"
);
}
// Note: Unlike the other migrations, copilot migration reads from the filesystem
// (copilot config files), not from the credentials provider. In tests, these files
// don't exist, so no migration occurs.
#[gpui::test]
async fn test_no_credentials_when_no_copilot_config_exists(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_copilot_credentials_if_needed(COPILOT_CHAT_EXTENSION_ID, cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
assert!(
credentials.is_none(),
"No credentials should be written when copilot config doesn't exist"
);
}
}

View File

@@ -1,6 +1,11 @@
mod anthropic_migration;
mod capability_granter;
mod copilot_migration;
pub mod extension_settings;
mod google_ai_migration;
pub mod headless_host;
mod open_router_migration;
mod openai_migration;
pub mod wasm_host;
#[cfg(test)]
@@ -12,13 +17,14 @@ use async_tar::Archive;
use client::ExtensionProvides;
use client::{Client, ExtensionMetadata, GetExtensionsResponse, proto, telemetry::Telemetry};
use collections::{BTreeMap, BTreeSet, HashSet, btree_map};
pub use extension::ExtensionManifest;
use extension::extension_builder::{CompileExtensionOptions, ExtensionBuilder};
use extension::{
ExtensionContextServerProxy, ExtensionDebugAdapterProviderProxy, ExtensionEvents,
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageProxy,
ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, ExtensionSnippetProxy,
ExtensionThemeProxy,
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageModelProviderProxy,
ExtensionLanguageProxy, ExtensionLanguageServerProxy, ExtensionSlashCommandProxy,
ExtensionSnippetProxy, ExtensionThemeProxy,
};
use fs::{Fs, RemoveOptions};
use futures::future::join_all;
@@ -32,8 +38,8 @@ use futures::{
select_biased,
};
use gpui::{
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task, WeakEntity,
actions,
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, SharedString, Task,
WeakEntity, actions,
};
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
use language::{
@@ -53,15 +59,24 @@ use std::{
cmp::Ordering,
path::{self, Path, PathBuf},
sync::Arc,
time::{Duration, Instant},
time::Duration,
};
use url::Url;
use util::{ResultExt, paths::RemotePathBuf};
use wasm_host::llm_provider::ExtensionLanguageModelProvider;
use wasm_host::{
WasmExtension, WasmHost,
wit::{is_supported_wasm_api_version, wasm_api_version_range},
wit::{LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, wasm_api_version_range},
};
struct LlmProviderWithModels {
provider_info: LlmProviderInfo,
models: Vec<LlmModelInfo>,
is_authenticated: bool,
icon_path: Option<SharedString>,
auth_config: Option<extension::LanguageModelAuthConfig>,
}
pub use extension::{
ExtensionLibraryKind, GrammarManifestEntry, OldExtensionManifest, SchemaVersion,
};
@@ -70,6 +85,79 @@ pub use extension_settings::ExtensionSettings;
pub const RELOAD_DEBOUNCE_DURATION: Duration = Duration::from_millis(200);
const FS_WATCH_LATENCY: Duration = Duration::from_millis(100);
/// Extension IDs that are being migrated from hardcoded LLM providers.
/// For backwards compatibility, if the user has the corresponding env var set,
/// we automatically enable env var reading for these extensions on first install.
const LEGACY_LLM_EXTENSION_IDS: &[&str] = &[
"anthropic",
"copilot-chat",
"google-ai",
"openrouter",
"openai",
];
/// Migrates legacy LLM provider extensions by auto-enabling env var reading
/// if the env var is currently present in the environment.
///
/// This is idempotent: if the provider is already in `allowed_env_var_providers`,
/// we skip. This means if a user explicitly removes it, it will be re-added on
/// next launch if the env var is still set - but that's predictable behavior.
fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut App) {
// Only apply migration to known legacy LLM extensions
if !LEGACY_LLM_EXTENSION_IDS.contains(&manifest.id.as_ref()) {
return;
}
// Check each provider in the manifest
for (provider_id, provider_entry) in &manifest.language_model_providers {
let Some(auth_config) = &provider_entry.auth else {
continue;
};
let Some(env_var_name) = &auth_config.env_var else {
continue;
};
let full_provider_id: Arc<str> = format!("{}:{}", manifest.id, provider_id).into();
// Check if the env var is present and non-empty
let env_var_is_set = std::env::var(env_var_name)
.map(|v| !v.is_empty())
.unwrap_or(false);
// If env var isn't set, no need to do anything
if !env_var_is_set {
continue;
}
// Check if already enabled in settings
let already_enabled = ExtensionSettings::get_global(cx)
.allowed_env_var_providers
.contains(full_provider_id.as_ref());
if already_enabled {
continue;
}
// Enable env var reading since the env var is set
settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
let full_provider_id = full_provider_id.clone();
move |settings, _| {
let providers = settings
.extension
.allowed_env_var_providers
.get_or_insert_with(Vec::new);
if !providers
.iter()
.any(|id| id.as_ref() == full_provider_id.as_ref())
{
providers.push(full_provider_id);
}
}
});
}
}
/// The current extension [`SchemaVersion`] supported by Zed.
const CURRENT_SCHEMA_VERSION: SchemaVersion = SchemaVersion(1);
@@ -131,6 +219,8 @@ pub struct ExtensionStore {
pub enum ExtensionOperation {
Upgrade,
Install,
/// Auto-install from settings - triggers legacy LLM provider migrations
AutoInstall,
Remove,
}
@@ -606,15 +696,68 @@ impl ExtensionStore {
.extension_index
.extensions
.contains_key(extension_id.as_ref());
!is_already_installed && !SUPPRESSED_EXTENSIONS.contains(&extension_id.as_ref())
let dominated = SUPPRESSED_EXTENSIONS.contains(&extension_id.as_ref());
!is_already_installed && !dominated
})
.cloned()
.collect::<Vec<_>>();
cx.spawn(async move |this, cx| {
for extension_id in extensions_to_install {
// When enabled, this checks if an extension exists locally in the repo's extensions/
// directory and installs it as a dev extension instead of fetching from the registry.
// This is useful for testing auto-installed extensions before they've been published.
// Set to `true` only during local development/testing of new auto-install extensions.
#[cfg(debug_assertions)]
const DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS: bool = false;
#[cfg(debug_assertions)]
if DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS {
let local_extension_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.parent()
.unwrap()
.join("extensions")
.join(extension_id.as_ref());
if local_extension_path.exists() {
// Force-remove existing extension directory if it exists and isn't a symlink
// This handles the case where the extension was previously installed from the registry
if let Some(installed_dir) = this
.update(cx, |this, _cx| this.installed_dir.clone())
.ok()
{
let existing_path = installed_dir.join(extension_id.as_ref());
if existing_path.exists() {
let metadata = std::fs::symlink_metadata(&existing_path);
let is_symlink = metadata.map(|m| m.is_symlink()).unwrap_or(false);
if !is_symlink {
if let Err(e) = std::fs::remove_dir_all(&existing_path) {
log::error!(
"Failed to remove existing extension directory {:?}: {}",
existing_path,
e
);
}
}
}
}
if let Some(task) = this
.update(cx, |this, cx| {
this.install_dev_extension(local_extension_path, cx)
})
.ok()
{
task.await.log_err();
}
continue;
}
}
this.update(cx, |this, cx| {
this.install_latest_extension(extension_id.clone(), cx);
this.auto_install_latest_extension(extension_id.clone(), cx);
})
.ok();
}
@@ -769,7 +912,10 @@ impl ExtensionStore {
this.update(cx, |this, cx| this.reload(Some(extension_id.clone()), cx))?
.await;
if let ExtensionOperation::Install = operation {
if matches!(
operation,
ExtensionOperation::Install | ExtensionOperation::AutoInstall
) {
this.update(cx, |this, cx| {
cx.emit(Event::ExtensionInstalled(extension_id.clone()));
if let Some(events) = ExtensionEvents::try_global(cx)
@@ -779,6 +925,27 @@ impl ExtensionStore {
this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx)
});
}
// Run legacy LLM provider migrations only for auto-installed extensions
if matches!(operation, ExtensionOperation::AutoInstall) {
if let Some(manifest) = this.extension_manifest_for_id(&extension_id) {
migrate_legacy_llm_provider_env_var(&manifest, cx);
}
copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx);
anthropic_migration::migrate_anthropic_credentials_if_needed(
&extension_id,
cx,
);
google_ai_migration::migrate_google_ai_credentials_if_needed(
&extension_id,
cx,
);
openai_migration::migrate_openai_credentials_if_needed(&extension_id, cx);
open_router_migration::migrate_open_router_credentials_if_needed(
&extension_id,
cx,
);
}
})
.ok();
}
@@ -788,8 +955,24 @@ impl ExtensionStore {
}
pub fn install_latest_extension(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
log::info!("installing extension {extension_id} latest version");
self.install_latest_extension_with_operation(extension_id, ExtensionOperation::Install, cx);
}
/// Auto-install an extension, triggering legacy LLM provider migrations.
fn auto_install_latest_extension(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
self.install_latest_extension_with_operation(
extension_id,
ExtensionOperation::AutoInstall,
cx,
);
}
fn install_latest_extension_with_operation(
&mut self,
extension_id: Arc<str>,
operation: ExtensionOperation,
cx: &mut Context<Self>,
) {
let schema_versions = schema_version_range();
let wasm_api_versions = wasm_api_version_range(ReleaseChannel::global(cx));
@@ -812,13 +995,8 @@ impl ExtensionStore {
return;
};
self.install_or_upgrade_extension_at_endpoint(
extension_id,
url,
ExtensionOperation::Install,
cx,
)
.detach_and_log_err(cx);
self.install_or_upgrade_extension_at_endpoint(extension_id, url, operation, cx)
.detach_and_log_err(cx);
}
pub fn upgrade_extension(
@@ -837,7 +1015,6 @@ impl ExtensionStore {
operation: ExtensionOperation,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
log::info!("installing extension {extension_id} {version}");
let Some(url) = self
.http_client
.build_zed_api_url(
@@ -1134,18 +1311,6 @@ impl ExtensionStore {
return Task::ready(());
}
let reload_count = extensions_to_unload
.iter()
.filter(|id| extensions_to_load.contains(id))
.count();
log::info!(
"extensions updated. loading {}, reloading {}, unloading {}",
extensions_to_load.len() - reload_count,
reload_count,
extensions_to_unload.len() - reload_count
);
let extension_ids = extensions_to_load
.iter()
.filter_map(|id| {
@@ -1220,6 +1385,11 @@ impl ExtensionStore {
for command_name in extension.manifest.slash_commands.keys() {
self.proxy.unregister_slash_command(command_name.clone());
}
for provider_id in extension.manifest.language_model_providers.keys() {
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
self.proxy
.unregister_language_model_provider(full_provider_id, cx);
}
}
self.wasm_extensions
@@ -1358,7 +1528,11 @@ impl ExtensionStore {
})
.await;
let mut wasm_extensions = Vec::new();
let mut wasm_extensions: Vec<(
Arc<ExtensionManifest>,
WasmExtension,
Vec<LlmProviderWithModels>,
)> = Vec::new();
for extension in extension_entries {
if extension.manifest.lib.kind.is_none() {
continue;
@@ -1376,7 +1550,122 @@ impl ExtensionStore {
match wasm_extension {
Ok(wasm_extension) => {
wasm_extensions.push((extension.manifest.clone(), wasm_extension))
// Query for LLM providers if the manifest declares any
let mut llm_providers_with_models = Vec::new();
if !extension.manifest.language_model_providers.is_empty() {
let providers_result = wasm_extension
.call(|ext, store| {
async move { ext.call_llm_providers(store).await }.boxed()
})
.await;
if let Ok(Ok(providers)) = providers_result {
for provider_info in providers {
let models_result = wasm_extension
.call({
let provider_id = provider_info.id.clone();
|ext, store| {
async move {
ext.call_llm_provider_models(store, &provider_id)
.await
}
.boxed()
}
})
.await;
let models: Vec<LlmModelInfo> = match models_result {
Ok(Ok(Ok(models))) => models,
Ok(Ok(Err(e))) => {
log::error!(
"Failed to get models for LLM provider {} in extension {}: {}",
provider_info.id,
extension.manifest.id,
e
);
Vec::new()
}
Ok(Err(e)) => {
log::error!(
"Wasm error calling llm_provider_models for {} in extension {}: {:?}",
provider_info.id,
extension.manifest.id,
e
);
Vec::new()
}
Err(e) => {
log::error!(
"Extension call failed for llm_provider_models {} in extension {}: {:?}",
provider_info.id,
extension.manifest.id,
e
);
Vec::new()
}
};
// Query initial authentication state
let is_authenticated = wasm_extension
.call({
let provider_id = provider_info.id.clone();
|ext, store| {
async move {
ext.call_llm_provider_is_authenticated(
store,
&provider_id,
)
.await
}
.boxed()
}
})
.await
.unwrap_or(Ok(false))
.unwrap_or(false);
// Resolve icon path if provided
let icon_path = provider_info.icon.as_ref().map(|icon| {
let icon_file_path = extension_path.join(icon);
// Canonicalize to resolve symlinks (dev extensions are symlinked)
let absolute_icon_path = icon_file_path
.canonicalize()
.unwrap_or(icon_file_path)
.to_string_lossy()
.to_string();
SharedString::from(absolute_icon_path)
});
let provider_id_arc: Arc<str> =
provider_info.id.as_str().into();
let auth_config = extension
.manifest
.language_model_providers
.get(&provider_id_arc)
.and_then(|entry| entry.auth.clone());
llm_providers_with_models.push(LlmProviderWithModels {
provider_info,
models,
is_authenticated,
icon_path,
auth_config,
});
}
} else {
log::error!(
"Failed to get LLM providers from extension {}: {:?}",
extension.manifest.id,
providers_result
);
}
}
wasm_extensions.push((
extension.manifest.clone(),
wasm_extension,
llm_providers_with_models,
))
}
Err(e) => {
log::error!(
@@ -1395,7 +1684,7 @@ impl ExtensionStore {
this.update(cx, |this, cx| {
this.reload_complete_senders.clear();
for (manifest, wasm_extension) in &wasm_extensions {
for (manifest, wasm_extension, llm_providers_with_models) in &wasm_extensions {
let extension = Arc::new(wasm_extension.clone());
for (language_server_id, language_server_config) in &manifest.language_servers {
@@ -1449,9 +1738,41 @@ impl ExtensionStore {
this.proxy
.register_debug_locator(extension.clone(), debug_adapter.clone());
}
// Register LLM providers
for llm_provider in llm_providers_with_models {
let provider_id: Arc<str> =
format!("{}:{}", manifest.id, llm_provider.provider_info.id).into();
let wasm_ext = extension.as_ref().clone();
let pinfo = llm_provider.provider_info.clone();
let mods = llm_provider.models.clone();
let auth = llm_provider.is_authenticated;
let icon = llm_provider.icon_path.clone();
let auth_config = llm_provider.auth_config.clone();
this.proxy.register_language_model_provider(
provider_id.clone(),
Box::new(move |cx: &mut App| {
let provider = Arc::new(ExtensionLanguageModelProvider::new(
wasm_ext, pinfo, mods, auth, icon, auth_config, cx,
));
language_model::LanguageModelRegistry::global(cx).update(
cx,
|registry, cx| {
registry.register_provider(provider, cx);
},
);
}),
cx,
);
}
}
this.wasm_extensions.extend(wasm_extensions);
let wasm_extensions_without_llm: Vec<_> = wasm_extensions
.into_iter()
.map(|(manifest, ext, _)| (manifest, ext))
.collect();
this.wasm_extensions.extend(wasm_extensions_without_llm);
this.proxy.set_extensions_loaded();
this.proxy.reload_current_theme(cx);
this.proxy.reload_current_icon_theme(cx);
@@ -1473,7 +1794,6 @@ impl ExtensionStore {
let index_path = self.index_path.clone();
let proxy = self.proxy.clone();
cx.background_spawn(async move {
let start_time = Instant::now();
let mut index = ExtensionIndex::default();
fs.create_dir(&work_dir).await.log_err();
@@ -1511,7 +1831,6 @@ impl ExtensionStore {
.log_err();
}
log::info!("rebuilt extension index in {:?}", start_time.elapsed());
index
})
}
@@ -1785,11 +2104,6 @@ impl ExtensionStore {
})?,
path_style,
);
log::info!(
"Uploading extension {} to {:?}",
missing_extension.clone().id,
dest_dir
);
client
.update(cx, |client, cx| {
@@ -1797,11 +2111,6 @@ impl ExtensionStore {
})?
.await?;
log::info!(
"Finished uploading extension {}",
missing_extension.clone().id
);
let result = client
.update(cx, |client, _cx| {
client.proto_client().request(proto::InstallExtension {

View File

@@ -1,4 +1,4 @@
use collections::HashMap;
use collections::{HashMap, HashSet};
use extension::{
DownloadFileCapability, ExtensionCapability, NpmInstallPackageCapability, ProcessExecCapability,
};
@@ -16,6 +16,10 @@ pub struct ExtensionSettings {
pub auto_install_extensions: HashMap<Arc<str>, bool>,
pub auto_update_extensions: HashMap<Arc<str>, bool>,
pub granted_capabilities: Vec<ExtensionCapability>,
/// The extension language model providers that are allowed to read API keys
/// from environment variables. Each entry is a provider ID in the format
/// "extension_id:provider_id".
pub allowed_env_var_providers: HashSet<Arc<str>>,
}
impl ExtensionSettings {
@@ -60,6 +64,13 @@ impl Settings for ExtensionSettings {
}
})
.collect(),
allowed_env_var_providers: content
.extension
.allowed_env_var_providers
.clone()
.unwrap_or_default()
.into_iter()
.collect(),
}
}
}

View File

@@ -165,6 +165,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}),
dev: false,
},
@@ -196,6 +197,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}),
dev: false,
},
@@ -376,6 +378,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}),
dev: false,
},

View File

@@ -0,0 +1,124 @@
use credentials_provider::CredentialsProvider;
use gpui::App;
const GOOGLE_AI_EXTENSION_ID: &str = "google-ai";
const GOOGLE_AI_PROVIDER_ID: &str = "google-ai";
const GOOGLE_AI_DEFAULT_API_URL: &str = "https://generativelanguage.googleapis.com";
/// Migrates Google AI API credentials from the old built-in provider location
/// to the new extension-based location.
///
/// This should only be called during auto-install of the extension.
pub fn migrate_google_ai_credentials_if_needed(extension_id: &str, cx: &mut App) {
if extension_id != GOOGLE_AI_EXTENSION_ID {
return;
}
let extension_credential_key = format!(
"extension-llm-{}:{}",
GOOGLE_AI_EXTENSION_ID, GOOGLE_AI_PROVIDER_ID
);
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
// Read from old location
let old_credential = credentials_provider
.read_credentials(GOOGLE_AI_DEFAULT_API_URL, &cx)
.await
.ok()
.flatten();
let api_key = match old_credential {
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
Ok(key) if !key.is_empty() => key,
Ok(_) => {
log::debug!("Existing Google AI API key is empty, nothing to migrate");
return;
}
Err(_) => {
log::error!("Failed to decode Google AI API key as UTF-8");
return;
}
},
None => {
log::debug!("No existing Google AI API key found to migrate");
return;
}
};
log::info!("Migrating existing Google AI API key to Google AI extension");
match credentials_provider
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
.await
{
Ok(()) => {
log::info!("Successfully migrated Google AI API key to extension");
}
Err(err) => {
log::error!("Failed to migrate Google AI API key: {}", err);
}
}
})
.detach();
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
#[gpui::test]
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
let api_key = "AIzaSy-test-key-12345";
cx.write_credentials(GOOGLE_AI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_google_ai_credentials_if_needed(GOOGLE_AI_EXTENSION_ID, cx);
});
cx.run_until_parked();
let migrated = cx.read_credentials("extension-llm-google-ai:google-ai");
assert!(migrated.is_some(), "Credentials should have been migrated");
let (username, password) = migrated.unwrap();
assert_eq!(username, "Bearer");
assert_eq!(String::from_utf8(password).unwrap(), api_key);
}
#[gpui::test]
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_google_ai_credentials_if_needed(GOOGLE_AI_EXTENSION_ID, cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-google-ai:google-ai");
assert!(
credentials.is_none(),
"Should not create credentials if none existed"
);
}
#[gpui::test]
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
let api_key = "AIzaSy-test-key";
cx.write_credentials(GOOGLE_AI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_google_ai_credentials_if_needed("some-other-extension", cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-google-ai:google-ai");
assert!(
credentials.is_none(),
"Should not migrate for other extensions"
);
}
}

View File

@@ -0,0 +1,124 @@
use credentials_provider::CredentialsProvider;
use gpui::App;
const OPEN_ROUTER_EXTENSION_ID: &str = "openrouter";
const OPEN_ROUTER_PROVIDER_ID: &str = "openrouter";
const OPEN_ROUTER_DEFAULT_API_URL: &str = "https://openrouter.ai/api/v1";
/// Migrates OpenRouter API credentials from the old built-in provider location
/// to the new extension-based location.
///
/// This should only be called during auto-install of the extension.
pub fn migrate_open_router_credentials_if_needed(extension_id: &str, cx: &mut App) {
if extension_id != OPEN_ROUTER_EXTENSION_ID {
return;
}
let extension_credential_key = format!(
"extension-llm-{}:{}",
OPEN_ROUTER_EXTENSION_ID, OPEN_ROUTER_PROVIDER_ID
);
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
// Read from old location
let old_credential = credentials_provider
.read_credentials(OPEN_ROUTER_DEFAULT_API_URL, &cx)
.await
.ok()
.flatten();
let api_key = match old_credential {
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
Ok(key) if !key.is_empty() => key,
Ok(_) => {
log::debug!("Existing OpenRouter API key is empty, nothing to migrate");
return;
}
Err(_) => {
log::error!("Failed to decode OpenRouter API key as UTF-8");
return;
}
},
None => {
log::debug!("No existing OpenRouter API key found to migrate");
return;
}
};
log::info!("Migrating existing OpenRouter API key to OpenRouter extension");
match credentials_provider
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
.await
{
Ok(()) => {
log::info!("Successfully migrated OpenRouter API key to extension");
}
Err(err) => {
log::error!("Failed to migrate OpenRouter API key: {}", err);
}
}
})
.detach();
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
#[gpui::test]
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
let api_key = "sk-or-test-key-12345";
cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
});
cx.run_until_parked();
let migrated = cx.read_credentials("extension-llm-openrouter:openrouter");
assert!(migrated.is_some(), "Credentials should have been migrated");
let (username, password) = migrated.unwrap();
assert_eq!(username, "Bearer");
assert_eq!(String::from_utf8(password).unwrap(), api_key);
}
#[gpui::test]
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-openrouter:openrouter");
assert!(
credentials.is_none(),
"Should not create credentials if none existed"
);
}
#[gpui::test]
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
let api_key = "sk-or-test-key";
cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_open_router_credentials_if_needed("some-other-extension", cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-openrouter:openrouter");
assert!(
credentials.is_none(),
"Should not migrate for other extensions"
);
}
}

View File

@@ -0,0 +1,124 @@
use credentials_provider::CredentialsProvider;
use gpui::App;
const OPENAI_EXTENSION_ID: &str = "openai";
const OPENAI_PROVIDER_ID: &str = "openai";
const OPENAI_DEFAULT_API_URL: &str = "https://api.openai.com/v1";
/// Migrates OpenAI API credentials from the old built-in provider location
/// to the new extension-based location.
///
/// This should only be called during auto-install of the extension.
pub fn migrate_openai_credentials_if_needed(extension_id: &str, cx: &mut App) {
if extension_id != OPENAI_EXTENSION_ID {
return;
}
let extension_credential_key = format!(
"extension-llm-{}:{}",
OPENAI_EXTENSION_ID, OPENAI_PROVIDER_ID
);
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
// Read from old location
let old_credential = credentials_provider
.read_credentials(OPENAI_DEFAULT_API_URL, &cx)
.await
.ok()
.flatten();
let api_key = match old_credential {
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
Ok(key) if !key.is_empty() => key,
Ok(_) => {
log::debug!("Existing OpenAI API key is empty, nothing to migrate");
return;
}
Err(_) => {
log::error!("Failed to decode OpenAI API key as UTF-8");
return;
}
},
None => {
log::debug!("No existing OpenAI API key found to migrate");
return;
}
};
log::info!("Migrating existing OpenAI API key to OpenAI extension");
match credentials_provider
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
.await
{
Ok(()) => {
log::info!("Successfully migrated OpenAI API key to extension");
}
Err(err) => {
log::error!("Failed to migrate OpenAI API key: {}", err);
}
}
})
.detach();
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
#[gpui::test]
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
let api_key = "sk-test-key-12345";
cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
});
cx.run_until_parked();
let migrated = cx.read_credentials("extension-llm-openai:openai");
assert!(migrated.is_some(), "Credentials should have been migrated");
let (username, password) = migrated.unwrap();
assert_eq!(username, "Bearer");
assert_eq!(String::from_utf8(password).unwrap(), api_key);
}
#[gpui::test]
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-openai:openai");
assert!(
credentials.is_none(),
"Should not create credentials if none existed"
);
}
#[gpui::test]
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
let api_key = "sk-test-key";
cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_openai_credentials_if_needed("some-other-extension", cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-openai:openai");
assert!(
credentials.is_none(),
"Should not migrate for other extensions"
);
}
}

View File

@@ -1,9 +1,11 @@
pub mod llm_provider;
pub mod wit;
use crate::capability_granter::CapabilityGranter;
use crate::{ExtensionManifest, ExtensionSettings};
use anyhow::{Context as _, Result, anyhow, bail};
use async_trait::async_trait;
use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest};
use extension::{
CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary,
@@ -64,7 +66,7 @@ pub struct WasmHost {
#[derive(Clone, Debug)]
pub struct WasmExtension {
tx: UnboundedSender<ExtensionCall>,
tx: Arc<UnboundedSender<ExtensionCall>>,
pub manifest: Arc<ExtensionManifest>,
pub work_dir: Arc<Path>,
#[allow(unused)]
@@ -74,7 +76,10 @@ pub struct WasmExtension {
impl Drop for WasmExtension {
fn drop(&mut self) {
self.tx.close_channel();
// Only close the channel when this is the last clone holding the sender
if Arc::strong_count(&self.tx) == 1 {
self.tx.close_channel();
}
}
}
@@ -671,7 +676,7 @@ impl WasmHost {
Ok(WasmExtension {
manifest,
work_dir,
tx,
tx: Arc::new(tx),
zed_api_version,
_task: task,
})

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,7 @@ use lsp::LanguageServerName;
use release_channel::ReleaseChannel;
use task::{DebugScenario, SpawnInTerminal, TaskTemplate, ZedDebugConfig};
use crate::wasm_host::wit::since_v0_6_0::dap::StartDebuggingRequestArgumentsRequest;
use crate::wasm_host::wit::since_v0_8_0::dap::StartDebuggingRequestArgumentsRequest;
use super::{WasmState, wasm_engine};
use anyhow::{Context as _, Result, anyhow};
@@ -33,6 +33,19 @@ pub use latest::CodeLabelSpanLiteral;
pub use latest::{
CodeLabel, CodeLabelSpan, Command, DebugAdapterBinary, ExtensionProject, Range, SlashCommand,
zed::extension::context_server::ContextServerConfiguration,
zed::extension::llm_provider::{
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
ImageData as LlmImageData, MessageContent as LlmMessageContent,
MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo,
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
ToolUseJsonParseError as LlmToolUseJsonParseError,
},
zed::extension::lsp::{
Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind,
},
@@ -1007,6 +1020,20 @@ impl Extension {
resource: Resource<Arc<dyn WorktreeDelegate>>,
) -> Result<Result<DebugAdapterBinary, String>> {
match self {
Extension::V0_8_0(ext) => {
let dap_binary = ext
.call_get_dap_binary(
store,
&adapter_name,
&task.try_into()?,
user_installed_path.as_ref().and_then(|p| p.to_str()),
resource,
)
.await?
.map_err(|e| anyhow!("{e:?}"))?;
Ok(Ok(dap_binary))
}
Extension::V0_6_0(ext) => {
let dap_binary = ext
.call_get_dap_binary(
@@ -1032,6 +1059,16 @@ impl Extension {
config: serde_json::Value,
) -> Result<Result<StartDebuggingRequestArgumentsRequest, String>> {
match self {
Extension::V0_8_0(ext) => {
let config =
serde_json::to_string(&config).context("Adapter config is not a valid JSON")?;
let result = ext
.call_dap_request_kind(store, &adapter_name, &config)
.await?
.map_err(|e| anyhow!("{e:?}"))?;
Ok(Ok(result))
}
Extension::V0_6_0(ext) => {
let config =
serde_json::to_string(&config).context("Adapter config is not a valid JSON")?;
@@ -1052,6 +1089,15 @@ impl Extension {
config: ZedDebugConfig,
) -> Result<Result<DebugScenario, String>> {
match self {
Extension::V0_8_0(ext) => {
let config = config.into();
let result = ext
.call_dap_config_to_scenario(store, &config)
.await?
.map_err(|e| anyhow!("{e:?}"))?;
Ok(Ok(result.try_into()?))
}
Extension::V0_6_0(ext) => {
let config = config.into();
let dap_binary = ext
@@ -1074,6 +1120,20 @@ impl Extension {
debug_adapter_name: String,
) -> Result<Option<DebugScenario>> {
match self {
Extension::V0_8_0(ext) => {
let build_config_template = build_config_template.into();
let result = ext
.call_dap_locator_create_scenario(
store,
&locator_name,
&build_config_template,
&resolved_label,
&debug_adapter_name,
)
.await?;
Ok(result.map(TryInto::try_into).transpose()?)
}
Extension::V0_6_0(ext) => {
let build_config_template = build_config_template.into();
let dap_binary = ext
@@ -1099,6 +1159,15 @@ impl Extension {
resolved_build_task: SpawnInTerminal,
) -> Result<Result<DebugRequest, String>> {
match self {
Extension::V0_8_0(ext) => {
let build_config_template = resolved_build_task.try_into()?;
let dap_request = ext
.call_run_dap_locator(store, &locator_name, &build_config_template)
.await?
.map_err(|e| anyhow!("{e:?}"))?;
Ok(Ok(dap_request.into()))
}
Extension::V0_6_0(ext) => {
let build_config_template = resolved_build_task.try_into()?;
let dap_request = ext
@@ -1111,6 +1180,174 @@ impl Extension {
_ => anyhow::bail!("`dap_locator_create_scenario` not available prior to v0.6.0"),
}
}
pub async fn call_llm_providers(
&self,
store: &mut Store<WasmState>,
) -> Result<Vec<latest::llm_provider::ProviderInfo>> {
match self {
Extension::V0_8_0(ext) => ext.call_llm_providers(store).await,
_ => Ok(Vec::new()),
}
}
pub async fn call_llm_provider_models(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<Result<Vec<latest::llm_provider::ModelInfo>, String>> {
match self {
Extension::V0_8_0(ext) => ext.call_llm_provider_models(store, provider_id).await,
_ => anyhow::bail!("`llm_provider_models` not available prior to v0.8.0"),
}
}
pub async fn call_llm_provider_settings_markdown(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<Option<String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_settings_markdown(store, provider_id)
.await
}
_ => Ok(None),
}
}
pub async fn call_llm_provider_is_authenticated(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<bool> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_is_authenticated(store, provider_id)
.await
}
_ => Ok(false),
}
}
pub async fn call_llm_provider_start_device_flow_sign_in(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<Result<String, String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_start_device_flow_sign_in(store, provider_id)
.await
}
_ => {
anyhow::bail!(
"`llm_provider_start_device_flow_sign_in` not available prior to v0.8.0"
)
}
}
}
pub async fn call_llm_provider_poll_device_flow_sign_in(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<Result<(), String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_poll_device_flow_sign_in(store, provider_id)
.await
}
_ => {
anyhow::bail!(
"`llm_provider_poll_device_flow_sign_in` not available prior to v0.8.0"
)
}
}
}
pub async fn call_llm_provider_reset_credentials(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<Result<(), String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_reset_credentials(store, provider_id)
.await
}
_ => anyhow::bail!("`llm_provider_reset_credentials` not available prior to v0.8.0"),
}
}
pub async fn call_llm_count_tokens(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
model_id: &str,
request: &latest::llm_provider::CompletionRequest,
) -> Result<Result<u64, String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_count_tokens(store, provider_id, model_id, request)
.await
}
_ => anyhow::bail!("`llm_count_tokens` not available prior to v0.8.0"),
}
}
pub async fn call_llm_stream_completion_start(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
model_id: &str,
request: &latest::llm_provider::CompletionRequest,
) -> Result<Result<String, String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_stream_completion_start(store, provider_id, model_id, request)
.await
}
_ => anyhow::bail!("`llm_stream_completion_start` not available prior to v0.8.0"),
}
}
pub async fn call_llm_stream_completion_next(
&self,
store: &mut Store<WasmState>,
stream_id: &str,
) -> Result<Result<Option<latest::llm_provider::CompletionEvent>, String>> {
match self {
Extension::V0_8_0(ext) => ext.call_llm_stream_completion_next(store, stream_id).await,
_ => anyhow::bail!("`llm_stream_completion_next` not available prior to v0.8.0"),
}
}
pub async fn call_llm_stream_completion_close(
&self,
store: &mut Store<WasmState>,
stream_id: &str,
) -> Result<()> {
match self {
Extension::V0_8_0(ext) => ext.call_llm_stream_completion_close(store, stream_id).await,
_ => anyhow::bail!("`llm_stream_completion_close` not available prior to v0.8.0"),
}
}
pub async fn call_llm_cache_configuration(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
model_id: &str,
) -> Result<Option<latest::llm_provider::CacheConfiguration>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_cache_configuration(store, provider_id, model_id)
.await
}
_ => Ok(None),
}
}
}
trait ToWasmtimeResult<T> {

View File

@@ -32,8 +32,6 @@ wasmtime::component::bindgen!({
},
});
pub use self::zed::extension::*;
mod settings {
#![allow(dead_code)]
include!(concat!(env!("OUT_DIR"), "/since_v0.6.0/settings.rs"));

View File

@@ -1,11 +1,11 @@
use crate::wasm_host::wit::since_v0_6_0::{
use crate::wasm_host::wit::since_v0_8_0::{
dap::{
AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest,
StartDebuggingRequestArguments, TcpArguments, TcpArgumentsTemplate,
},
lsp::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind},
slash_command::SlashCommandOutputSection,
};
use crate::wasm_host::wit::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind};
use crate::wasm_host::{WasmState, wit::ToWasmtimeResult};
use ::http_client::{AsyncBody, HttpRequestExt};
use ::settings::{Settings, WorktreeId};
@@ -13,6 +13,7 @@ use anyhow::{Context as _, Result, bail};
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use async_trait::async_trait;
use credentials_provider::CredentialsProvider;
use extension::{
ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate,
};
@@ -22,12 +23,14 @@ use gpui::{BackgroundExecutor, SharedString};
use language::{BinaryStatus, LanguageName, language_settings::AllLanguageSettings};
use project::project_settings::ProjectSettings;
use semver::Version;
use smol::net::TcpListener;
use std::{
env,
net::Ipv4Addr,
path::{Path, PathBuf},
str::FromStr,
sync::{Arc, OnceLock},
time::Duration,
};
use task::{SpawnInTerminal, ZedDebugConfig};
use url::Url;
@@ -1107,3 +1110,361 @@ impl ExtensionImports for WasmState {
.to_wasmtime_result()
}
}
impl llm_provider::Host for WasmState {
async fn request_credential(
&mut self,
_provider_id: String,
_credential_type: llm_provider::CredentialType,
_label: String,
_placeholder: String,
) -> wasmtime::Result<Result<bool, String>> {
// For now, credential requests return false (not provided)
// Extensions should use get_env_var to check for env vars first,
// then store_credential/get_credential for manual storage
// Full UI credential prompting will be added in a future phase
Ok(Ok(false))
}
async fn get_credential(&mut self, provider_id: String) -> wasmtime::Result<Option<String>> {
let extension_id = self.manifest.id.clone();
// Check if this provider has an env var configured and if the user has allowed it
let env_var_name = self
.manifest
.language_model_providers
.get(&Arc::<str>::from(provider_id.as_str()))
.and_then(|entry| entry.auth.as_ref())
.and_then(|auth| auth.env_var.clone());
if let Some(env_var_name) = env_var_name {
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
// Read settings dynamically to get current allowed_env_var_providers
let is_allowed = self
.on_main_thread({
let full_provider_id = full_provider_id.clone();
move |cx| {
async move {
cx.update(|cx| {
crate::extension_settings::ExtensionSettings::get_global(cx)
.allowed_env_var_providers
.contains(&full_provider_id)
})
}
.boxed_local()
}
})
.await
.unwrap_or(false);
if is_allowed {
if let Ok(value) = env::var(&env_var_name) {
if !value.is_empty() {
return Ok(Some(value));
}
}
}
}
// Fall back to credential store
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
self.on_main_thread(move |cx| {
async move {
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
let result = credentials_provider
.read_credentials(&credential_key, cx)
.await
.ok()
.flatten();
Ok(result.map(|(_, password)| String::from_utf8_lossy(&password).to_string()))
}
.boxed_local()
})
.await
}
async fn store_credential(
&mut self,
provider_id: String,
value: String,
) -> wasmtime::Result<Result<(), String>> {
let extension_id = self.manifest.id.clone();
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
self.on_main_thread(move |cx| {
async move {
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
credentials_provider
.write_credentials(&credential_key, "api_key", value.as_bytes(), cx)
.await
.map_err(|e| anyhow::anyhow!("{}", e))
}
.boxed_local()
})
.await
.to_wasmtime_result()
}
async fn delete_credential(
&mut self,
provider_id: String,
) -> wasmtime::Result<Result<(), String>> {
let extension_id = self.manifest.id.clone();
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
self.on_main_thread(move |cx| {
async move {
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
credentials_provider
.delete_credentials(&credential_key, cx)
.await
.map_err(|e| anyhow::anyhow!("{}", e))
}
.boxed_local()
})
.await
.to_wasmtime_result()
}
async fn get_env_var(&mut self, name: String) -> wasmtime::Result<Option<String>> {
let extension_id = self.manifest.id.clone();
// Find which provider (if any) declares this env var in its auth config
let mut allowed_provider_id: Option<Arc<str>> = None;
for (provider_id, provider_entry) in &self.manifest.language_model_providers {
if let Some(auth_config) = &provider_entry.auth {
if auth_config.env_var.as_deref() == Some(&name) {
allowed_provider_id = Some(provider_id.clone());
break;
}
}
}
// If no provider declares this env var, deny access
let Some(provider_id) = allowed_provider_id else {
log::warn!(
"Extension {} attempted to read env var {} which is not declared in any provider auth config",
extension_id,
name
);
return Ok(None);
};
// Check if the user has allowed this provider to read env vars
// Read settings dynamically to get current allowed_env_var_providers
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
let is_allowed = self
.on_main_thread({
let full_provider_id = full_provider_id.clone();
move |cx| {
async move {
cx.update(|cx| {
crate::extension_settings::ExtensionSettings::get_global(cx)
.allowed_env_var_providers
.contains(&full_provider_id)
})
}
.boxed_local()
}
})
.await
.unwrap_or(false);
if !is_allowed {
log::debug!(
"Extension {} provider {} is not allowed to read env var {}",
extension_id,
provider_id,
name
);
return Ok(None);
}
Ok(env::var(&name).ok())
}
async fn oauth_start_web_auth(
&mut self,
config: llm_provider::OauthWebAuthConfig,
) -> wasmtime::Result<Result<llm_provider::OauthWebAuthResult, String>> {
let auth_url = config.auth_url;
let callback_path = config.callback_path;
let timeout_secs = config.timeout_secs.unwrap_or(300);
self.on_main_thread(move |cx| {
async move {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.map_err(|e| anyhow::anyhow!("Failed to bind localhost server: {}", e))?;
let port = listener
.local_addr()
.map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))?
.port();
let auth_url_with_port = auth_url.replace("{port}", &port.to_string());
cx.update(|cx| {
cx.open_url(&auth_url_with_port);
})?;
let accept_future = async {
let (mut stream, _) = listener
.accept()
.await
.map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?;
let mut request_line = String::new();
{
let mut reader = smol::io::BufReader::new(&mut stream);
smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
.await
.map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
}
let callback_url = if let Some(path_start) = request_line.find(' ') {
if let Some(path_end) = request_line[path_start + 1..].find(' ') {
let path = &request_line[path_start + 1..path_start + 1 + path_end];
if path.starts_with(&callback_path) || path.starts_with(&format!("/{}", callback_path.trim_start_matches('/'))) {
format!("http://localhost:{}{}", port, path)
} else {
return Err(anyhow::anyhow!(
"Unexpected callback path: {}",
path
));
}
} else {
return Err(anyhow::anyhow!("Malformed HTTP request"));
}
} else {
return Err(anyhow::anyhow!("Malformed HTTP request"));
};
let response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/html\r\n\
Connection: close\r\n\
\r\n\
<!DOCTYPE html>\
<html><head><title>Authentication Complete</title></head>\
<body style=\"font-family: system-ui, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0;\">\
<div style=\"text-align: center;\">\
<h1>Authentication Complete</h1>\
<p>You can close this window and return to Zed.</p>\
</div></body></html>";
smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes())
.await
.ok();
smol::io::AsyncWriteExt::flush(&mut stream).await.ok();
Ok(callback_url)
};
let timeout_duration = Duration::from_secs(timeout_secs as u64);
let callback_url = smol::future::or(
accept_future,
async {
smol::Timer::after(timeout_duration).await;
Err(anyhow::anyhow!(
"OAuth callback timed out after {} seconds",
timeout_secs
))
},
)
.await?;
Ok(llm_provider::OauthWebAuthResult {
callback_url,
port: port as u32,
})
}
.boxed_local()
})
.await
.to_wasmtime_result()
}
async fn send_oauth_http_request(
&mut self,
request: llm_provider::OauthHttpRequest,
) -> wasmtime::Result<Result<llm_provider::OauthHttpResponse, String>> {
let http_client = self.host.http_client.clone();
self.on_main_thread(move |_cx| {
async move {
let method = match request.method.to_uppercase().as_str() {
"GET" => ::http_client::Method::GET,
"POST" => ::http_client::Method::POST,
"PUT" => ::http_client::Method::PUT,
"DELETE" => ::http_client::Method::DELETE,
"PATCH" => ::http_client::Method::PATCH,
_ => {
return Err(anyhow::anyhow!(
"Unsupported HTTP method: {}",
request.method
));
}
};
let mut builder = ::http_client::Request::builder()
.method(method)
.uri(&request.url);
for (key, value) in &request.headers {
builder = builder.header(key.as_str(), value.as_str());
}
let body = if request.body.is_empty() {
AsyncBody::empty()
} else {
AsyncBody::from(request.body.into_bytes())
};
let http_request = builder
.body(body)
.map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
let mut response = http_client
.send(http_request)
.await
.map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
let status = response.status().as_u16();
let headers: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let mut body_bytes = Vec::new();
futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes)
.await
.map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
let body = String::from_utf8_lossy(&body_bytes).to_string();
Ok(llm_provider::OauthHttpResponse {
status,
headers,
body,
})
}
.boxed_local()
})
.await
.to_wasmtime_result()
}
async fn oauth_open_browser(&mut self, url: String) -> wasmtime::Result<Result<(), String>> {
self.on_main_thread(move |cx| {
async move {
cx.update(|cx| {
cx.open_url(&url);
})?;
Ok(())
}
.boxed_local()
})
.await
.to_wasmtime_result()
}
}

View File

@@ -442,7 +442,9 @@ impl ExtensionsPage {
let extension_store = ExtensionStore::global(cx).read(cx);
match extension_store.outstanding_operations().get(extension_id) {
Some(ExtensionOperation::Install) => ExtensionStatus::Installing,
Some(ExtensionOperation::Install) | Some(ExtensionOperation::AutoInstall) => {
ExtensionStatus::Installing
}
Some(ExtensionOperation::Remove) => ExtensionStatus::Removing,
Some(ExtensionOperation::Upgrade) => ExtensionStatus::Upgrading,
None => match extension_store.installed_extensions().get(extension_id) {

View File

@@ -232,14 +232,12 @@ impl From<Oid> for usize {
#[derive(Copy, Clone, Debug)]
pub enum RunHook {
PreCommit,
PrePush,
}
impl RunHook {
pub fn as_str(&self) -> &str {
match self {
Self::PreCommit => "pre-commit",
Self::PrePush => "pre-push",
}
}
@@ -250,7 +248,6 @@ impl RunHook {
pub fn from_proto(value: i32) -> Option<Self> {
match value {
0 => Some(Self::PreCommit),
1 => Some(Self::PrePush),
_ => None,
}
}

View File

@@ -652,6 +652,7 @@ pub struct RealGitRepository {
pub repository: Arc<Mutex<git2::Repository>>,
pub system_git_binary_path: Option<PathBuf>,
pub any_git_binary_path: PathBuf,
any_git_binary_help_output: Arc<Mutex<Option<SharedString>>>,
executor: BackgroundExecutor,
}
@@ -670,6 +671,7 @@ impl RealGitRepository {
system_git_binary_path,
any_git_binary_path,
executor,
any_git_binary_help_output: Arc::new(Mutex::new(None)),
})
}
@@ -680,6 +682,27 @@ impl RealGitRepository {
.context("failed to read git work directory")
.map(Path::to_path_buf)
}
async fn any_git_binary_help_output(&self) -> SharedString {
if let Some(output) = self.any_git_binary_help_output.lock().clone() {
return output;
}
let git_binary_path = self.any_git_binary_path.clone();
let executor = self.executor.clone();
let working_directory = self.working_directory();
let output: SharedString = self
.executor
.spawn(async move {
GitBinary::new(git_binary_path, working_directory?, executor)
.run(["help", "-a"])
.await
})
.await
.unwrap_or_default()
.into();
*self.any_git_binary_help_output.lock() = Some(output.clone());
output
}
}
#[derive(Clone, Debug)]
@@ -2290,18 +2313,47 @@ impl GitRepository for RealGitRepository {
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
let working_directory = self.working_directory();
let repository = self.repository.clone();
let git_binary_path = self.any_git_binary_path.clone();
let executor = self.executor.clone();
self.executor
.spawn(async move {
let working_directory = working_directory?;
let git = GitBinary::new(git_binary_path, working_directory, executor)
.envs(HashMap::clone(&env));
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
.await?;
Ok(())
})
.boxed()
let help_output = self.any_git_binary_help_output();
// Note: Do not spawn these commands on the background thread, as this causes some git hooks to hang.
async move {
let working_directory = working_directory?;
if !help_output
.await
.lines()
.any(|line| line.trim().starts_with("hook "))
{
let hook_abs_path = repository.lock().path().join("hooks").join(hook.as_str());
if hook_abs_path.is_file() {
let output = new_smol_command(&hook_abs_path)
.envs(env.iter())
.current_dir(&working_directory)
.output()
.await?;
if !output.status.success() {
return Err(GitBinaryCommandError {
stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
status: output.status,
}
.into());
}
}
return Ok(());
}
let git = GitBinary::new(git_binary_path, working_directory, executor)
.envs(HashMap::clone(&env));
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
.await?;
Ok(())
}
.boxed()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -24,6 +24,7 @@ pub struct GitPanelSettings {
pub fallback_branch_name: String,
pub sort_by_path: bool,
pub collapse_untracked_diff: bool,
pub tree_view: bool,
}
impl ScrollbarVisibility for GitPanelSettings {
@@ -56,6 +57,7 @@ impl Settings for GitPanelSettings {
fallback_branch_name: git_panel.fallback_branch_name.unwrap(),
sort_by_path: git_panel.sort_by_path.unwrap(),
collapse_untracked_diff: git_panel.collapse_untracked_diff.unwrap(),
tree_view: git_panel.tree_view.unwrap(),
}
}
}

Some files were not shown because too many files have changed in this diff Show More