Compare commits

..

52 Commits

Author SHA1 Message Date
Oleksiy Syvokon
610536201b Merge branch 'main' into ep-distill 2025-12-12 21:03:52 +02:00
Agus Zubiaga
60f4aa333b edit prediction cli: Improve error handling (#44718)
We were panicking whenever something went wrong with an example in the
CLI. This can be very disruptive when running many examples, and e.g a
single request fails. Instead, if running more than one example, errors
will now be logged alongside instructions to explore and re-run the
example by itself.

<img width="1454" height="744" alt="CleanShot 2025-12-12 at 13 32 04@2x"
src="https://github.com/user-attachments/assets/87c59e64-08b9-4461-af5b-03af5de94152"></img>


You can still opt in to stop as soon as en error occurs with the new
`--failfast` argument.

Release Notes:

- N/A
2025-12-12 14:15:58 -03:00
localcc
a698f1bf63 Fix Bounds::contains (#44711)
Closes #11643 

Release Notes:

- Fixed double hover state on windows

Co-authored-by: Kirill Bulatov <mail4score@gmail.com>
2025-12-12 14:49:29 +00:00
localcc
636d11ebec Multiple priority scheduler (#44701)
Improves the scheduler by allowing tasks to have a set priority which
will significantly improve responsiveness.

Release notes:

- N/A

---------

Co-authored-by: Yara <git@yara.blue>
Co-authored-by: dvdsk <noreply@davidsk.dev>
2025-12-12 06:32:30 -08:00
Agus Zubiaga
4d0e760b04 edit prediction cli: Progress output cleanup (#44708)
- Limit status lines to 10 in case `max_parallelism` is specified with a
grater value
- Handle logging gracefully rather than writing over it when clearing
status lines

Release Notes:

- N/A
2025-12-12 14:03:08 +00:00
localcc
8bd4d866b9 Windows/send keystrokes (#44707)
Closes #41176 

Release Notes:

- Fixed SendKeystrokes mapping on windows

Co-authored-by: Kirill Bulatov <mail4score@gmail.com>
2025-12-12 05:51:11 -08:00
Piotr Osiewicz
47c30b6da7 git: Revert "Ignore whitespace in git blame invocation" (#44648)
Reverts zed-industries/zed#35960
cc @cole-miller

---------

Co-authored-by: Cole Miller <cole@zed.dev>
2025-12-12 14:28:25 +01:00
Lukas Wirth
18d344e118 language: Make TreeSitterData only shared between snapshots of the same version (#44198)
Currently we have a single cache for this data shared between all
snapshots which is incorrect, as we might update the cache to a new
version while having old snapshots around which then may try to access
new data with old offsets/rows.

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-12-12 14:15:50 +01:00
Agus Zubiaga
610cc1b138 edit prediction cli: Cargo-style progress output (#44675)
Release Notes:

- N/A
2025-12-12 09:43:16 -03:00
Xiaobo Liu
a07ea1a272 util: Avoid redundant Arc allocation in SanitizedPath::from_arc (#44479)
Release Notes:

- N/A

Signed-off-by: Xiaobo Liu <cppcoffee@gmail.com>
2025-12-12 13:33:49 +01:00
Lukas Wirth
e03fa114a7 remote: Remove unnecessary and incorrect single quote in MasterProcess (#44697)
Closes https://github.com/zed-industries/zed/issues/43992

Release Notes:

- Fixed remoting not working on some linux and mac systems
2025-12-12 11:53:15 +00:00
Dino
17db7b0e99 Add keymap field to bug report issue template (#44564)
Update the issue template used for "Report a bug" to include a field
specifically for the user's keymap file, as we've seen multiple cases
where we end up asking the users for their custom keymap, to ensure that
they're not overriding existing defaults.

Release Notes:

- N/A
2025-12-12 11:17:15 +00:00
Kirill Bulatov
1afe29422b Move servers back from the background thread (#44696)
Partial revert of https://github.com/zed-industries/zed/pull/44631

With this and `sccache` enabled, I get 
<img width="3456" height="1096" alt="image"
src="https://github.com/user-attachments/assets/937760fb-8b53-49f8-ae63-4df1d31b292b"
/>

and r-a infinitely hangs waiting on this.

Release Notes:

- N/A
2025-12-12 11:16:17 +00:00
Lukas Wirth
a8aa7622b7 util: Fix shell builder quoting regressions (#44685)
Follow up to https://github.com/zed-industries/zed/pull/42382

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-12-12 11:06:49 +00:00
Agus Zubiaga
a66854e435 commit view: Reuse avatar asset (#44554) 2025-12-12 07:42:05 -03:00
Dino
12073e10f8 Fix missing buffer font features in Blame UI, Hover Popover and Markdown Preview (#44657)
- Fix missing font features in 
  `git_ui::blame_ui::GitBlameRenderer.render_blame_entry`
- Fix missing buffer font features in
`markdown_preview::markdown_renderer`
- Update the way that the markdown style is built for hover popovers so
  that, for code blocks, the buffer font features are used.
- Introduce `gpui::Styled.font_features` to allow callers to also set
  the font's features, similar to how `gpui::Styled.font_family` already
  exists.

Relates to #44209

Release Notes:

- Fixed wrong font features in Blame UI, Hover Popover and Markdown
Preview
2025-12-12 09:55:06 +00:00
Oleksiy Syvokon
a2a96e4038 Merge branch 'main' into ep-distill 2025-12-12 11:28:15 +02:00
Smit Barmase
1186b50ca4 git_ui: Fix commit and amend not working via keybinds in commit modal (#44690)
Closes #41567

We were using the git panel editor to check the focus where the commit
modal has its only editor.

Release Notes:

- Fixed an issue where commit and amend actions wouldn’t trigger when
using keybinds in the commit modal.
2025-12-12 14:41:48 +05:30
Lukas Wirth
65130a9ca9 windows: Fix more VSCode keybinds (#44684)
Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-12-12 07:04:55 +00:00
Dino
23d18fde8c git_ui: Always use latest commit message on amend (#44553)
Update the behavior of `git::Amend` to ensure that the latest head
commit message, if available, is always loaded into the commit message
editor, regardless of its state. The previous text, if any, is now also
restored after the amend is finished.

- Update `FakeGitRepository.show` to include a message in the returned
`CommitDetails` so we can assert that this specific commit message is
set in the commit message editor.
- Add default implementation for `FakeGitRepository.commit` and
`FakeGitRepository.run_hook` to ensure that tests are able to run and
don't panic on `unimplemented!()`
- Refactor `GitPanel.load_last_commit_message_if_empty` to
`GitPanel.load_last_commit_message`, ensuring that the head commit
message is always loaded, regardless of whether the commit message
editor is empty.
- Update `GitPanel.commit_changes` to ensure that the pending amend
state is only updated if the editor managed to actually commit the
changes. This also ensures that we don't restore the commit message
editor's contents when amending a commit, before the amend is actually
processed.
- Update `CommitModal.amend`, removing the call to
`GitPanel.set_amend_pending` as that is now handled by the background
task created in `GitPanel.commit_changes`.
- Split the `commit` and `amend` methods from the event handlers so that
the methods can be called directly, as is now being done by
`CommitModal.on_commit` and `CommitModal.on_amend`.

Release Notes:

- Updated the ‎`git: amend` command to always load the latest head
commit message, and to restore any previously entered text in the commit
message editor after the amend completes
2025-12-12 12:16:43 +05:30
Conrad Irwin
332c0d03d1 Terminal regex perf improvements (#44679)
Closes #44510

Release Notes:

- Improve performance of terminal link matching even more
2025-12-11 22:40:48 -07:00
Max Brunsfeld
b871130220 Restructure concurrency in EP CLI to allow running many examples in big rust repos (#44673)
Release Notes:

- N/A
2025-12-12 01:58:53 +00:00
Conrad Irwin
0a1e5f93a0 Allow triggering after release workflow manually (#44671)
Release Notes:

- N/A
2025-12-11 16:54:10 -07:00
Piotr Osiewicz
8d0fff688f rust: Change cwd of cargo run-esque tasks to use package root, not dirname of current file as cwd (#44672)
This also applies to `cargo clean` one.

Closes #20873

Release Notes:

- rust: Changed cwd of tasks that spawn a binary target to the root of a
current package (which used to be a directory of the current source
file).
2025-12-11 23:47:40 +00:00
Kirill Bulatov
717d898692 Show an underlying reason on file opening (#44664)
Based on the debug attempt from
https://github.com/zed-industries/zed/issues/44370

Release Notes:

- N/A
2025-12-11 23:20:25 +00:00
Max Brunsfeld
1cd7563f04 Add ep distill command, for generating edit prediction training examples (#44670)
Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-12-11 14:57:58 -08:00
Agus Zubiaga
fc6ca38989 edit prediction cli: Improve language server reliability (#44666)
We weren't waiting for ALL language servers of a buffer to start, only
the first one.

Release Notes:

- N/A
2025-12-11 22:30:51 +00:00
Yara 🏳️‍⚧️
1029a8fbaf Add support for manual spans, expand instrumentation (#44663)
Release Notes:

- N/A

---------

Co-authored-by: Cameron <cameron@zed.dev>
2025-12-11 22:29:47 +00:00
KyleBarton
07748b7bae Add scrolling functionality to markdown preview mode (#44585)
Closes #21324

Adds four new commands:
- `markdown::MoveUp`, `markdown::MoveDown` - these scroll up and down in
markdown preview mode, by no more than the height of a large headline.

- `markdown::MoveUpByItem`, and `markdown::MoveDownByItem` - these
scroll up and down by the height of the item at the top of the markdown
preview window. So headlines and large codeblocks, for instance, scroll
further than individual paragraph lines.

Also attempts to create sensible defaults:
`down` -> `markdown::ScrollDown`
`up` -> `markdown::ScrollUp`
`alt-down` -> `markdown::ScrollDownByItem`
`alt-up` -> `markdown::ScrollUpByItem`

And in Vim:

`ctrl-u` -> `markdown::ScrollPageUp`
`ctrl-d` -> `markdown::ScrollPageDown`
`ctrl-e` -> `markdown::ScrollDown`
`ctrl-y` -> `markdown::ScrollUp`


Release Notes:

- Added commands `markdown::ScrollUp`, `markdown::ScrollDown`,
`markdown::ScrollUpByItem`, and `markdown::ScrollDownByItem`
- Changed commands `markdown::MovePageUp` to `markdown::ScrollPageUp`
and `markdown::MovePageDown` to `markdown::ScrollPageDown`
2025-12-11 22:18:38 +00:00
Agus Zubiaga
37f2ac24b8 edit prediction cli: Skip worktree scan (#44658)
Release Notes:

- N/A

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-12-11 21:05:50 +00:00
Richard Feldman
b5a0a3322d Add GPT-5.2 support (#44656)
<img width="429" height="188" alt="Screenshot 2025-12-11 at 3 45 26 PM"
src="https://github.com/user-attachments/assets/fe9f1b86-7268-4c63-a8c2-75ac671012c9"
/>


Release Notes:

- Added GPT-5.2 support when using your own OpenAI key
2025-12-11 15:49:10 -05:00
Kirill Bulatov
eb7da26d19 Disable word completions in markdown and plaintext files (#44654)
Reformat on save had also added trailing commas.

Release Notes:

- Disable word completions in plaintext and markdown files, see
https://zed.dev/docs/configuring-zed?highlight=word%20completio#words on
how to enable it back in the language settings
2025-12-11 20:15:38 +00:00
Zachiah Sawyer
9c099e7ed3 Update file vs folder open keymaps on macos/linux to match windows (#44598)
Closes #44597

Matches what was done here:

55dfbaca68 (diff-cc832e840d61526768bb4acec7645a71e8b160a65a30e7ce9e9c51762b58199a)

Release Notes:

- Standardize Cmd-O = open file, Cmd-K Cmd-O = open folder across
operating systems.

---------

Co-authored-by: Lukas Wirth <me@lukaswirth.dev>
2025-12-11 19:40:47 +00:00
Danilo Leal
7669b05268 image viewer: Make image metadata not a button (#44651)
Tiny thing I noticed; the image metadata showing on the status bar was
previously a button, but given that nothing happens when you click it,
it doesn't need to be one. Having hover, active, and all other states
was confusing.

Release Notes:

- N/A
2025-12-11 19:14:36 +00:00
Oleksiy Syvokon
ec26556dab Parse expected output for the zeta2 prompt
Co-authored-by: Agus Zubiaga <agus@zed.dev>
    Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-12-11 20:13:44 +02:00
Agus Zubiaga
2098b67304 edit prediction: Respect enabled settings when refreshing from diagnostics (#44640)
Release Notes:

- N/A
2025-12-11 17:39:57 +00:00
Oleksiy Syvokon
1a8d8e9572 Add ep distill command
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-12-11 19:36:30 +02:00
Lukas Wirth
5a6198cc39 language: Spawn language servers on background threads (#44631)
Closes https://github.com/zed-industries/zed/issues/39056

Leverages a new `await_on_background` API that spawns the future on the
background but blocks the current task, allowing to borrow from the
surrounding scope.

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-12-11 17:23:27 +00:00
Oleksiy Syvokon
ab893ca754 ep_cli fixes, non-batched teacher, and other
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-12-11 19:20:25 +02:00
Siame Rafiq
cda78c12ab git: Make permalinks aware of current diffs (#41915)
Addressing #22546, we want git permalinks to be aware of the current
changes within the buffer.

This change calculates how many lines have been added/deleted between
the start and end of the selection and uses those values to offset the
selection.

This is done within `Editor::get_permalink_to_line` so that it can be
passed to any git_store.

Example:

<img width="284" height="316" alt="image"
src="https://github.com/user-attachments/assets/268043a0-2fc8-41c1-b094-d650fd4e0ae0"
/>

Where this selections permalink would previously return L3-L9, it now
returns L2-L7.

Release Notes:

- git: make permalinks aware of current diffs

Closes #22546

---

This is my first PR into the zed repository so very happy for any
feedback on how I've implemented this. Thanks!
2025-12-11 10:53:20 -05:00
Smit Barmase
f4378672b8 editor: Fix auto-indent cases in Markdown (#44616)
Builds on https://github.com/zed-industries/zed/pull/40794 and
https://github.com/zed-industries/zed/pull/44381

- Fixes the case where creating a new line inside a nested list puts the
cursor correctly under that nested list item.
- Fixes the case where typing a new list item at the expected indent no
longer auto-indents or outdents incorrectly.

Release Notes:

- Fixed an issue in Markdown where new list items weren’t respecting the
expected indentation on type.
2025-12-11 21:14:15 +05:30
Yara 🏳️‍⚧️
ecb8d3d4dd Revert "Multiple priority scheduler" (#44637)
Reverts zed-industries/zed#44575
2025-12-11 16:16:43 +01:00
localcc
95dbc0efc2 Multiple priority scheduler (#44575)
Improves the scheduler by allowing tasks to have a set priority which
will significantly improve responsiveness.

Release notes:

- N/A

---------

Co-authored-by: Yara <git@yara.blue>
2025-12-11 13:22:39 +00:00
Gaauwe Rombouts
8572c19a02 Improve TS/TSX/JS syntax highlighting for parameters, types, and punctuation (#44532)
Relands https://github.com/zed-industries/zed/pull/43437

Release Notes:

- Refined syntax highlighting in JavaScript and TypeScript for better
visual distinction of types, parameters, and JSDoc elements

---------

Co-authored-by: MrSubidubi <dev@bahn.sh>
Co-authored-by: Clay Tercek <30105080+claytercek@users.noreply.github.com>
2025-12-11 12:02:28 +01:00
Lukas Wirth
045c14593f util: Honor shell args for shell env fetching on windows (#44615)
Closes https://github.com/zed-industries/zed/issues/40464

Release Notes:

- Fixed shell environment fetching on windows discarding specified
arguments in settings
2025-12-11 10:34:37 +00:00
Lukas Wirth
0ff3b68a5e windows: Fix incorrect cursor insertion keybinds (#44608)
Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-12-11 09:38:44 +00:00
Lukas Wirth
a6b9524d78 gpui: Retain maximized and fullscreen state for new windows derived from previous windows (#44605)
Release Notes:

- Fixed new windows underflowing the taskbar on windows
- Improved new windows spawned from maximized or fullscreened windows by
copying the maximized and fullscreened states
2025-12-11 09:38:38 +00:00
CharlesChen0823
7ed5d42696 git: Fix git hook hang with prek (#44212)
Fix git hook hang when using with `prek`. Can see
[comments](https://github.com/zed-industries/zed/issues/44057#issuecomment-3606837089),
this is easy test, should using release build, debug build sometimes not
hang.

The issue existing long time, see issue #37293 , and then in commit
#42239 this issue had fixed. but in commit #43285 broken again. So I
reference the implementation in #42239, then this code work.

I MUST CLAIM, I really don't known what happend, and why this code work.
But it worked.

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
2025-12-11 03:17:13 +00:00
Max Brunsfeld
25d74480aa Rework edit prediction CLI (#44562)
This PR restructures the commands of the Edit Prediction CLI (now called
`ep`), to support some flows that are important for the training
process:
* generating zeta2 prompt and expected output, without running
predictions
* scoring outputs that are generated by a system other than the
production code (to evaluate the model during training)

To achieve this, we've restructured the CLI commands so that they all
take as input, and produce as output, a consistent, uniform data format:
a set of one or more `Example` structs, expressible either as the
original markdown format, or as a JSON lines. The `Example` struct
starts with the basic fields that are in human-readable eval format, but
contain a number of optional fields that are filled in by different
steps in the processing pipeline (`context`, `predict`, `format-prompt`,
and `score`).

### To do

* [x] Adjust the teacher model output parsing to use the full buffer
contents
* [x] Move udiff to cli
* [x] Align `format-prompt` with Zeta2's production code
* [x] Change score output to assume same provider
* [x] Move pretty reporting to `eval` command
* [x] Store cursor point in addition to cursor offset
* [x] Rename `edit_prediction_cli2` -> `edit_prediction_cli` (nuke the
old one)

Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-12-10 17:36:51 -08:00
Cole Miller
37077a8ebb git: Avoid calling git help -a on every commit (#44586)
Updates #43993 

Release Notes:

- N/A
2025-12-11 01:03:35 +00:00
Finn Evers
7c4a85f5f1 ci: Explicitly set git committer information in protobuf check (#44582)
This should hopefully fix the flakes for good.

Release Notes:

- N/A
2025-12-10 23:35:02 +00:00
Cole Miller
d21628c349 Revert "Increase askpass timeout for git operations (#42946)" (#44578)
This reverts commit a74aac88c9.

cc @11happy, we need to do a bit more than just running `git hook
pre-push` before pushing, as described
[here](https://github.com/zed-industries/zed/pull/42946#issuecomment-3550570438).
Right now this is also running the pre-push hook twice.

Release Notes:

- N/A
2025-12-10 18:07:01 -05:00
156 changed files with 6749 additions and 6177 deletions

View File

@@ -75,6 +75,22 @@ body:
</details>
validations:
required: false
- type: textarea
attributes:
label: Relevant Keymap
description: |
Open the command palette in Zed, then type “zed: open keymap file” and copy/paste the file's contents.
value: |
<details><summary>keymap.json</summary>
<!-- Paste your keymap file inside the code block. -->
```json
```
</details>
validations:
required: false
- type: textarea
attributes:
label: (for AI issues) Model provider details

View File

@@ -5,13 +5,27 @@ on:
release:
types:
- published
workflow_dispatch:
inputs:
tag_name:
description: tag_name
required: true
type: string
prerelease:
description: prerelease
required: true
type: boolean
body:
description: body
type: string
default: ''
jobs:
rebuild_releases_page:
if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
runs-on: namespace-profile-2x4-ubuntu-2404
steps:
- name: after_release::rebuild_releases_page::refresh_cloud_releases
run: curl -fX POST https://cloud.zed.dev/releases/refresh?expect_tag=${{ github.event.release.tag_name }}
run: curl -fX POST https://cloud.zed.dev/releases/refresh?expect_tag=${{ github.event.release.tag_name || inputs.tag_name }}
shell: bash -euxo pipefail {0}
- name: after_release::rebuild_releases_page::redeploy_zed_dev
run: npm exec --yes -- vercel@37 --token="$VERCEL_TOKEN" --scope zed-industries redeploy https://zed.dev
@@ -27,7 +41,7 @@ jobs:
- id: get-release-url
name: after_release::post_to_discord::get_release_url
run: |
if [ "${{ github.event.release.prerelease }}" == "true" ]; then
if [ "${{ github.event.release.prerelease || inputs.prerelease }}" == "true" ]; then
URL="https://zed.dev/releases/preview"
else
URL="https://zed.dev/releases/stable"
@@ -40,9 +54,9 @@ jobs:
uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757
with:
stringToTruncate: |
📣 Zed [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released!
📣 Zed [${{ github.event.release.tag_name || inputs.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released!
${{ github.event.release.body }}
${{ github.event.release.body || inputs.body }}
maxLength: 2000
truncationSymbol: '...'
- name: after_release::post_to_discord::discord_webhook_action
@@ -56,7 +70,7 @@ jobs:
- id: set-package-name
name: after_release::publish_winget::set_package_name
run: |
if ("${{ github.event.release.prerelease }}" -eq "true") {
if ("${{ github.event.release.prerelease || inputs.prerelease }}" -eq "true") {
$PACKAGE_NAME = "ZedIndustries.Zed.Preview"
} else {
$PACKAGE_NAME = "ZedIndustries.Zed"
@@ -68,6 +82,7 @@ jobs:
uses: vedantmgoyal9/winget-releaser@19e706d4c9121098010096f9c495a70a7518b30f
with:
identifier: ${{ steps.set-package-name.outputs.PACKAGE_NAME }}
release-tag: ${{ github.event.release.tag_name || inputs.tag_name }}
max-versions-to-keep: 5
token: ${{ secrets.WINGET_TOKEN }}
create_sentry_release:

View File

@@ -497,6 +497,8 @@ jobs:
env:
GIT_AUTHOR_NAME: Protobuf Action
GIT_AUTHOR_EMAIL: ci@zed.dev
GIT_COMMITTER_NAME: Protobuf Action
GIT_COMMITTER_EMAIL: ci@zed.dev
steps:
- name: steps::checkout_repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683

39
Cargo.lock generated
View File

@@ -3111,16 +3111,6 @@ dependencies = [
"uuid",
]
[[package]]
name = "cloud_zeta2_prompt"
version = "0.1.0"
dependencies = [
"anyhow",
"cloud_llm_client",
"indoc",
"serde",
]
[[package]]
name = "cmake"
version = "0.1.54"
@@ -5119,7 +5109,6 @@ dependencies = [
"clock",
"cloud_api_types",
"cloud_llm_client",
"cloud_zeta2_prompt",
"collections",
"copilot",
"credentials_provider",
@@ -5150,8 +5139,6 @@ dependencies = [
"serde",
"serde_json",
"settings",
"smol",
"strsim",
"strum 0.27.2",
"telemetry",
"telemetry_events",
@@ -5162,6 +5149,7 @@ dependencies = [
"workspace",
"worktree",
"zed_actions",
"zeta_prompt",
"zlog",
]
@@ -5175,11 +5163,10 @@ dependencies = [
"clap",
"client",
"cloud_llm_client",
"cloud_zeta2_prompt",
"collections",
"debug_adapter_extension",
"dirs 4.0.0",
"edit_prediction",
"edit_prediction_context",
"extension",
"fs",
"futures 0.3.31",
@@ -5192,6 +5179,7 @@ dependencies = [
"language_model",
"language_models",
"languages",
"libc",
"log",
"node_runtime",
"paths",
@@ -5209,10 +5197,10 @@ dependencies = [
"sqlez",
"sqlez_macros",
"terminal_view",
"toml 0.8.23",
"util",
"wasmtime",
"watch",
"zlog",
"zeta_prompt",
]
[[package]]
@@ -5239,6 +5227,7 @@ dependencies = [
"text",
"tree-sitter",
"util",
"zeta_prompt",
"zlog",
]
@@ -5260,7 +5249,6 @@ dependencies = [
"buffer_diff",
"client",
"cloud_llm_client",
"cloud_zeta2_prompt",
"codestral",
"command_palette_hooks",
"copilot",
@@ -5291,6 +5279,7 @@ dependencies = [
"util",
"workspace",
"zed_actions",
"zeta_prompt",
]
[[package]]
@@ -7250,6 +7239,7 @@ dependencies = [
"libc",
"log",
"lyon",
"mach2 0.5.0",
"media",
"metal",
"naga",
@@ -14456,12 +14446,14 @@ dependencies = [
"settings",
"smol",
"theme",
"tracing",
"ui",
"unindent",
"util",
"util_macros",
"workspace",
"zed_actions",
"ztracing",
]
[[package]]
@@ -16374,13 +16366,13 @@ dependencies = [
"alacritty_terminal",
"anyhow",
"collections",
"fancy-regex",
"futures 0.3.31",
"gpui",
"itertools 0.14.0",
"libc",
"log",
"rand 0.9.2",
"regex",
"release_channel",
"schemars",
"serde",
@@ -18108,6 +18100,7 @@ dependencies = [
"language",
"log",
"lsp",
"markdown_preview",
"menu",
"multi_buffer",
"nvim-rs",
@@ -20933,6 +20926,13 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "zeta_prompt"
version = "0.1.0"
dependencies = [
"serde",
]
[[package]]
name = "zip"
version = "0.6.6"
@@ -21025,6 +21025,7 @@ dependencies = [
"tracing",
"tracing-subscriber",
"tracing-tracy",
"zlog",
"ztracing_macro",
]

View File

@@ -32,7 +32,6 @@ members = [
"crates/cloud_api_client",
"crates/cloud_api_types",
"crates/cloud_llm_client",
"crates/cloud_zeta2_prompt",
"crates/collab",
"crates/collab_ui",
"crates/collections",
@@ -202,6 +201,7 @@ members = [
"crates/zed_actions",
"crates/zed_env_vars",
"crates/edit_prediction_cli",
"crates/zeta_prompt",
"crates/zlog",
"crates/zlog_settings",
"crates/ztracing",
@@ -266,7 +266,6 @@ clock = { path = "crates/clock" }
cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections", version = "0.1.0" }
command_palette = { path = "crates/command_palette" }
@@ -425,6 +424,7 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
edit_prediction = { path = "crates/edit_prediction" }
zeta_prompt = { path = "crates/zeta_prompt" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
ztracing = { path = "crates/ztracing" }
@@ -631,7 +631,7 @@ shellexpand = "2.1.0"
shlex = "1.3.0"
simplelog = "0.12.2"
slotmap = "1.0.6"
smallvec = { version = "1.6", features = ["union"] }
smallvec = { version = "1.6", features = ["union", "const_new"] }
smol = "2.0"
sqlformat = "0.2"
stacksafe = "0.1"
@@ -657,6 +657,7 @@ time = { version = "0.3", features = [
tiny_http = "0.8"
tokio = { version = "1" }
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] }
toml = "0.8"
toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] }
tower-http = "0.4.4"

View File

@@ -25,7 +25,8 @@
"ctrl-shift-w": "workspace::CloseWindow",
"shift-escape": "workspace::ToggleZoom",
"open": "workspace::Open",
"ctrl-o": "workspace::Open",
"ctrl-o": "workspace::OpenFiles",
"ctrl-k ctrl-o": "workspace::Open",
"ctrl-=": ["zed::IncreaseBufferFontSize", { "persist": false }],
"ctrl-+": ["zed::IncreaseBufferFontSize", { "persist": false }],
"ctrl--": ["zed::DecreaseBufferFontSize", { "persist": false }],
@@ -814,7 +815,6 @@
"ctrl-]": "agent::CycleNextInlineAssist",
"ctrl-shift-enter": "inline_assistant::ThumbsUpResult",
"ctrl-shift-backspace": "inline_assistant::ThumbsDownResult"
}
},
{
@@ -1192,8 +1192,12 @@
{
"context": "MarkdownPreview",
"bindings": {
"pageup": "markdown::MovePageUp",
"pagedown": "markdown::MovePageDown"
"pageup": "markdown::ScrollPageUp",
"pagedown": "markdown::ScrollPageDown",
"up": "markdown::ScrollUp",
"down": "markdown::ScrollDown",
"alt-up": "markdown::ScrollUpByItem",
"alt-down": "markdown::ScrollDownByItem"
}
},
{

View File

@@ -1296,8 +1296,12 @@
{
"context": "MarkdownPreview",
"bindings": {
"pageup": "markdown::MovePageUp",
"pagedown": "markdown::MovePageDown"
"pageup": "markdown::ScrollPageUp",
"pagedown": "markdown::ScrollPageDown",
"up": "markdown::ScrollUp",
"down": "markdown::ScrollDown",
"alt-up": "markdown::ScrollUpByItem",
"alt-down": "markdown::ScrollDownByItem"
}
},
{

View File

@@ -489,8 +489,8 @@
"bindings": {
"ctrl-[": "editor::Outdent",
"ctrl-]": "editor::Indent",
"ctrl-shift-alt-up": ["editor::AddSelectionAbove", { "skip_soft_wrap": true }], // Insert Cursor Above
"ctrl-shift-alt-down": ["editor::AddSelectionBelow", { "skip_soft_wrap": true }], // Insert Cursor Below
"ctrl-alt-up": ["editor::AddSelectionAbove", { "skip_soft_wrap": true }], // Insert Cursor Above
"ctrl-alt-down": ["editor::AddSelectionBelow", { "skip_soft_wrap": true }], // Insert Cursor Below
"ctrl-shift-k": "editor::DeleteLine",
"alt-up": "editor::MoveLineUp",
"alt-down": "editor::MoveLineDown",
@@ -501,9 +501,12 @@
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
"ctrl-f3": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip
"ctrl-shift-f3": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
"ctrl-k ctrl-i": "editor::Hover",
"ctrl-k ctrl-b": "editor::BlameHover",
"ctrl-k ctrl-f": "editor::FormatSelections",
"ctrl-/": ["editor::ToggleComments", { "advance_downwards": false }],
"f8": ["editor::GoToDiagnostic", { "severity": { "min": "hint", "max": "error" } }],
"shift-f8": ["editor::GoToPreviousDiagnostic", { "severity": { "min": "hint", "max": "error" } }],
@@ -536,7 +539,7 @@
"ctrl-k p": "editor::CopyPath",
"ctrl-\\": "pane::SplitRight",
"alt-.": "editor::GoToHunk",
"alt-,": "editor::GoToPreviousHunk"
"alt-,": "editor::GoToPreviousHunk",
}
},
{
@@ -1220,8 +1223,12 @@
"context": "MarkdownPreview",
"use_key_equivalents": true,
"bindings": {
"pageup": "markdown::MovePageUp",
"pagedown": "markdown::MovePageDown"
"pageup": "markdown::ScrollPageUp",
"pagedown": "markdown::ScrollPageDown",
"up": "markdown::ScrollUp",
"down": "markdown::ScrollDown",
"alt-up": "markdown::ScrollUpByItem",
"alt-down": "markdown::ScrollDownByItem"
}
},
{

View File

@@ -1046,5 +1046,14 @@
"g g": "settings_editor::FocusFirstNavEntry",
"shift-g": "settings_editor::FocusLastNavEntry"
}
},
{
"context": "MarkdownPreview",
"bindings": {
"ctrl-u": "markdown::ScrollPageUp",
"ctrl-d": "markdown::ScrollPageDown",
"ctrl-y": "markdown::ScrollUp",
"ctrl-e": "markdown::ScrollDown"
}
}
]

View File

@@ -12,7 +12,7 @@
"theme": {
"mode": "system",
"light": "One Light",
"dark": "One Dark"
"dark": "One Dark",
},
"icon_theme": "Zed (Default)",
// The name of a base set of key bindings to use.
@@ -29,7 +29,7 @@
// Features that can be globally enabled or disabled
"features": {
// Which edit prediction provider to use.
"edit_prediction_provider": "zed"
"edit_prediction_provider": "zed",
},
// The name of a font to use for rendering text in the editor
// ".ZedMono" currently aliases to Lilex
@@ -69,7 +69,7 @@
// The OpenType features to enable for text in the UI
"ui_font_features": {
// Disable ligatures:
"calt": false
"calt": false,
},
// The weight of the UI font in standard CSS units from 100 to 900.
"ui_font_weight": 400,
@@ -87,7 +87,7 @@
"border_size": 0.0,
// Opacity of the inactive panes. 0 means transparent, 1 means opaque.
// Values are clamped to the [0.0, 1.0] range.
"inactive_opacity": 1.0
"inactive_opacity": 1.0,
},
// Layout mode of the bottom dock. Defaults to "contained"
// choices: contained, full, left_aligned, right_aligned
@@ -103,12 +103,12 @@
"left_padding": 0.2,
// The relative width of the right padding of the central pane from the
// workspace when the centered layout is used.
"right_padding": 0.2
"right_padding": 0.2,
},
// Image viewer settings
"image_viewer": {
// The unit for image file sizes: "binary" (KiB, MiB) or decimal (KB, MB)
"unit": "binary"
"unit": "binary",
},
// Determines the modifier to be used to add multiple cursors with the mouse. The open hover link mouse gestures will adapt such that it do not conflict with the multicursor modifier.
//
@@ -296,7 +296,7 @@
// When true, enables drag and drop text selection in buffer.
"enabled": true,
// The delay in milliseconds that must elapse before drag and drop is allowed. Otherwise, a new text selection is created.
"delay": 300
"delay": 300,
},
// What to do when go to definition yields no results.
//
@@ -400,14 +400,14 @@
// Visible characters used to render whitespace when show_whitespaces is enabled.
"whitespace_map": {
"space": "•",
"tab": "→"
"tab": "→",
},
// Settings related to calls in Zed
"calls": {
// Join calls with the microphone live by default
"mute_on_join": false,
// Share your project when you are the first to join a channel
"share_on_join": false
"share_on_join": false,
},
// Toolbar related settings
"toolbar": {
@@ -420,7 +420,7 @@
// Whether to show agent review buttons in the editor toolbar.
"agent_review": true,
// Whether to show code action buttons in the editor toolbar.
"code_actions": false
"code_actions": false,
},
// Whether to allow windows to tab together based on the users tabbing preference (macOS only).
"use_system_window_tabs": false,
@@ -439,7 +439,7 @@
// Whether to show the sign in button in the titlebar.
"show_sign_in": true,
// Whether to show the menus in the titlebar.
"show_menus": false
"show_menus": false,
},
"audio": {
// Opt into the new audio system.
@@ -472,7 +472,7 @@
// the future we will migrate by setting this to false
//
// You need to rejoin a call for this setting to apply
"experimental.legacy_audio_compatible": true
"experimental.legacy_audio_compatible": true,
},
// Scrollbar related settings
"scrollbar": {
@@ -511,8 +511,8 @@
// When false, forcefully disables the horizontal scrollbar. Otherwise, obey other settings.
"horizontal": true,
// When false, forcefully disables the vertical scrollbar. Otherwise, obey other settings.
"vertical": true
}
"vertical": true,
},
},
// Minimap related settings
"minimap": {
@@ -560,7 +560,7 @@
// 3. "gutter" or "none" to not highlight the current line in the minimap.
"current_line_highlight": null,
// Maximum number of columns to display in the minimap.
"max_width_columns": 80
"max_width_columns": 80,
},
// Enable middle-click paste on Linux.
"middle_click_paste": true,
@@ -583,7 +583,7 @@
// Whether to show fold buttons in the gutter.
"folds": true,
// Minimum number of characters to reserve space for in the gutter.
"min_line_number_digits": 4
"min_line_number_digits": 4,
},
"indent_guides": {
// Whether to show indent guides in the editor.
@@ -604,7 +604,7 @@
//
// 1. "disabled"
// 2. "indent_aware"
"background_coloring": "disabled"
"background_coloring": "disabled",
},
// Whether the editor will scroll beyond the last line.
"scroll_beyond_last_line": "one_page",
@@ -623,7 +623,7 @@
"fast_scroll_sensitivity": 4.0,
"sticky_scroll": {
// Whether to stick scopes to the top of the editor.
"enabled": false
"enabled": false,
},
"relative_line_numbers": "disabled",
// If 'search_wrap' is disabled, search result do not wrap around the end of the file.
@@ -641,7 +641,7 @@
// Whether to interpret the search query as a regular expression.
"regex": false,
// Whether to center the cursor on each search match when navigating.
"center_on_match": false
"center_on_match": false,
},
// When to populate a new search's query based on the text under the cursor.
// This setting can take the following three values:
@@ -684,8 +684,8 @@
"shift": false,
"alt": false,
"platform": false,
"function": false
}
"function": false,
},
},
// Whether to resize all the panels in a dock when resizing the dock.
// Can be a combination of "left", "right" and "bottom".
@@ -733,7 +733,7 @@
// "always"
// 5. Never show the scrollbar:
// "never"
"show": null
"show": null,
},
// Which files containing diagnostic errors/warnings to mark in the project panel.
// This setting can take the following three values:
@@ -756,7 +756,7 @@
// "always"
// 2. Never show indent guides:
// "never"
"show": "always"
"show": "always",
},
// Sort order for entries in the project panel.
// This setting can take three values:
@@ -781,8 +781,8 @@
// Whether to automatically open files after pasting or duplicating them.
"on_paste": true,
// Whether to automatically open files dropped from external sources.
"on_drop": true
}
"on_drop": true,
},
},
"outline_panel": {
// Whether to show the outline panel button in the status bar
@@ -815,7 +815,7 @@
// "always"
// 2. Never show indent guides:
// "never"
"show": "always"
"show": "always",
},
// Scrollbar-related settings
"scrollbar": {
@@ -832,11 +832,11 @@
// "always"
// 5. Never show the scrollbar:
// "never"
"show": null
"show": null,
},
// Default depth to expand outline items in the current file.
// Set to 0 to collapse all items that have children, 1 or higher to collapse items at that depth or deeper.
"expand_outlines_with_depth": 100
"expand_outlines_with_depth": 100,
},
"collaboration_panel": {
// Whether to show the collaboration panel button in the status bar.
@@ -844,7 +844,7 @@
// Where to dock the collaboration panel. Can be 'left' or 'right'.
"dock": "left",
// Default width of the collaboration panel.
"default_width": 240
"default_width": 240,
},
"git_panel": {
// Whether to show the git panel button in the status bar.
@@ -880,12 +880,12 @@
// Choices: always, auto, never, system
// Default: inherits editor scrollbar settings
// "show": null
}
},
},
"message_editor": {
// Whether to automatically replace emoji shortcodes with emoji characters.
// For example: typing `:wave:` gets replaced with `👋`.
"auto_replace_emoji_shortcode": true
"auto_replace_emoji_shortcode": true,
},
"notification_panel": {
// Whether to show the notification panel button in the status bar.
@@ -893,7 +893,7 @@
// Where to dock the notification panel. Can be 'left' or 'right'.
"dock": "right",
// Default width of the notification panel.
"default_width": 380
"default_width": 380,
},
"agent": {
// Whether the agent is enabled.
@@ -915,7 +915,7 @@
// The provider to use.
"provider": "zed.dev",
// The model to use.
"model": "claude-sonnet-4"
"model": "claude-sonnet-4",
},
// Additional parameters for language model requests. When making a request to a model, parameters will be taken
// from the last entry in this list that matches the model's provider and name. In each entry, both provider
@@ -970,8 +970,8 @@
"grep": true,
"terminal": true,
"thinking": true,
"web_search": true
}
"web_search": true,
},
},
"ask": {
"name": "Ask",
@@ -988,14 +988,14 @@
"open": true,
"grep": true,
"thinking": true,
"web_search": true
}
"web_search": true,
},
},
"minimal": {
"name": "Minimal",
"enable_all_context_servers": false,
"tools": {}
}
"tools": {},
},
},
// Where to show notifications when the agent has either completed
// its response, or else needs confirmation before it can run a
@@ -1024,7 +1024,7 @@
// Minimum number of lines to display in the agent message editor.
//
// Default: 4
"message_editor_min_lines": 4
"message_editor_min_lines": 4,
},
// Whether the screen sharing icon is shown in the os status bar.
"show_call_status_icon": true,
@@ -1059,7 +1059,7 @@
// Whether or not to show the navigation history buttons.
"show_nav_history_buttons": true,
// Whether or not to show the tab bar buttons.
"show_tab_bar_buttons": true
"show_tab_bar_buttons": true,
},
// Settings related to the editor's tabs
"tabs": {
@@ -1098,7 +1098,7 @@
// "errors"
// 3. Mark files with errors and warnings:
// "all"
"show_diagnostics": "off"
"show_diagnostics": "off",
},
// Settings related to preview tabs.
"preview_tabs": {
@@ -1119,7 +1119,7 @@
"enable_preview_file_from_code_navigation": true,
// Whether to keep tabs in preview mode when code navigation is used to navigate away from them.
// If `enable_preview_file_from_code_navigation` or `enable_preview_multibuffer_from_code_navigation` is also true, the new tab may replace the existing one.
"enable_keep_preview_on_code_navigation": false
"enable_keep_preview_on_code_navigation": false,
},
// Settings related to the file finder.
"file_finder": {
@@ -1163,7 +1163,7 @@
// * "all": Use all gitignored files
// * "indexed": Use only the files Zed had indexed
// * "smart": Be smart and search for ignored when called from a gitignored worktree
"include_ignored": "smart"
"include_ignored": "smart",
},
// Whether or not to remove any trailing whitespace from lines of a buffer
// before saving it.
@@ -1234,7 +1234,7 @@
// Send debug info like crash reports.
"diagnostics": true,
// Send anonymized usage data like what languages you're using Zed with.
"metrics": true
"metrics": true,
},
// Whether to disable all AI features in Zed.
//
@@ -1268,7 +1268,7 @@
"enabled": true,
// Minimum time to wait before pulling diagnostics from the language server(s).
// 0 turns the debounce off.
"debounce_ms": 50
"debounce_ms": 50,
},
// Settings for inline diagnostics
"inline": {
@@ -1286,8 +1286,8 @@
"min_column": 0,
// The minimum severity of the diagnostics to show inline.
// Inherits editor's diagnostics' max severity settings when `null`.
"max_severity": null
}
"max_severity": null,
},
},
// Files or globs of files that will be excluded by Zed entirely. They will be skipped during file
// scans, file searches, and not be displayed in the project file tree. Takes precedence over `file_scan_inclusions`.
@@ -1301,7 +1301,7 @@
"**/.DS_Store",
"**/Thumbs.db",
"**/.classpath",
"**/.settings"
"**/.settings",
],
// Files or globs of files that will be included by Zed, even when ignored by git. This is useful
// for files that are not tracked by git, but are still important to your project. Note that globs
@@ -1336,14 +1336,14 @@
// Whether or not to display the git commit summary on the same line.
"show_commit_summary": false,
// The minimum column number to show the inline blame information at
"min_column": 0
"min_column": 0,
},
"blame": {
"show_avatar": true
"show_avatar": true,
},
// Control which information is shown in the branch picker.
"branch_picker": {
"show_author_name": true
"show_author_name": true,
},
// How git hunks are displayed visually in the editor.
// This setting can take two values:
@@ -1355,7 +1355,7 @@
"hunk_style": "staged_hollow",
// Should the name or path be displayed first in the git view.
// "path_style": "file_name_first" or "file_path_first"
"path_style": "file_name_first"
"path_style": "file_name_first",
},
// The list of custom Git hosting providers.
"git_hosting_providers": [
@@ -1389,7 +1389,7 @@
"**/secrets.yml",
"**/.zed/settings.json", // zed project settings
"/**/zed/settings.json", // zed user settings
"/**/zed/keymap.json"
"/**/zed/keymap.json",
],
// When to show edit predictions previews in buffer.
// This setting takes two possible values:
@@ -1407,15 +1407,15 @@
"copilot": {
"enterprise_uri": null,
"proxy": null,
"proxy_no_verify": null
"proxy_no_verify": null,
},
"codestral": {
"model": null,
"max_tokens": null
"max_tokens": null,
},
// Whether edit predictions are enabled when editing text threads in the agent panel.
// This setting has no effect if globally disabled.
"enabled_in_text_threads": true
"enabled_in_text_threads": true,
},
// Settings specific to journaling
"journal": {
@@ -1425,7 +1425,7 @@
// May take 2 values:
// 1. hour12
// 2. hour24
"hour_format": "hour12"
"hour_format": "hour12",
},
// Status bar-related settings.
"status_bar": {
@@ -1436,7 +1436,7 @@
// Whether to show the cursor position button in the status bar.
"cursor_position_button": true,
// Whether to show active line endings button in the status bar.
"line_endings_button": false
"line_endings_button": false,
},
// Settings specific to the terminal
"terminal": {
@@ -1557,8 +1557,8 @@
// Preferred Conda manager to use when activating Conda environments.
// Values: "auto", "conda", "mamba", "micromamba"
// Default: "auto"
"conda_manager": "auto"
}
"conda_manager": "auto",
},
},
"toolbar": {
// Whether to display the terminal title in its toolbar's breadcrumbs.
@@ -1566,7 +1566,7 @@
//
// The shell running in the terminal needs to be configured to emit the title.
// Example: `echo -e "\e]2;New Title\007";`
"breadcrumbs": false
"breadcrumbs": false,
},
// Scrollbar-related settings
"scrollbar": {
@@ -1583,7 +1583,7 @@
// "always"
// 5. Never show the scrollbar:
// "never"
"show": null
"show": null,
},
// Set the terminal's font size. If this option is not included,
// the terminal will default to matching the buffer's font size.
@@ -1646,30 +1646,26 @@
// surrounding symbols or quotes
[
"(?x)",
"# optionally starts with 0-2 opening prefix symbols",
"[({\\[<]{0,2}",
"# which may be followed by an opening quote",
"(?<quote>[\"'`])?",
"# `path` is the shortest sequence of any non-space character",
"(?<link>(?<path>[^ ]+?",
" # which may end with a line and optionally a column,",
" (?<line_column>:+[0-9]+(:[0-9]+)?|:?\\([0-9]+([,:][0-9]+)?\\))?",
"))",
"# which must be followed by a matching quote",
"(?(<quote>)\\k<quote>)",
"# and optionally a single closing symbol",
"[)}\\]>]?",
"# if line/column matched, may be followed by a description",
"(?(<line_column>):[^ 0-9][^ ]*)?",
"# which may be followed by trailing punctuation",
"[.,:)}\\]>]*",
"# and always includes trailing whitespace or end of line",
"([ ]+|$)"
]
"(?<path>",
" (",
" # multi-char path: first char (not opening delimiter or space)",
" [^({\\[<\"'`\\ ]",
" # middle chars: non-space, and colon/paren only if not followed by digit/paren",
" ([^\\ :(]|[:(][^0-9()])*",
" # last char: not closing delimiter or colon",
" [^()}\\]>\"'`.,;:\\ ]",
" |",
" # single-char path: not delimiter, punctuation, or space",
" [^(){}\\[\\]<>\"'`.,;:\\ ]",
" )",
" # optional line/column suffix (included in path for PathWithPosition::parse_str)",
" (:+[0-9]+(:[0-9]+)?|:?\\([0-9]+([,:]?[0-9]+)?\\))?",
")",
],
],
// Timeout for hover and Cmd-click path hyperlink discovery in milliseconds. Specifying a
// timeout of `0` will disable path hyperlinking in terminal.
"path_hyperlink_timeout_ms": 1
"path_hyperlink_timeout_ms": 1,
},
"code_actions_on_format": {},
// Settings related to running tasks.
@@ -1685,7 +1681,7 @@
// * Zed task from history (e.g. one-off task was spawned before)
//
// Default: true
"prefer_lsp": true
"prefer_lsp": true,
},
// An object whose keys are language names, and whose values
// are arrays of filenames or extensions of files that should
@@ -1702,7 +1698,7 @@
"file_types": {
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "**/.vscode/**/*.json", "tsconfig*.json"],
"Markdown": [".rules", ".cursorrules", ".windsurfrules", ".clinerules"],
"Shell Script": [".env.*"]
"Shell Script": [".env.*"],
},
// Settings for which version of Node.js and NPM to use when installing
// language servers and Copilot.
@@ -1718,14 +1714,14 @@
// `path`, but not `npm_path`, Zed will assume that `npm` is located at
// `${path}/../npm`.
"path": null,
"npm_path": null
"npm_path": null,
},
// The extensions that Zed should automatically install on startup.
//
// If you don't want any of these extensions, add this field to your settings
// and change the value to `false`.
"auto_install_extensions": {
"html": true
"html": true,
},
// The capabilities granted to extensions.
//
@@ -1733,7 +1729,7 @@
"granted_extension_capabilities": [
{ "kind": "process:exec", "command": "*", "args": ["**"] },
{ "kind": "download_file", "host": "*", "path": ["**"] },
{ "kind": "npm:install", "package": "*" }
{ "kind": "npm:install", "package": "*" },
],
// Controls how completions are processed for this language.
"completions": {
@@ -1784,7 +1780,7 @@
// 4. "replace_suffix"
// Behaves like `"replace"` if the text after the cursor is a suffix of the completion, and like
// `"insert"` otherwise.
"lsp_insert_mode": "replace_suffix"
"lsp_insert_mode": "replace_suffix",
},
// Different settings for specific languages.
"languages": {
@@ -1792,116 +1788,116 @@
"language_servers": ["astro-language-server", "..."],
"prettier": {
"allowed": true,
"plugins": ["prettier-plugin-astro"]
}
"plugins": ["prettier-plugin-astro"],
},
},
"Blade": {
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"C": {
"format_on_save": "off",
"use_on_type_format": false,
"prettier": {
"allowed": false
}
"allowed": false,
},
},
"C++": {
"format_on_save": "off",
"use_on_type_format": false,
"prettier": {
"allowed": false
}
"allowed": false,
},
},
"CSharp": {
"language_servers": ["roslyn", "!omnisharp", "..."]
"language_servers": ["roslyn", "!omnisharp", "..."],
},
"CSS": {
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"Dart": {
"tab_size": 2
"tab_size": 2,
},
"Diff": {
"show_edit_predictions": false,
"remove_trailing_whitespace_on_save": false,
"ensure_final_newline_on_save": false
"ensure_final_newline_on_save": false,
},
"Elixir": {
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."]
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."],
},
"Elm": {
"tab_size": 4
"tab_size": 4,
},
"Erlang": {
"language_servers": ["erlang-ls", "!elp", "..."]
"language_servers": ["erlang-ls", "!elp", "..."],
},
"Git Commit": {
"allow_rewrap": "anywhere",
"soft_wrap": "editor_width",
"preferred_line_length": 72
"preferred_line_length": 72,
},
"Go": {
"hard_tabs": true,
"code_actions_on_format": {
"source.organizeImports": true
"source.organizeImports": true,
},
"debuggers": ["Delve"]
"debuggers": ["Delve"],
},
"GraphQL": {
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"HEEX": {
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."]
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."],
},
"HTML": {
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"HTML+ERB": {
"language_servers": ["herb", "!ruby-lsp", "..."]
"language_servers": ["herb", "!ruby-lsp", "..."],
},
"Java": {
"prettier": {
"allowed": true,
"plugins": ["prettier-plugin-java"]
}
"plugins": ["prettier-plugin-java"],
},
},
"JavaScript": {
"language_servers": ["!typescript-language-server", "vtsls", "..."],
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"JSON": {
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"JSONC": {
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"JS+ERB": {
"language_servers": ["!ruby-lsp", "..."]
"language_servers": ["!ruby-lsp", "..."],
},
"Kotlin": {
"language_servers": ["!kotlin-language-server", "kotlin-lsp", "..."]
"language_servers": ["!kotlin-language-server", "kotlin-lsp", "..."],
},
"LaTeX": {
"formatter": "language_server",
"language_servers": ["texlab", "..."],
"prettier": {
"allowed": true,
"plugins": ["prettier-plugin-latex"]
}
"plugins": ["prettier-plugin-latex"],
},
},
"Markdown": {
"format_on_save": "off",
@@ -1909,136 +1905,142 @@
"remove_trailing_whitespace_on_save": false,
"allow_rewrap": "anywhere",
"soft_wrap": "editor_width",
"completions": {
"words": "disabled",
},
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"PHP": {
"language_servers": ["phpactor", "!intelephense", "!phptools", "..."],
"prettier": {
"allowed": true,
"plugins": ["@prettier/plugin-php"],
"parser": "php"
}
"parser": "php",
},
},
"Plain Text": {
"allow_rewrap": "anywhere",
"soft_wrap": "editor_width"
"soft_wrap": "editor_width",
"completions": {
"words": "disabled",
},
},
"Python": {
"code_actions_on_format": {
"source.organizeImports.ruff": true
"source.organizeImports.ruff": true,
},
"formatter": {
"language_server": {
"name": "ruff"
}
"name": "ruff",
},
},
"debuggers": ["Debugpy"],
"language_servers": ["basedpyright", "ruff", "!ty", "!pyrefly", "!pyright", "!pylsp", "..."]
"language_servers": ["basedpyright", "ruff", "!ty", "!pyrefly", "!pyright", "!pylsp", "..."],
},
"Ruby": {
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "!sorbet", "!steep", "..."]
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "!sorbet", "!steep", "..."],
},
"Rust": {
"debuggers": ["CodeLLDB"]
"debuggers": ["CodeLLDB"],
},
"SCSS": {
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"Starlark": {
"language_servers": ["starpls", "!buck2-lsp", "..."]
"language_servers": ["starpls", "!buck2-lsp", "..."],
},
"Svelte": {
"language_servers": ["svelte-language-server", "..."],
"prettier": {
"allowed": true,
"plugins": ["prettier-plugin-svelte"]
}
"plugins": ["prettier-plugin-svelte"],
},
},
"TSX": {
"language_servers": ["!typescript-language-server", "vtsls", "..."],
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"Twig": {
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"TypeScript": {
"language_servers": ["!typescript-language-server", "vtsls", "..."],
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"SystemVerilog": {
"format_on_save": "off",
"language_servers": ["!slang", "..."],
"use_on_type_format": false
"use_on_type_format": false,
},
"Vue.js": {
"language_servers": ["vue-language-server", "vtsls", "..."],
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"XML": {
"prettier": {
"allowed": true,
"plugins": ["@prettier/plugin-xml"]
}
"plugins": ["@prettier/plugin-xml"],
},
},
"YAML": {
"prettier": {
"allowed": true
}
"allowed": true,
},
},
"YAML+ERB": {
"language_servers": ["!ruby-lsp", "..."]
"language_servers": ["!ruby-lsp", "..."],
},
"Zig": {
"language_servers": ["zls", "..."]
}
"language_servers": ["zls", "..."],
},
},
// Different settings for specific language models.
"language_models": {
"anthropic": {
"api_url": "https://api.anthropic.com"
"api_url": "https://api.anthropic.com",
},
"bedrock": {},
"google": {
"api_url": "https://generativelanguage.googleapis.com"
"api_url": "https://generativelanguage.googleapis.com",
},
"ollama": {
"api_url": "http://localhost:11434"
"api_url": "http://localhost:11434",
},
"openai": {
"api_url": "https://api.openai.com/v1"
"api_url": "https://api.openai.com/v1",
},
"openai_compatible": {},
"open_router": {
"api_url": "https://openrouter.ai/api/v1"
"api_url": "https://openrouter.ai/api/v1",
},
"lmstudio": {
"api_url": "http://localhost:1234/api/v0"
"api_url": "http://localhost:1234/api/v0",
},
"deepseek": {
"api_url": "https://api.deepseek.com/v1"
"api_url": "https://api.deepseek.com/v1",
},
"mistral": {
"api_url": "https://api.mistral.ai/v1"
"api_url": "https://api.mistral.ai/v1",
},
"vercel": {
"api_url": "https://api.v0.dev/v1"
"api_url": "https://api.v0.dev/v1",
},
"x_ai": {
"api_url": "https://api.x.ai/v1"
"api_url": "https://api.x.ai/v1",
},
"zed.dev": {}
"zed.dev": {},
},
"session": {
// Whether or not to restore unsaved buffers on restart.
@@ -2047,7 +2049,7 @@
// dirty files when closing the application.
//
// Default: true
"restore_unsaved_buffers": true
"restore_unsaved_buffers": true,
},
// Zed's Prettier integration settings.
// Allows to enable/disable formatting with Prettier
@@ -2065,11 +2067,11 @@
// "singleQuote": true
// Forces Prettier integration to use a specific parser name when formatting files with the language
// when set to a non-empty string.
"parser": ""
"parser": "",
},
// Settings for auto-closing of JSX tags.
"jsx_tag_auto_close": {
"enabled": true
"enabled": true,
},
// LSP Specific settings.
"lsp": {
@@ -2090,19 +2092,19 @@
// Specify the DAP name as a key here.
"CodeLLDB": {
"env": {
"RUST_LOG": "info"
}
}
"RUST_LOG": "info",
},
},
},
// Common language server settings.
"global_lsp_settings": {
// Whether to show the LSP servers button in the status bar.
"button": true
"button": true,
},
// Jupyter settings
"jupyter": {
"enabled": true,
"kernel_selections": {}
"kernel_selections": {},
// Specify the language name as the key and the kernel name as the value.
// "kernel_selections": {
// "python": "conda-base"
@@ -2116,7 +2118,7 @@
"max_columns": 128,
// Maximum number of lines to keep in REPL's scrollback buffer.
// Clamped with [4, 256] range.
"max_lines": 32
"max_lines": 32,
},
// Vim settings
"vim": {
@@ -2130,7 +2132,7 @@
// Specify the mode as the key and the shape as the value.
// The mode can be one of the following: "normal", "replace", "insert", "visual".
// The shape can be one of the following: "block", "bar", "underline", "hollow".
"cursor_shape": {}
"cursor_shape": {},
},
// The server to connect to. If the environment variable
// ZED_SERVER_URL is set, it will override this setting.
@@ -2163,9 +2165,9 @@
"windows": {
"languages": {
"PHP": {
"language_servers": ["intelephense", "!phpactor", "!phptools", "..."]
}
}
"language_servers": ["intelephense", "!phpactor", "!phptools", "..."],
},
},
},
// Whether to show full labels in line indicator or short ones
//
@@ -2224,7 +2226,7 @@
"dock": "bottom",
"log_dap_communications": true,
"format_dap_log_messages": true,
"button": true
"button": true,
},
// Configures any number of settings profiles that are temporarily applied on
// top of your existing user settings when selected from
@@ -2251,5 +2253,5 @@
// Useful for filtering out noisy logs or enabling more verbose logging.
//
// Example: {"log": {"client": "warn"}}
"log": {}
"log": {},
}

View File

@@ -11,8 +11,6 @@ use project::agent_server_store::AgentServerCommand;
use serde::Deserialize;
use settings::Settings as _;
use task::ShellBuilder;
#[cfg(windows)]
use task::ShellKind;
use util::ResultExt as _;
use std::path::PathBuf;
@@ -92,23 +90,8 @@ impl AcpConnection {
) -> Result<Self> {
let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
let builder = ShellBuilder::new(&shell, cfg!(windows));
#[cfg(windows)]
let kind = builder.kind();
let (cmd, args) = builder.build(Some(command.path.display().to_string()), &command.args);
let mut child = util::command::new_smol_command(cmd);
#[cfg(windows)]
if kind == ShellKind::Cmd {
use smol::process::windows::CommandExt;
for arg in args {
child.raw_arg(arg);
}
} else {
child.args(args);
}
#[cfg(not(windows))]
child.args(args);
let mut child =
builder.build_command(Some(command.path.display().to_string()), &command.args);
child
.envs(command.env.iter().flatten())
.stdin(std::process::Stdio::piped())

View File

@@ -53,7 +53,7 @@ text.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http.workspace = true
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
tokio-socks.workspace = true
tokio.workspace = true
url.workspace = true
util.workspace = true

View File

@@ -1,18 +0,0 @@
[package]
name = "cloud_zeta2_prompt"
version = "0.1.0"
publish.workspace = true
edition.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/cloud_zeta2_prompt.rs"
[dependencies]
anyhow.workspace = true
cloud_llm_client.workspace = true
indoc.workspace = true
serde.workspace = true

View File

@@ -1,485 +0,0 @@
use anyhow::Result;
use cloud_llm_client::predict_edits_v3::{
self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
};
use indoc::indoc;
use std::cmp;
use std::fmt::Write;
use std::path::Path;
use std::sync::Arc;
pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
/// NOTE: Differs from zed version of constant - includes a newline
pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
/// NOTE: Differs from zed version of constant - includes a newline
pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
## Edit History
"#};
const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
---
Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
Do not include the cursor marker in your output.
If you're editing multiple files, be sure to reflect filename in the hunk's header.
"};
const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
# Instructions
You are an edit prediction agent in a code editor.
Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish.
Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make.
Always continue along the user's current trajectory, rather than changing course.
## Output Format
You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change
along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format:
<edits path="my-project/src/myapp/cli.py">
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
</edits>
- Specify the file to edit using the `path` attribute.
- Use `<old_text>` and `<new_text>` tags to replace content
- `<old_text>` must exactly match existing file content, including indentation
- `<old_text>` cannot be empty
- Do not escape quotes, newlines, or other characters within tags
- Always close all tags properly
- Don't include the <|user_cursor|> marker in your output.
## Edit History
"#};
const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
---
Remember that the edits in the edit history have already been applied.
"#};
pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
let prompt_data = PromptData {
events: request.events.clone(),
cursor_point: request.cursor_point,
cursor_path: request.excerpt_path.clone(),
included_files: request.related_files.clone(),
};
match request.prompt_format {
PromptFormat::MinimalQwen => {
return Ok(MinimalQwenPrompt.render(&prompt_data));
}
PromptFormat::SeedCoder1120 => {
return Ok(SeedCoder1120Prompt.render(&prompt_data));
}
_ => (),
};
let insertions = match request.prompt_format {
PromptFormat::Minimal | PromptFormat::OldTextNewText => {
vec![(request.cursor_point, CURSOR_MARKER)]
}
PromptFormat::OnlySnippets => vec![],
PromptFormat::MinimalQwen => unreachable!(),
PromptFormat::SeedCoder1120 => unreachable!(),
};
let mut prompt = match request.prompt_format {
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
PromptFormat::OnlySnippets => String::new(),
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
PromptFormat::MinimalQwen => unreachable!(),
PromptFormat::SeedCoder1120 => unreachable!(),
};
if request.events.is_empty() {
prompt.push_str("(No edit history)\n\n");
} else {
let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
"The following are the latest edits made by the user, from earlier to later.\n\n"
} else {
"Here are the latest edits made by the user, from earlier to later.\n\n"
};
prompt.push_str(edit_preamble);
push_events(&mut prompt, &request.events);
}
let excerpts_preamble = match request.prompt_format {
PromptFormat::Minimal => indoc! {"
## Part of the file under the cursor
(The cursor marker <|user_cursor|> indicates the current user cursor position.
The file is in current state, edits from edit history has been applied.
We only show part of the file around the cursor.
You can only edit exactly this part of the file.
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
"},
PromptFormat::OldTextNewText => indoc! {"
## Code Excerpts
Here is some excerpts of code that you should take into account to predict the next edit.
The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history.
In addition other excerpts are included to better understand what the edit will be, including the declaration
or references of symbols around the cursor, or other similar code snippets that may need to be updated
following patterns that appear in the edit history.
Consider each of them carefully in relation to the edit history, and that the user may not have navigated
to the next place they want to edit yet.
Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
"},
PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
indoc! {"
## Code Excerpts
The cursor marker <|user_cursor|> indicates the current user cursor position.
The file is in current state, edits from edit history have been applied.
"}
}
};
prompt.push_str(excerpts_preamble);
prompt.push('\n');
let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
for related_file in &request.related_files {
if request.prompt_format == PromptFormat::Minimal {
write_codeblock_with_filename(
&related_file.path,
&related_file.excerpts,
if related_file.path == request.excerpt_path {
&insertions
} else {
&[]
},
related_file.max_row,
include_line_numbers,
&mut prompt,
);
} else {
write_codeblock(
&related_file.path,
&related_file.excerpts,
if related_file.path == request.excerpt_path {
&insertions
} else {
&[]
},
related_file.max_row,
include_line_numbers,
&mut prompt,
);
}
}
match request.prompt_format {
PromptFormat::OldTextNewText => {
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
}
PromptFormat::Minimal => {
prompt.push_str(MINIMAL_PROMPT_REMINDER);
}
_ => {}
}
Ok(prompt)
}
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
match prompt_format {
PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
_ => GenerationParams::default(),
}
}
pub fn write_codeblock<'a>(
path: &Path,
excerpts: impl IntoIterator<Item = &'a Excerpt>,
sorted_insertions: &[(Point, &str)],
file_line_count: Line,
include_line_numbers: bool,
output: &'a mut String,
) {
writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
write_excerpts(
excerpts,
sorted_insertions,
file_line_count,
include_line_numbers,
output,
);
write!(output, "`````\n\n").unwrap();
}
fn write_codeblock_with_filename<'a>(
path: &Path,
excerpts: impl IntoIterator<Item = &'a Excerpt>,
sorted_insertions: &[(Point, &str)],
file_line_count: Line,
include_line_numbers: bool,
output: &'a mut String,
) {
writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
write_excerpts(
excerpts,
sorted_insertions,
file_line_count,
include_line_numbers,
output,
);
write!(output, "`````\n\n").unwrap();
}
pub fn write_excerpts<'a>(
excerpts: impl IntoIterator<Item = &'a Excerpt>,
sorted_insertions: &[(Point, &str)],
file_line_count: Line,
include_line_numbers: bool,
output: &mut String,
) {
let mut current_row = Line(0);
let mut sorted_insertions = sorted_insertions.iter().peekable();
for excerpt in excerpts {
if excerpt.start_line > current_row {
writeln!(output, "").unwrap();
}
if excerpt.text.is_empty() {
return;
}
current_row = excerpt.start_line;
for mut line in excerpt.text.lines() {
if include_line_numbers {
write!(output, "{}|", current_row.0 + 1).unwrap();
}
while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
match current_row.cmp(&insertion_location.line) {
cmp::Ordering::Equal => {
let (prefix, suffix) = line.split_at(insertion_location.column as usize);
output.push_str(prefix);
output.push_str(insertion_marker);
line = suffix;
sorted_insertions.next();
}
cmp::Ordering::Less => break,
cmp::Ordering::Greater => {
sorted_insertions.next();
break;
}
}
}
output.push_str(line);
output.push('\n');
current_row.0 += 1;
}
}
if current_row < file_line_count {
writeln!(output, "").unwrap();
}
}
pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
if events.is_empty() {
return;
};
writeln!(output, "`````diff").unwrap();
for event in events {
writeln!(output, "{}", event).unwrap();
}
writeln!(output, "`````\n").unwrap();
}
struct PromptData {
events: Vec<Arc<Event>>,
cursor_point: Point,
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
included_files: Vec<RelatedFile>,
}
#[derive(Default)]
pub struct GenerationParams {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
}
trait PromptFormatter {
fn render(&self, data: &PromptData) -> String;
fn generation_params() -> GenerationParams {
return GenerationParams::default();
}
}
struct MinimalQwenPrompt;
impl PromptFormatter for MinimalQwenPrompt {
fn render(&self, data: &PromptData) -> String {
let edit_history = self.fmt_edit_history(data);
let context = self.fmt_context(data);
format!(
"{instructions}\n\n{edit_history}\n\n{context}",
instructions = MinimalQwenPrompt::INSTRUCTIONS,
edit_history = edit_history,
context = context
)
}
}
impl MinimalQwenPrompt {
const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
fn fmt_edit_history(&self, data: &PromptData) -> String {
if data.events.is_empty() {
"(No edit history)\n\n".to_string()
} else {
let mut events_str = String::new();
push_events(&mut events_str, &data.events);
format!(
"The following are the latest edits made by the user, from earlier to later.\n\n{}",
events_str
)
}
}
fn fmt_context(&self, data: &PromptData) -> String {
let mut context = String::new();
let include_line_numbers = true;
for related_file in &data.included_files {
writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
if related_file.path == data.cursor_path {
write!(context, "<|fim_prefix|>").unwrap();
write_excerpts(
&related_file.excerpts,
&[(data.cursor_point, "<|fim_suffix|>")],
related_file.max_row,
include_line_numbers,
&mut context,
);
writeln!(context, "<|fim_middle|>").unwrap();
} else {
write_excerpts(
&related_file.excerpts,
&[],
related_file.max_row,
include_line_numbers,
&mut context,
);
}
}
context
}
}
struct SeedCoder1120Prompt;
impl PromptFormatter for SeedCoder1120Prompt {
fn render(&self, data: &PromptData) -> String {
let edit_history = self.fmt_edit_history(data);
let context = self.fmt_context(data);
format!(
"# Edit History:\n{edit_history}\n\n{context}",
edit_history = edit_history,
context = context
)
}
fn generation_params() -> GenerationParams {
GenerationParams {
temperature: Some(0.2),
top_p: Some(0.9),
stop: Some(vec!["<[end_of_sentence]>".into()]),
}
}
}
impl SeedCoder1120Prompt {
fn fmt_edit_history(&self, data: &PromptData) -> String {
if data.events.is_empty() {
"(No edit history)\n\n".to_string()
} else {
let mut events_str = String::new();
push_events(&mut events_str, &data.events);
events_str
}
}
fn fmt_context(&self, data: &PromptData) -> String {
let mut context = String::new();
let include_line_numbers = true;
for related_file in &data.included_files {
writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
if related_file.path == data.cursor_path {
let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
context.push_str(&fim_prompt);
} else {
write_excerpts(
&related_file.excerpts,
&[],
related_file.max_row,
include_line_numbers,
&mut context,
);
}
}
context
}
fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
let mut buf = String::new();
const FIM_SUFFIX: &str = "<[fim-suffix]>";
const FIM_PREFIX: &str = "<[fim-prefix]>";
const FIM_MIDDLE: &str = "<[fim-middle]>";
write!(buf, "{}", FIM_PREFIX).unwrap();
write_excerpts(
&file.excerpts,
&[(cursor_point, FIM_SUFFIX)],
file.max_row,
true,
&mut buf,
);
// Swap prefix and suffix parts
let index = buf.find(FIM_SUFFIX).unwrap();
let prefix = &buf[..index];
let suffix = &buf[index..];
format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
}
}

View File

@@ -33,12 +33,10 @@ impl StdioTransport {
) -> Result<Self> {
let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
let builder = ShellBuilder::new(&shell, cfg!(windows));
let (command, args) =
builder.build(Some(binary.executable.display().to_string()), &binary.args);
let mut command =
builder.build_command(Some(binary.executable.display().to_string()), &binary.args);
let mut command = util::command::new_smol_command(command);
command
.args(args)
.envs(binary.env.unwrap_or_default())
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())

View File

@@ -1045,54 +1045,47 @@ async fn heuristic_syntactic_expand(
let node_range = node_start..node_end;
let row_count = node_end.row - node_start.row + 1;
let mut ancestor_range = None;
let reached_outline_node = cx.background_executor().scoped({
let node_range = node_range.clone();
let outline_range = outline_range.clone();
let ancestor_range = &mut ancestor_range;
|scope| {
scope.spawn(async move {
// Stop if we've exceeded the row count or reached an outline node. Then, find the interval
// of node children which contains the query range. For example, this allows just returning
// the header of a declaration rather than the entire declaration.
if row_count > max_row_count || outline_range == Some(node_range.clone()) {
let mut cursor = node.walk();
let mut included_child_start = None;
let mut included_child_end = None;
let mut previous_end = node_start;
if cursor.goto_first_child() {
loop {
let child_node = cursor.node();
let child_range =
previous_end..Point::from_ts_point(child_node.end_position());
if included_child_start.is_none()
&& child_range.contains(&input_range.start)
{
included_child_start = Some(child_range.start);
}
if child_range.contains(&input_range.end) {
included_child_end = Some(child_range.end);
}
previous_end = child_range.end;
if !cursor.goto_next_sibling() {
break;
}
cx.background_executor()
.await_on_background(async {
// Stop if we've exceeded the row count or reached an outline node. Then, find the interval
// of node children which contains the query range. For example, this allows just returning
// the header of a declaration rather than the entire declaration.
if row_count > max_row_count || outline_range == Some(node_range.clone()) {
let mut cursor = node.walk();
let mut included_child_start = None;
let mut included_child_end = None;
let mut previous_end = node_start;
if cursor.goto_first_child() {
loop {
let child_node = cursor.node();
let child_range =
previous_end..Point::from_ts_point(child_node.end_position());
if included_child_start.is_none()
&& child_range.contains(&input_range.start)
{
included_child_start = Some(child_range.start);
}
if child_range.contains(&input_range.end) {
included_child_end = Some(child_range.end);
}
previous_end = child_range.end;
if !cursor.goto_next_sibling() {
break;
}
}
let end = included_child_end.unwrap_or(node_range.end);
if let Some(start) = included_child_start {
let row_count = end.row - start.row;
if row_count < max_row_count {
*ancestor_range =
Some(Some(RangeInclusive::new(start.row, end.row)));
return;
}
}
*ancestor_range = Some(None);
}
})
}
});
reached_outline_node.await;
let end = included_child_end.unwrap_or(node_range.end);
if let Some(start) = included_child_start {
let row_count = end.row - start.row;
if row_count < max_row_count {
ancestor_range = Some(Some(RangeInclusive::new(start.row, end.row)));
return;
}
}
ancestor_range = Some(None);
}
})
.await;
if let Some(node) = ancestor_range {
return node;
}

View File

@@ -12,7 +12,7 @@ workspace = true
path = "src/edit_prediction.rs"
[features]
eval-support = []
cli-support = []
[dependencies]
ai_onboarding.workspace = true
@@ -21,7 +21,6 @@ arrayvec.workspace = true
brotli.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
collections.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
@@ -50,8 +49,6 @@ semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strsim.workspace = true
strum.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
@@ -62,6 +59,7 @@ uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }

View File

@@ -1,14 +1,13 @@
use anyhow::Result;
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
ZED_VERSION_HEADER_NAME,
};
use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
use collections::{HashMap, HashSet};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use edit_prediction_context::EditPredictionExcerptOptions;
@@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
use futures::{
AsyncReadExt as _, FutureExt as _, StreamExt as _,
channel::{
mpsc::{self, UnboundedReceiver},
oneshot,
},
channel::mpsc::{self, UnboundedReceiver},
select_biased,
};
use gpui::BackgroundExecutor;
@@ -58,8 +54,10 @@ mod onboarding_modal;
pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
pub mod udiff;
mod xml_edits;
mod zed_edit_prediction_delegate;
pub mod zeta1;
pub mod zeta2;
@@ -72,7 +70,6 @@ use crate::mercury::Mercury;
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
pub use crate::prediction::EditPredictionInputs;
use crate::prediction::EditPredictionResult;
pub use crate::sweep_ai::SweepAi;
pub use telemetry_events::EditPredictionRating;
@@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
min_bytes: 128,
target_before_cursor_over_total_bytes: 0.5,
},
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
prompt_format: PromptFormat::DEFAULT,
};
@@ -162,8 +158,7 @@ pub struct EditPredictionStore {
use_context: bool,
options: ZetaOptions,
update_required: bool,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: EditPredictionModel,
pub sweep_ai: SweepAi,
@@ -183,10 +178,22 @@ pub enum EditPredictionModel {
Mercury,
}
pub struct EditPredictionModelInput {
project: Entity<Project>,
buffer: Entity<Buffer>,
snapshot: BufferSnapshot,
position: Anchor,
events: Vec<Arc<zeta_prompt::Event>>,
related_files: Arc<[RelatedFile]>,
recent_paths: VecDeque<ProjectPath>,
trigger: PredictEditsRequestTrigger,
diagnostic_search_range: Range<Point>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ZetaOptions {
pub context: EditPredictionExcerptOptions,
pub max_prompt_bytes: usize,
pub prompt_format: predict_edits_v3::PromptFormat,
}
@@ -194,7 +201,8 @@ pub struct ZetaOptions {
pub enum DebugEvent {
ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
EditPredictionRequested(EditPredictionRequestedDebugEvent),
EditPredictionStarted(EditPredictionStartedDebugEvent),
EditPredictionFinished(EditPredictionFinishedDebugEvent),
}
#[derive(Debug)]
@@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent {
}
#[derive(Debug)]
pub struct EditPredictionRequestedDebugEvent {
pub inputs: EditPredictionInputs,
pub retrieval_time: Duration,
pub struct EditPredictionStartedDebugEvent {
pub buffer: WeakEntity<Buffer>,
pub position: Anchor,
pub local_prompt: Result<String, String>,
pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
pub prompt: Option<String>,
}
#[derive(Debug)]
pub struct EditPredictionFinishedDebugEvent {
pub buffer: WeakEntity<Buffer>,
pub position: Anchor,
pub model_output: Option<String>,
}
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
struct ProjectState {
events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
events: VecDeque<Arc<zeta_prompt::Event>>,
last_event: Option<LastEvent>,
recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
context_updates_tx: smol::channel::Sender<()>,
context_updates_rx: smol::channel::Receiver<()>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
last_prediction_refresh: Option<(EntityId, Instant)>,
cancelled_predictions: HashSet<usize>,
context: Entity<RelatedExcerptStore>,
@@ -241,7 +252,7 @@ struct ProjectState {
}
impl ProjectState {
pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
self.events
.iter()
.cloned()
@@ -272,6 +283,18 @@ impl ProjectState {
})
.detach()
}
fn active_buffer(
&self,
project: &Entity<Project>,
cx: &App,
) -> Option<(Entity<Buffer>, Option<Anchor>)> {
let project = project.read(cx);
let active_path = project.path_for_entry(project.active_entry()?, cx)?;
let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
Some((active_buffer, registered_buffer.last_position))
}
}
#[derive(Debug, Clone)]
@@ -362,6 +385,7 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
struct RegisteredBuffer {
snapshot: BufferSnapshot,
last_position: Option<Anchor>,
_subscriptions: [gpui::Subscription; 2],
}
@@ -376,7 +400,7 @@ impl LastEvent {
&self,
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
cx: &App,
) -> Option<Arc<predict_edits_v3::Event>> {
) -> Option<Arc<zeta_prompt::Event>> {
let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
@@ -396,7 +420,7 @@ impl LastEvent {
if path == old_path && diff.is_empty() {
None
} else {
Some(Arc::new(predict_edits_v3::Event::BufferChange {
Some(Arc::new(zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
@@ -481,8 +505,7 @@ impl EditPredictionStore {
},
),
update_required: false,
debug_tx: None,
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2,
sweep_ai: SweepAi::new(cx),
@@ -531,17 +554,11 @@ impl EditPredictionStore {
.is_some()
}
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
self.eval_cache = Some(cache);
}
pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
self.debug_tx = Some(debug_watch_tx);
debug_watch_rx
}
pub fn options(&self) -> &ZetaOptions {
&self.options
}
@@ -560,15 +577,41 @@ impl EditPredictionStore {
}
}
pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
project_state.events.clear();
}
}
pub fn edit_history_for_project(
&self,
project: &Entity<Project>,
) -> Vec<Arc<zeta_prompt::Event>> {
self.projects
.get(&project.entity_id())
.map(|project_state| project_state.events.iter().cloned().collect())
.unwrap_or_default()
}
pub fn context_for_project<'a>(
&'a self,
project: &Entity<Project>,
cx: &'a App,
) -> &'a [RelatedFile] {
) -> Arc<[RelatedFile]> {
self.projects
.get(&project.entity_id())
.map(|project| project.context.read(cx).related_files())
.unwrap_or(&[])
.unwrap_or_else(|| vec![].into())
}
pub fn context_for_project_with_buffers<'a>(
&'a self,
project: &Entity<Project>,
cx: &'a App,
) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
self.projects
.get(&project.entity_id())
.map(|project| project.context.read(cx).related_files_with_buffers())
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -599,85 +642,21 @@ impl EditPredictionStore {
cx: &mut Context<Self>,
) -> &mut ProjectState {
let entity_id = project.entity_id();
let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
self.projects
.entry(entity_id)
.or_insert_with(|| ProjectState {
context: {
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
cx.subscribe(
&related_excerpt_store,
move |this, _, event, _| match event {
RelatedExcerptStoreEvent::StartedRefresh => {
if let Some(debug_tx) = this.debug_tx.clone() {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalStarted(
ContextRetrievalStartedDebugEvent {
project_entity_id: entity_id,
timestamp: Instant::now(),
search_prompt: String::new(),
},
))
.ok();
}
}
RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
mean_definition_latency,
max_definition_latency,
} => {
if let Some(debug_tx) = this.debug_tx.clone() {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalFinished(
ContextRetrievalFinishedDebugEvent {
project_entity_id: entity_id,
timestamp: Instant::now(),
metadata: vec![
(
"Cache Hits",
format!(
"{}/{}",
cache_hit_count,
cache_hit_count + cache_miss_count
)
.into(),
),
(
"Max LSP Time",
format!(
"{} ms",
max_definition_latency.as_millis()
)
.into(),
),
(
"Mean LSP Time",
format!(
"{} ms",
mean_definition_latency.as_millis()
)
.into(),
),
],
},
))
.ok();
}
if let Some(project_state) = this.projects.get(&entity_id) {
project_state.context_updates_tx.send_blocking(()).ok();
}
}
},
)
cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
this.handle_excerpt_store_event(entity_id, event);
})
.detach();
related_excerpt_store
},
events: VecDeque::new(),
last_event: None,
recent_paths: VecDeque::new(),
context_updates_rx,
context_updates_tx,
debug_tx: None,
registered_buffers: HashMap::default(),
current_prediction: None,
cancelled_predictions: HashSet::default(),
@@ -689,12 +668,79 @@ impl EditPredictionStore {
})
}
pub fn project_context_updates(
&self,
pub fn remove_project(&mut self, project: &Entity<Project>) {
self.projects.remove(&project.entity_id());
}
fn handle_excerpt_store_event(
&mut self,
project_entity_id: EntityId,
event: &RelatedExcerptStoreEvent,
) {
if let Some(project_state) = self.projects.get(&project_entity_id) {
if let Some(debug_tx) = project_state.debug_tx.clone() {
match event {
RelatedExcerptStoreEvent::StartedRefresh => {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalStarted(
ContextRetrievalStartedDebugEvent {
project_entity_id: project_entity_id,
timestamp: Instant::now(),
search_prompt: String::new(),
},
))
.ok();
}
RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
mean_definition_latency,
max_definition_latency,
} => {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalFinished(
ContextRetrievalFinishedDebugEvent {
project_entity_id: project_entity_id,
timestamp: Instant::now(),
metadata: vec![
(
"Cache Hits",
format!(
"{}/{}",
cache_hit_count,
cache_hit_count + cache_miss_count
)
.into(),
),
(
"Max LSP Time",
format!("{} ms", max_definition_latency.as_millis())
.into(),
),
(
"Mean LSP Time",
format!("{} ms", mean_definition_latency.as_millis())
.into(),
),
],
},
))
.ok();
}
}
}
}
}
pub fn debug_info(
&mut self,
project: &Entity<Project>,
) -> Option<smol::channel::Receiver<()>> {
let project_state = self.projects.get(&project.entity_id())?;
Some(project_state.context_updates_rx.clone())
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<DebugEvent> {
let project_state = self.get_or_init_project(project, cx);
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
project_state.debug_tx = Some(debug_watch_tx);
debug_watch_rx
}
fn handle_project_event(
@@ -768,6 +814,7 @@ impl EditPredictionStore {
let project_entity_id = project.entity_id();
entry.insert(RegisteredBuffer {
snapshot,
last_position: None,
_subscriptions: [
cx.subscribe(buffer, {
let project = project.downgrade();
@@ -855,13 +902,21 @@ impl EditPredictionStore {
});
}
fn current_prediction_for_buffer(
&self,
fn prediction_at(
&mut self,
buffer: &Entity<Buffer>,
position: Option<language::Anchor>,
project: &Entity<Project>,
cx: &App,
) -> Option<BufferEditPrediction<'_>> {
let project_state = self.projects.get(&project.entity_id())?;
let project_state = self.projects.get_mut(&project.entity_id())?;
if let Some(position) = position
&& let Some(buffer) = project_state
.registered_buffers
.get_mut(&buffer.entity_id())
{
buffer.last_position = Some(position);
}
let CurrentEditPrediction {
requested_by,
@@ -1104,12 +1159,21 @@ impl EditPredictionStore {
};
self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
let Some(open_buffer_task) = project
.update(cx, |project, cx| {
project
.active_entry()
.and_then(|entry| project.path_for_entry(entry, cx))
.map(|path| project.open_buffer(path, cx))
let Some((active_buffer, snapshot, cursor_point)) = this
.read_with(cx, |this, cx| {
let project_state = this.projects.get(&project.entity_id())?;
let (buffer, position) = project_state.active_buffer(&project, cx)?;
let snapshot = buffer.read(cx).snapshot();
if !Self::predictions_enabled_at(&snapshot, position, cx) {
return None;
}
let cursor_point = position
.map(|pos| pos.to_point(&snapshot))
.unwrap_or_default();
Some((buffer, snapshot, cursor_point))
})
.log_err()
.flatten()
@@ -1118,14 +1182,11 @@ impl EditPredictionStore {
};
cx.spawn(async move |cx| {
let active_buffer = open_buffer_task.await?;
let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
active_buffer,
&snapshot,
Default::default(),
Default::default(),
cursor_point,
&project,
cx,
)
@@ -1170,6 +1231,37 @@ impl EditPredictionStore {
});
}
fn predictions_enabled_at(
snapshot: &BufferSnapshot,
position: Option<language::Anchor>,
cx: &App,
) -> bool {
let file = snapshot.file();
let all_settings = all_language_settings(file, cx);
if !all_settings.show_edit_predictions(snapshot.language(), cx)
|| file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
{
return false;
}
if let Some(last_position) = position {
let settings = snapshot.settings_at(last_position, cx);
if !settings.edit_predictions_disabled_in.is_empty()
&& let Some(scope) = snapshot.language_scope_at(last_position)
&& let Some(scope_name) = scope.override_name()
&& settings
.edit_predictions_disabled_in
.iter()
.any(|s| s == scope_name)
{
return false;
}
}
true
}
#[cfg(not(test))]
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
#[cfg(test)]
@@ -1348,6 +1440,7 @@ impl EditPredictionStore {
let project_state = self.projects.get(&project.entity_id()).unwrap();
let events = project_state.events(cx);
let has_events = !events.is_empty();
let debug_tx = project_state.debug_tx.clone();
let snapshot = active_buffer.read(cx).snapshot();
let cursor_point = position.to_point(&snapshot);
@@ -1357,55 +1450,29 @@ impl EditPredictionStore {
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
let related_files = if self.use_context {
self.context_for_project(&project, cx).to_vec()
self.context_for_project(&project, cx)
} else {
Vec::new()
Vec::new().into()
};
let inputs = EditPredictionModelInput {
project: project.clone(),
buffer: active_buffer.clone(),
snapshot: snapshot.clone(),
position,
events,
related_files,
recent_paths: project_state.recent_paths.clone(),
trigger,
diagnostic_search_range: diagnostic_search_range.clone(),
debug_tx,
};
let task = match self.edit_prediction_model {
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
self,
&project,
&active_buffer,
snapshot.clone(),
position,
events,
trigger,
cx,
),
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
self,
&project,
&active_buffer,
snapshot.clone(),
position,
events,
related_files,
trigger,
cx,
),
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
&project,
&active_buffer,
snapshot.clone(),
position,
events,
&project_state.recent_paths,
related_files,
diagnostic_search_range.clone(),
cx,
),
EditPredictionModel::Mercury => self.mercury.request_prediction(
&project,
&active_buffer,
snapshot.clone(),
position,
events,
&project_state.recent_paths,
related_files,
diagnostic_search_range.clone(),
cx,
),
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
};
cx.spawn(async move |this, cx| {
@@ -1529,8 +1596,8 @@ impl EditPredictionStore {
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: Version,
#[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
#[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
#[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
#[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
http_client::Url::parse(&predict_edits_url)?
@@ -1540,7 +1607,7 @@ impl EditPredictionStore {
.build_zed_llm_url("/predict_edits/raw", &[])?
};
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
let cache_key = if let Some(cache) = eval_cache {
use collections::FxHasher;
use std::hash::{Hash, Hasher};
@@ -1574,7 +1641,7 @@ impl EditPredictionStore {
)
.await?;
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
if let Some((cache, request, key)) = cache_key {
cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
}
@@ -1706,6 +1773,20 @@ impl EditPredictionStore {
}
}
#[cfg(feature = "cli-support")]
pub fn set_context_for_buffer(
&mut self,
project: &Entity<Project>,
related_files: Vec<RelatedFile>,
cx: &mut Context<Self>,
) {
self.get_or_init_project(project, cx)
.context
.update(cx, |store, _| {
store.set_related_files(related_files);
});
}
fn is_file_open_source(
&self,
project: &Entity<Project>,
@@ -1729,14 +1810,14 @@ impl EditPredictionStore {
self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
}
fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
if !self.data_collection_choice.is_enabled() {
return false;
}
events.iter().all(|event| {
matches!(
event.as_ref(),
Event::BufferChange {
zeta_prompt::Event::BufferChange {
in_open_source_repo: true,
..
}
@@ -1817,10 +1898,10 @@ pub struct ZedUpdateRequiredError {
minimum_version: Version,
}
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
pub type EvalCacheKey = (EvalCacheEntryKind, u64);
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum EvalCacheEntryKind {
Context,
@@ -1828,7 +1909,7 @@ pub enum EvalCacheEntryKind {
Prediction,
}
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
impl std::fmt::Display for EvalCacheEntryKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@@ -1839,7 +1920,7 @@ impl std::fmt::Display for EvalCacheEntryKind {
}
}
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
pub trait EvalCache: Send + Sync {
fn read(&self, key: EvalCacheKey) -> Option<String>;
fn write(&self, key: EvalCacheKey, input: &str, value: &str);

View File

@@ -1,5 +1,5 @@
use super::*;
use crate::zeta1::MAX_EVENT_TOKENS;
use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
use client::{UserStore, test::FakeServer};
use clock::{FakeSystemClock, ReplicaId};
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -7,7 +7,6 @@ use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
RejectEditPredictionsBody,
};
use edit_prediction_context::Line;
use futures::{
AsyncReadExt, StreamExt,
channel::{mpsc, oneshot},
@@ -28,6 +27,7 @@ use settings::SettingsStore;
use std::{path::Path, sync::Arc, time::Duration};
use util::{path, rel_path::rel_path};
use uuid::Uuid;
use zeta_prompt::ZetaPromptInput;
use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
@@ -45,10 +45,6 @@ async fn test_current_state(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
ep_store.update(cx, |ep_store, cx| {
ep_store.register_project(&project, cx);
});
let buffer1 = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
@@ -60,30 +56,38 @@ async fn test_current_state(cx: &mut TestAppContext) {
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot1.anchor_before(language::Point::new(1, 3));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_project(&project, cx);
ep_store.register_buffer(&buffer1, &project, cx);
});
// Prediction for current file
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
let (_request, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(indoc! {r"
--- a/root/1.txt
+++ b/root/1.txt
@@ ... @@
Hello!
-How
+How are you?
Bye
"}))
.send(model_response(
request,
indoc! {r"
--- a/root/1.txt
+++ b/root/1.txt
@@ ... @@
Hello!
-How
+How are you?
Bye
"},
))
.unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.current_prediction_for_buffer(&buffer1, &project, cx)
.prediction_at(&buffer1, None, &project, cx)
.unwrap();
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
@@ -120,22 +124,26 @@ async fn test_current_state(cx: &mut TestAppContext) {
});
});
let (_request, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(indoc! {r#"
--- a/root/2.txt
+++ b/root/2.txt
Hola!
-Como
+Como estas?
Adios
"#}))
.send(model_response(
request,
indoc! {r#"
--- a/root/2.txt
+++ b/root/2.txt
@@ ... @@
Hola!
-Como
+Como estas?
Adios
"#},
))
.unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.current_prediction_for_buffer(&buffer1, &project, cx)
.prediction_at(&buffer1, None, &project, cx)
.unwrap();
assert_matches!(
prediction,
@@ -151,9 +159,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
.await
.unwrap();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.current_prediction_for_buffer(&buffer2, &project, cx)
.prediction_at(&buffer2, None, &project, cx)
.unwrap();
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
@@ -186,7 +194,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
// TODO Put back when we have a structured request again
// assert_eq!(
@@ -202,15 +210,18 @@ async fn test_simple_request(cx: &mut TestAppContext) {
// );
respond_tx
.send(model_response(indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"}))
.send(model_response(
request,
indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"},
))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -276,15 +287,18 @@ async fn test_request_events(cx: &mut TestAppContext) {
);
respond_tx
.send(model_response(indoc! {r#"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"#}))
.send(model_response(
request,
indoc! {r#"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"#},
))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -324,27 +338,17 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
const NO_OP_DIFF: &str = indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How
Bye
"};
let (_, respond_tx) = requests.predict.next().await.unwrap();
let response = model_response(NO_OP_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let response = model_response(request, "");
let id = response.id.clone();
respond_tx.send(response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
assert!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.is_none()
);
});
@@ -389,22 +393,22 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text("Hello!\nHow are you?\nBye", cx);
});
let response = model_response(SIMPLE_DIFF);
let response = model_response(request, SIMPLE_DIFF);
let id = response.id.clone();
respond_tx.send(response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
assert!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.is_none()
);
});
@@ -459,17 +463,17 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(SIMPLE_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -482,18 +486,18 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let second_response = model_response(SIMPLE_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// second replaces first
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -541,17 +545,17 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(SIMPLE_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -564,27 +568,30 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
// worse than current prediction
let second_response = model_response(indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are
Bye
"});
let second_response = model_response(
request,
indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are
Bye
"},
);
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// first is preferred over second
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -633,29 +640,29 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_first) = requests.predict.next().await.unwrap();
let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_second) = requests.predict.next().await.unwrap();
let (request, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle
cx.run_until_parked();
// second responds first
let second_response = model_response(SIMPLE_DIFF);
let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_second.send(second_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// current prediction is second
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -663,17 +670,17 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
);
});
let first_response = model_response(SIMPLE_DIFF);
let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// current prediction is still second, since first was cancelled
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -724,13 +731,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_first) = requests.predict.next().await.unwrap();
let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_second) = requests.predict.next().await.unwrap();
let (request2, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle, so requests are sent
cx.run_until_parked();
@@ -754,19 +761,19 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
// wait for throttle
cx.run_until_parked();
let (_, respond_third) = requests.predict.next().await.unwrap();
let (request3, respond_third) = requests.predict.next().await.unwrap();
let first_response = model_response(SIMPLE_DIFF);
let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// current prediction is first
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -774,17 +781,17 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
let cancelled_response = model_response(SIMPLE_DIFF);
let cancelled_response = model_response(request2, SIMPLE_DIFF);
let cancelled_id = cancelled_response.id.clone();
respond_second.send(cancelled_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// current prediction is still first, since second was cancelled
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -792,17 +799,17 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
let third_response = model_response(SIMPLE_DIFF);
let third_response = model_response(request3, SIMPLE_DIFF);
let third_response_id = third_response.id.clone();
respond_third.send(third_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// third completes and replaces first
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -1036,7 +1043,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
// );
// }
fn model_response(text: &str) -> open_ai::Response {
// Generate a model response that would apply the given diff to the active file.
fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
let prompt = match &request.messages[0] {
open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(content),
} => content,
_ => panic!("unexpected request {request:?}"),
};
let open = "<editable_region>\n";
let close = "</editable_region>";
let cursor = "<|user_cursor|>";
let start_ix = open.len() + prompt.find(open).unwrap();
let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
open_ai::Response {
id: Uuid::new_v4().to_string(),
object: "response".into(),
@@ -1045,7 +1069,7 @@ fn model_response(text: &str) -> open_ai::Response {
choices: vec![open_ai::Choice {
index: 0,
message: open_ai::RequestMessage::Assistant {
content: Some(open_ai::MessageContent::Plain(text.to_string())),
content: Some(open_ai::MessageContent::Plain(new_excerpt)),
tool_calls: vec![],
},
finish_reason: None,
@@ -1160,20 +1184,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
.await;
let completion = EditPrediction {
let prediction = EditPrediction {
edits,
edit_preview,
buffer: buffer.clone(),
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
id: EditPredictionId("the-id".into()),
inputs: EditPredictionInputs {
inputs: ZetaPromptInput {
events: Default::default(),
included_files: Default::default(),
cursor_point: cloud_llm_client::predict_edits_v3::Point {
line: Line(0),
column: 0,
},
related_files: Default::default(),
cursor_path: Path::new("").into(),
cursor_excerpt: "".into(),
editable_range_in_excerpt: 0..0,
cursor_offset_in_excerpt: 0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
@@ -1182,7 +1205,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
cx.update(|cx| {
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1192,7 +1215,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1202,7 +1225,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1212,7 +1235,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1222,7 +1245,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1232,7 +1255,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1242,7 +1265,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1252,7 +1275,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1260,7 +1283,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
})
}

View File

@@ -735,6 +735,7 @@ mod tests {
true,
fs.clone(),
Default::default(),
true,
&mut cx.to_async(),
)
.await
@@ -758,6 +759,7 @@ mod tests {
true,
fs.clone(),
Default::default(),
true,
&mut cx.to_async(),
)
.await
@@ -816,6 +818,7 @@ mod tests {
true,
fs.clone(),
Default::default(),
true,
&mut cx.to_async(),
)
.await

View File

@@ -1,20 +1,17 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
use project::{Project, ProjectPath};
use std::{
collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
};
use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
use zeta_prompt::ZetaPromptInput;
use crate::{
EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
prediction::EditPredictionResult,
};
@@ -38,16 +35,17 @@ impl Mercury {
store_api_token_in_keychain(api_token, cx)
}
pub fn request_prediction(
pub(crate) fn request_prediction(
&self,
_project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
_recent_paths: &VecDeque<ProjectPath>,
related_files: Vec<RelatedFile>,
_diagnostic_search_range: Range<Point>,
EditPredictionModelInput {
buffer,
snapshot,
position,
events,
related_files,
debug_tx,
..
}: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
@@ -62,6 +60,7 @@ impl Mercury {
let http_client = cx.http_client();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let active_buffer = buffer.clone();
let result = cx.background_spawn(async move {
let (editable_range, context_range) =
@@ -72,39 +71,39 @@ impl Mercury {
MAX_REWRITE_TOKENS,
);
let offset_range = editable_range.to_offset(&snapshot);
let prompt = build_prompt(
&events,
&related_files,
&snapshot,
full_path.as_ref(),
cursor_point,
editable_range,
context_range.clone(),
);
let context_offset_range = context_range.to_offset(&snapshot);
let inputs = EditPredictionInputs {
events: events,
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(
context_range.start.row,
),
text: snapshot
.text_for_range(context_range.clone())
.collect::<String>()
.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
let editable_offset_range = editable_range.to_offset(&snapshot);
let inputs = zeta_prompt::ZetaPromptInput {
events,
related_files,
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
- context_range.start.to_offset(&snapshot),
cursor_path: full_path.clone(),
cursor_excerpt: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
editable_range_in_excerpt: (editable_offset_range.start
- context_offset_range.start)
..(editable_offset_range.end - context_offset_range.start),
};
let prompt = build_prompt(&inputs);
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: active_buffer.downgrade(),
prompt: Some(prompt.clone()),
position,
},
))
.ok();
}
let request_body = open_ai::Request {
model: "mercury-coder".into(),
messages: vec![open_ai::RequestMessage::User {
@@ -160,6 +159,18 @@ impl Mercury {
let id = mem::take(&mut response.id);
let response_str = text_from_response(response).unwrap_or_default();
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
EditPredictionFinishedDebugEvent {
buffer: active_buffer.downgrade(),
model_output: Some(response_str.clone()),
position,
},
))
.ok();
}
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
@@ -168,15 +179,16 @@ impl Mercury {
if response_str != NO_PREDICTION_OUTPUT {
let old_text = snapshot
.text_for_range(offset_range.clone())
.text_for_range(editable_offset_range.clone())
.collect::<String>();
edits.extend(
language::text_diff(&old_text, &response_str)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(offset_range.start + range.start)
..snapshot.anchor_before(offset_range.start + range.end),
snapshot.anchor_after(editable_offset_range.start + range.start)
..snapshot
.anchor_before(editable_offset_range.start + range.end),
text,
)
}),
@@ -186,8 +198,6 @@ impl Mercury {
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
});
let buffer = active_buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) =
result.await.context("Mercury edit prediction failed")?;
@@ -208,15 +218,7 @@ impl Mercury {
}
}
fn build_prompt(
events: &[Arc<Event>],
related_files: &[RelatedFile],
cursor_buffer: &BufferSnapshot,
cursor_buffer_path: &Path,
cursor_point: Point,
editable_range: Range<Point>,
context_range: Range<Point>,
) -> String {
fn build_prompt(inputs: &ZetaPromptInput) -> String {
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
@@ -237,14 +239,14 @@ fn build_prompt(
&mut prompt,
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|prompt| {
for related_file in related_files {
for related_file in inputs.related_files.iter() {
for related_excerpt in &related_file.excerpts {
push_delimited(
prompt,
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|prompt| {
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
prompt.push_str(related_file.path.path.as_unix_str());
prompt.push_str(related_file.path.to_string_lossy().as_ref());
prompt.push('\n');
prompt.push_str(&related_excerpt.text.to_string());
},
@@ -259,21 +261,22 @@ fn build_prompt(
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|prompt| {
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
prompt.push('\n');
let prefix_range = context_range.start..editable_range.start;
let suffix_range = editable_range.end..context_range.end;
prompt.extend(cursor_buffer.text_for_range(prefix_range));
prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
let range_before_cursor = editable_range.start..cursor_point;
let range_after_cursor = cursor_point..editable_range.end;
prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
prompt.push_str(
&inputs.cursor_excerpt
[inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
);
prompt.push_str(CURSOR_TAG);
prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
prompt.push_str(
&inputs.cursor_excerpt
[inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
);
});
prompt.extend(cursor_buffer.text_for_range(suffix_range));
prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
},
);
@@ -281,8 +284,8 @@ fn build_prompt(
&mut prompt,
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|prompt| {
for event in events {
writeln!(prompt, "{event}").unwrap();
for event in inputs.events.iter() {
zeta_prompt::write_event(prompt, &event);
}
},
);

View File

@@ -1,6 +1,5 @@
use std::{
ops::Range,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
@@ -9,7 +8,7 @@ use cloud_llm_client::EditPredictionRejectReason;
use edit_prediction_types::interpolate_edits;
use gpui::{AsyncApp, Entity, SharedString};
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
use serde::Serialize;
use zeta_prompt::ZetaPromptInput;
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(pub SharedString);
@@ -40,7 +39,7 @@ impl EditPredictionResult {
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
buffer_snapshotted_at: Instant,
response_received_at: Instant,
inputs: EditPredictionInputs,
inputs: ZetaPromptInput,
cx: &mut AsyncApp,
) -> Self {
if edits.is_empty() {
@@ -94,15 +93,7 @@ pub struct EditPrediction {
pub buffer: Entity<Buffer>,
pub buffer_snapshotted_at: Instant,
pub response_received_at: Instant,
pub inputs: EditPredictionInputs,
}
#[derive(Debug, Clone, Serialize)]
pub struct EditPredictionInputs {
pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
pub cursor_path: Arc<Path>,
pub inputs: zeta_prompt::ZetaPromptInput,
}
impl EditPrediction {
@@ -133,9 +124,12 @@ impl std::fmt::Debug for EditPrediction {
#[cfg(test)]
mod tests {
use std::path::Path;
use super::*;
use gpui::{App, Entity, TestAppContext, prelude::*};
use language::{Buffer, ToOffset as _};
use zeta_prompt::ZetaPromptInput;
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
@@ -154,14 +148,13 @@ mod tests {
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,
inputs: EditPredictionInputs {
inputs: ZetaPromptInput {
events: vec![],
included_files: vec![],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
line: cloud_llm_client::predict_edits_v3::Line(0),
column: 0,
},
related_files: vec![].into(),
cursor_path: Path::new("path.txt").into(),
cursor_offset_in_excerpt: 0,
cursor_excerpt: "".into(),
editable_range_in_excerpt: 0..0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),

View File

@@ -1,26 +1,21 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
use language::{Point, ToOffset as _};
use lsp::DiagnosticSeverity;
use project::{Project, ProjectPath};
use serde::{Deserialize, Serialize};
use std::{
collections::VecDeque,
fmt::{self, Write as _},
ops::Range,
path::Path,
sync::Arc,
time::Instant,
};
use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
@@ -44,40 +39,34 @@ impl SweepAi {
pub fn request_prediction_with_sweep(
&self,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
recent_paths: &VecDeque<ProjectPath>,
related_files: Vec<RelatedFile>,
diagnostic_search_range: Range<Point>,
inputs: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let debug_info = self.debug_info.clone();
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = snapshot
let full_path: Arc<Path> = inputs
.snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
let project_file = project::File::from_dyn(snapshot.file());
let project_file = project::File::from_dyn(inputs.snapshot.file());
let repo_name = project_file
.map(|file| file.worktree.read(cx).root_name_str())
.unwrap_or("untitled")
.into();
let offset = position.to_offset(&snapshot);
let offset = inputs.position.to_offset(&inputs.snapshot);
let recent_buffers = recent_paths.iter().cloned();
let recent_buffers = inputs.recent_paths.iter().cloned();
let http_client = cx.http_client();
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
if active_buffer == &buffer {
let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
if inputs.buffer == buffer {
None
} else {
Some(buffer.read(cx).snapshot())
@@ -86,14 +75,13 @@ impl SweepAi {
.take(3)
.collect::<Vec<_>>();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let result = cx.background_spawn(async move {
let text = snapshot.text();
let text = inputs.snapshot.text();
let mut recent_changes = String::new();
for event in &events {
for event in &inputs.events {
write_event(event.as_ref(), &mut recent_changes).unwrap();
}
@@ -122,20 +110,23 @@ impl SweepAi {
})
.collect::<Vec<_>>();
let retrieval_chunks = related_files
let retrieval_chunks = inputs
.related_files
.iter()
.flat_map(|related_file| {
related_file.excerpts.iter().map(|excerpt| FileChunk {
file_path: related_file.path.path.as_unix_str().to_string(),
start_line: excerpt.point_range.start.row as usize,
end_line: excerpt.point_range.end.row as usize,
file_path: related_file.path.to_string_lossy().to_string(),
start_line: excerpt.row_range.start as usize,
end_line: excerpt.row_range.end as usize,
content: excerpt.text.to_string(),
timestamp: None,
})
})
.collect();
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
let diagnostic_entries = inputs
.snapshot
.diagnostics_in_range(inputs.diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
@@ -195,21 +186,14 @@ impl SweepAi {
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
let inputs = EditPredictionInputs {
events,
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(0),
text: request_body.file_contents.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
let ep_inputs = zeta_prompt::ZetaPromptInput {
events: inputs.events,
related_files: inputs.related_files.clone(),
cursor_path: full_path.clone(),
cursor_excerpt: request_body.file_contents.into(),
// we actually don't know
editable_range_in_excerpt: 0..inputs.snapshot.len(),
cursor_offset_in_excerpt: request_body.cursor_position,
};
let request = http_client::Request::builder()
@@ -237,15 +221,20 @@ impl SweepAi {
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
let old_text = snapshot
let old_text = inputs
.snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
inputs
.snapshot
.anchor_after(response.start_index + range.start)
..inputs
.snapshot
.anchor_before(response.start_index + range.end),
text,
)
})
@@ -254,13 +243,13 @@ impl SweepAi {
anyhow::Ok((
response.autocomplete_id,
edits,
snapshot,
inputs.snapshot,
response_received_at,
inputs,
ep_inputs,
))
});
let buffer = active_buffer.clone();
let buffer = inputs.buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
@@ -403,12 +392,9 @@ struct AdditionalCompletion {
pub finish_reason: Option<String>,
}
fn write_event(
event: &cloud_llm_client::predict_edits_v3::Event,
f: &mut impl fmt::Write,
) -> fmt::Result {
fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
match event {
cloud_llm_client::predict_edits_v3::Event::BufferChange {
zeta_prompt::Event::BufferChange {
old_path,
path,
diff,

View File

@@ -14,87 +14,48 @@ use anyhow::anyhow;
use collections::HashMap;
use gpui::AsyncApp;
use gpui::Entity;
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
use project::Project;
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
use project::{Project, ProjectPath};
use util::paths::PathStyle;
use util::rel_path::RelPath;
pub async fn parse_diff<'a>(
diff_str: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
let mut diff = DiffParser::new(diff_str);
let mut edited_buffer = None;
let mut edits = Vec::new();
while let Some(event) = diff.next()? {
match event {
DiffEvent::Hunk {
path: file_path,
hunk,
} => {
let (buffer, ranges) = match edited_buffer {
None => {
edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
edited_buffer
.as_ref()
.context("Model tried to edit a file that wasn't included")?
}
Some(ref current) => current,
};
edits.extend(
resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
.with_context(|| format!("Diff:\n{diff_str}"))?,
);
}
DiffEvent::FileEnd { renamed_to } => {
let (buffer, _) = edited_buffer
.take()
.context("Got a FileEnd event before an Hunk event")?;
if renamed_to.is_some() {
anyhow::bail!("edit predictions cannot rename files");
}
if diff.next()?.is_some() {
anyhow::bail!("Edited more than one file");
}
return Ok((buffer, edits));
}
}
}
Err(anyhow::anyhow!("No EOF"))
}
#[derive(Debug)]
pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffer>>);
#[derive(Clone, Debug)]
pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
#[must_use]
pub async fn apply_diff<'a>(
diff_str: &'a str,
pub async fn apply_diff(
diff_str: &str,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'a>> {
) -> Result<OpenedBuffers> {
let mut included_files = HashMap::default();
let worktree_id = project.read_with(cx, |project, cx| {
anyhow::Ok(
project
.visible_worktrees(cx)
.next()
.context("no worktrees")?
.read(cx)
.id(),
)
})??;
for line in diff_str.lines() {
let diff_line = DiffLine::parse(line);
if let DiffLine::OldPath { path } = diff_line {
let buffer = project
.update(cx, |project, cx| {
let project_path =
project
.find_project_path(path.as_ref(), cx)
.with_context(|| {
format!("Failed to find worktree for new path: {}", path)
})?;
let project_path = ProjectPath {
worktree_id,
path: RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc(),
};
anyhow::Ok(project.open_buffer(project_path, cx))
})??
.await?;
included_files.insert(path, buffer);
included_files.insert(path.to_string(), buffer);
}
}
@@ -113,7 +74,7 @@ pub async fn apply_diff<'a>(
let (buffer, ranges) = match current_file {
None => {
let buffer = included_files
.get_mut(&file_path)
.get_mut(file_path.as_ref())
.expect("Opened all files in diff");
current_file = Some((buffer, ranges.as_slice()));
@@ -167,6 +128,29 @@ pub async fn apply_diff<'a>(
Ok(OpenedBuffers(included_files))
}
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
let mut diff = DiffParser::new(diff_str);
let mut text = text.to_string();
while let Some(event) = diff.next()? {
match event {
DiffEvent::Hunk { hunk, .. } => {
let hunk_offset = text
.find(&hunk.context)
.ok_or_else(|| anyhow!("couldn't resolve hunk {:?}", hunk.context))?;
for edit in hunk.edits.iter().rev() {
let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
text.replace_range(range, &edit.text);
}
}
DiffEvent::FileEnd { .. } => {}
}
}
Ok(text)
}
struct PatchFile<'a> {
old_path: Cow<'a, str>,
new_path: Cow<'a, str>,
@@ -492,7 +476,6 @@ mod tests {
use super::*;
use gpui::TestAppContext;
use indoc::indoc;
use language::Point;
use pretty_assertions::assert_eq;
use project::{FakeFs, Project};
use serde_json::json;
@@ -754,38 +737,38 @@ mod tests {
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
--- a/file1
+++ b/file1
one
two
-three
+3
four
five
--- a/root/file1
+++ b/root/file1
--- a/file1
+++ b/file1
3
-four
-five
+4
+5
--- a/root/file1
+++ b/root/file1
--- a/file1
+++ b/file1
-one
-two
3
4
--- a/root/file2
+++ b/root/file2
--- a/file2
+++ b/file2
+5
six
--- a/root/file2
+++ b/root/file2
--- a/file2
+++ b/file2
seven
+7.5
eight
--- a/root/file2
+++ b/root/file2
--- a/file2
+++ b/file2
ten
+11
"#};
@@ -817,137 +800,6 @@ mod tests {
});
}
#[gpui::test]
async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
let fs = init_test(cx);
let buffer_1_text = indoc! {r#"
one
two
three
four
five
one
two
three
four
five
"# };
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/file1"), cx)
})
.await
.unwrap();
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
one
two
-three
+3
four
five
"#};
let final_text = indoc! {r#"
one
two
three
four
five
one
two
3
four
five
"#};
apply_diff(diff, &project, &mut cx.to_async())
.await
.expect_err("Non-unique edits should fail");
let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
.await
.unwrap();
assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
buffer.update(cx, |buffer, cx| {
buffer.edit(edits, None, cx);
assert_eq!(buffer.text(), final_text);
});
}
#[gpui::test]
async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
let fs = init_test(cx);
let buffer_1_text = indoc! {r#"
one two three four
five six seven eight
nine ten eleven twelve
"# };
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/file1"), cx)
})
.await
.unwrap();
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
one two three four
-five six seven eight
+five SIX seven eight!
nine ten eleven twelve
"#};
let (buffer, edits) = parse_diff(diff, |_path| {
Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
})
.await
.unwrap();
let edits = edits
.into_iter()
.map(|(range, text)| (range.to_point(&buffer), text))
.collect::<Vec<_>>();
assert_eq!(
edits,
&[
(Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
(Point::new(1, 20)..Point::new(1, 20), "!".into())
]
);
}
#[gpui::test]
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
let fs = init_test(cx);
@@ -985,8 +837,8 @@ mod tests {
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
--- a/file1
+++ b/file1
one
two
-three

View File

@@ -1,637 +0,0 @@
use anyhow::{Context as _, Result};
use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
use std::{cmp, ops::Range, path::Path, sync::Arc};
const EDITS_TAG_NAME: &'static str = "edits";
const OLD_TEXT_TAG_NAME: &'static str = "old_text";
const NEW_TEXT_TAG_NAME: &'static str = "new_text";
const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
pub async fn parse_xml_edits<'a>(
input: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
parse_xml_edits_inner(input, get_buffer)
.await
.with_context(|| format!("Failed to parse XML edits:\n{input}"))
}
async fn parse_xml_edits_inner<'a>(
input: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
let xml_edits = extract_xml_replacements(input)?;
let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
.with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
let mut all_edits = vec![];
for (old_text, new_text) in xml_edits.replacements {
let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
let matched_old_text = buffer
.text_for_range(match_range.clone())
.collect::<String>();
let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
all_edits.extend(
edits_within_hunk
.into_iter()
.map(move |(inner_range, inner_text)| {
(
buffer.anchor_after(match_range.start + inner_range.start)
..buffer.anchor_before(match_range.start + inner_range.end),
inner_text,
)
}),
);
}
Ok((buffer, all_edits))
}
fn fuzzy_match_in_ranges(
old_text: &str,
buffer: &BufferSnapshot,
context_ranges: &[Range<Anchor>],
) -> Result<Range<usize>> {
let mut state = FuzzyMatcher::new(buffer, old_text);
let mut best_match = None;
let mut tie_match_range = None;
for range in context_ranges {
let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
match (best_match_cost, state.match_range(range.to_offset(buffer))) {
(Some(lowest_cost), Some((new_cost, new_range))) => {
if new_cost == lowest_cost {
tie_match_range = Some(new_range);
} else if new_cost < lowest_cost {
tie_match_range.take();
best_match = Some((new_cost, new_range));
}
}
(None, Some(new_match)) => {
best_match = Some(new_match);
}
(None, None) | (Some(_), None) => {}
};
}
if let Some((_, best_match_range)) = best_match {
if let Some(tie_match_range) = tie_match_range {
anyhow::bail!(
"Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
best_match_range.clone(),
buffer.text_for_range(best_match_range).collect::<String>(),
tie_match_range.clone(),
buffer.text_for_range(tie_match_range).collect::<String>()
);
}
return Ok(best_match_range);
}
anyhow::bail!(
"Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
old_text,
context_ranges
.iter()
.map(|range| buffer.text_for_range(range.clone()).collect::<String>())
.collect::<Vec<String>>()
.join("```\n```")
);
}
#[derive(Debug)]
struct XmlEdits<'a> {
file_path: &'a str,
/// Vec of (old_text, new_text) pairs
replacements: Vec<(&'a str, &'a str)>,
}
fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
let mut cursor = 0;
let (edits_body_start, edits_attrs) =
find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
let file_path = edits_attrs
.trim_start()
.strip_prefix("path")
.context("no path attribute on edits tag")?
.trim_end()
.strip_prefix('=')
.context("no value for path attribute")?
.trim()
.trim_start_matches('"')
.trim_end_matches('"');
cursor = edits_body_start;
let mut edits_list = Vec::new();
while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
let old_body_end = find_tag_close(input, &mut cursor)?;
let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
.context("no new_text tag following old_text")?;
let new_body_end = find_tag_close(input, &mut cursor)?;
let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
edits_list.push((old_text, new_text));
}
Ok(XmlEdits {
file_path,
replacements: edits_list,
})
}
/// Trims a single leading and trailing newline
fn trim_surrounding_newlines(input: &str) -> &str {
let start = input.strip_prefix('\n').unwrap_or(input);
let end = start.strip_suffix('\n').unwrap_or(start);
end
}
fn find_tag_open<'a>(
input: &'a str,
cursor: &mut usize,
expected_tag: &str,
) -> Result<Option<(usize, &'a str)>> {
let mut search_pos = *cursor;
while search_pos < input.len() {
let Some(tag_start) = input[search_pos..].find("<") else {
break;
};
let tag_start = search_pos + tag_start;
if !input[tag_start + 1..].starts_with(expected_tag) {
search_pos = search_pos + tag_start + 1;
continue;
};
let after_tag_name = tag_start + expected_tag.len() + 1;
let close_bracket = input[after_tag_name..]
.find('>')
.with_context(|| format!("missing > after <{}", expected_tag))?;
let attrs_end = after_tag_name + close_bracket;
let body_start = attrs_end + 1;
let attributes = input[after_tag_name..attrs_end].trim();
*cursor = body_start;
return Ok(Some((body_start, attributes)));
}
Ok(None)
}
fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
let mut depth = 1;
let mut search_pos = *cursor;
while search_pos < input.len() && depth > 0 {
let Some(bracket_offset) = input[search_pos..].find('<') else {
break;
};
let bracket_pos = search_pos + bracket_offset;
if input[bracket_pos..].starts_with("</")
&& let Some(close_end) = input[bracket_pos + 2..].find('>')
{
let close_start = bracket_pos + 2;
let tag_name = input[close_start..close_start + close_end].trim();
if XML_TAGS.contains(&tag_name) {
depth -= 1;
if depth == 0 {
*cursor = close_start + close_end + 1;
return Ok(bracket_pos);
}
}
search_pos = close_start + close_end + 1;
continue;
} else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
let close_bracket_pos = bracket_pos + close_bracket_offset;
let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
if XML_TAGS.contains(&tag_name) {
depth += 1;
}
}
search_pos = bracket_pos + 1;
}
anyhow::bail!("no closing tag found")
}
const REPLACEMENT_COST: u32 = 1;
const INSERTION_COST: u32 = 3;
const DELETION_COST: u32 = 10;
/// A fuzzy matcher that can process text chunks incrementally
/// and return the best match found so far at each step.
struct FuzzyMatcher<'a> {
snapshot: &'a BufferSnapshot,
query_lines: Vec<&'a str>,
matrix: SearchMatrix,
}
impl<'a> FuzzyMatcher<'a> {
fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
let query_lines = old_text.lines().collect();
Self {
snapshot,
query_lines,
matrix: SearchMatrix::new(0),
}
}
fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
let point_range = range.to_point(&self.snapshot);
let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
self.matrix
.reset(self.query_lines.len() + 1, buffer_line_count + 1);
let query_line_count = self.query_lines.len();
for row in 0..query_line_count {
let query_line = self.query_lines[row].trim();
let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
self.matrix.set(
row + 1,
0,
SearchState::new(leading_deletion_cost, SearchDirection::Up),
);
let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
let mut col = 0;
while let Some(buffer_line) = buffer_lines.next() {
let buffer_line = buffer_line.trim();
let up = SearchState::new(
self.matrix
.get(row, col + 1)
.cost
.saturating_add(DELETION_COST),
SearchDirection::Up,
);
let left = SearchState::new(
self.matrix
.get(row + 1, col)
.cost
.saturating_add(INSERTION_COST),
SearchDirection::Left,
);
let diagonal = SearchState::new(
if query_line == buffer_line {
self.matrix.get(row, col).cost
} else if fuzzy_eq(query_line, buffer_line) {
self.matrix.get(row, col).cost + REPLACEMENT_COST
} else {
self.matrix
.get(row, col)
.cost
.saturating_add(DELETION_COST + INSERTION_COST)
},
SearchDirection::Diagonal,
);
self.matrix
.set(row + 1, col + 1, up.min(left).min(diagonal));
col += 1;
}
}
// Find all matches with the best cost
let mut best_cost = u32::MAX;
let mut matches_with_best_cost = Vec::new();
for col in 1..=buffer_line_count {
let cost = self.matrix.get(query_line_count, col).cost;
if cost < best_cost {
best_cost = cost;
matches_with_best_cost.clear();
matches_with_best_cost.push(col as u32);
} else if cost == best_cost {
matches_with_best_cost.push(col as u32);
}
}
// Find ranges for the matches
for &match_end_col in &matches_with_best_cost {
let mut matched_lines = 0;
let mut query_row = query_line_count;
let mut match_start_col = match_end_col;
while query_row > 0 && match_start_col > 0 {
let current = self.matrix.get(query_row, match_start_col as usize);
match current.direction {
SearchDirection::Diagonal => {
query_row -= 1;
match_start_col -= 1;
matched_lines += 1;
}
SearchDirection::Up => {
query_row -= 1;
}
SearchDirection::Left => {
match_start_col -= 1;
}
}
}
let buffer_row_start = match_start_col + point_range.start.row;
let buffer_row_end = match_end_col + point_range.start.row;
let matched_buffer_row_count = buffer_row_end - buffer_row_start;
let matched_ratio = matched_lines as f32
/ (matched_buffer_row_count as f32).max(query_line_count as f32);
if matched_ratio >= 0.8 {
let buffer_start_ix = self
.snapshot
.point_to_offset(Point::new(buffer_row_start, 0));
let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
buffer_row_end - 1,
self.snapshot.line_len(buffer_row_end - 1),
));
return Some((best_cost, buffer_start_ix..buffer_end_ix));
}
}
None
}
}
fn fuzzy_eq(left: &str, right: &str) -> bool {
const THRESHOLD: f64 = 0.8;
let min_levenshtein = left.len().abs_diff(right.len());
let min_normalized_levenshtein =
1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
if min_normalized_levenshtein < THRESHOLD {
return false;
}
strsim::normalized_levenshtein(left, right) >= THRESHOLD
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum SearchDirection {
Up,
Left,
Diagonal,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SearchState {
cost: u32,
direction: SearchDirection,
}
impl SearchState {
fn new(cost: u32, direction: SearchDirection) -> Self {
Self { cost, direction }
}
}
struct SearchMatrix {
cols: usize,
rows: usize,
data: Vec<SearchState>,
}
impl SearchMatrix {
fn new(cols: usize) -> Self {
SearchMatrix {
cols,
rows: 0,
data: Vec::new(),
}
}
fn reset(&mut self, rows: usize, cols: usize) {
self.rows = rows;
self.cols = cols;
self.data
.fill(SearchState::new(0, SearchDirection::Diagonal));
self.data.resize(
self.rows * self.cols,
SearchState::new(0, SearchDirection::Diagonal),
);
}
fn get(&self, row: usize, col: usize) -> SearchState {
debug_assert!(row < self.rows);
debug_assert!(col < self.cols);
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, state: SearchState) {
debug_assert!(row < self.rows && col < self.cols);
self.data[row * self.cols + col] = state;
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
use indoc::indoc;
use language::Point;
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
#[test]
fn test_extract_xml_edits() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
old content
</old_text>
<new_text>
new content
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "old content");
assert_eq!(result.replacements[0].1, "new content");
}
#[test]
fn test_extract_xml_edits_with_wrong_closing_tags() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
old content
</new_text>
<new_text>
new content
</old_text>
</ edits >
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "old content");
assert_eq!(result.replacements[0].1, "new content");
}
#[test]
fn test_extract_xml_edits_with_xml_like_content() {
let input = indoc! {r#"
<edits path="component.tsx">
<old_text>
<foo><bar></bar></foo>
</old_text>
<new_text>
<foo><bar><baz></baz></bar></foo>
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "component.tsx");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
assert_eq!(
result.replacements[0].1,
"<foo><bar><baz></baz></bar></foo>"
);
}
#[test]
fn test_extract_xml_edits_with_conflicting_content() {
let input = indoc! {r#"
<edits path="component.tsx">
<old_text>
<new_text></new_text>
</old_text>
<new_text>
<old_text></old_text>
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "component.tsx");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "<new_text></new_text>");
assert_eq!(result.replacements[0].1, "<old_text></old_text>");
}
#[test]
fn test_extract_xml_edits_multiple_pairs() {
let input = indoc! {r#"
Some reasoning before edits. Lots of thinking going on here
<edits path="test.rs">
<old_text>
first old
</old_text>
<new_text>
first new
</new_text>
<old_text>
second old
</edits>
<new_text>
second new
</old_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 2);
assert_eq!(result.replacements[0].0, "first old");
assert_eq!(result.replacements[0].1, "first new");
assert_eq!(result.replacements[1].0, "second old");
assert_eq!(result.replacements[1].1, "second new");
}
#[test]
fn test_extract_xml_edits_unexpected_eof() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
first old
</
"#};
extract_xml_replacements(input).expect_err("Unexpected end of file");
}
#[gpui::test]
async fn test_parse_xml_edits(cx: &mut TestAppContext) {
let fs = init_test(cx);
let buffer_1_text = indoc! {r#"
one two three four
five six seven eight
nine ten eleven twelve
thirteen fourteen fifteen
sixteen seventeen eighteen
"#};
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/file1"), cx)
})
.await
.unwrap();
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let edits = indoc! {r#"
<edits path="root/file1">
<old_text>
nine ten eleven twelve
</old_text>
<new_text>
nine TEN eleven twelve!
</new_text>
</edits>
"#};
let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
let (buffer, edits) = parse_xml_edits(edits, |_path| {
Some((&buffer_snapshot, included_ranges.as_slice()))
})
.await
.unwrap();
let edits = edits
.into_iter()
.map(|(range, text)| (range.to_point(&buffer), text))
.collect::<Vec<_>>();
assert_eq!(
edits,
&[
(Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
(Point::new(2, 22)..Point::new(2, 22), "!".into())
]
);
}
fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
});
FakeFs::new(cx.background_executor.clone())
}
}

View File

@@ -125,14 +125,15 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
return;
}
if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
&& let BufferEditPrediction::Local { prediction } = current
&& prediction.interpolate(buffer.read(cx)).is_some()
{
return;
}
self.store.update(cx, |store, cx| {
if let Some(current) =
store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
&& let BufferEditPrediction::Local { prediction } = current
&& prediction.interpolate(buffer.read(cx)).is_some()
{
return;
}
store.refresh_context(&self.project, &buffer, cursor_position, cx);
store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
});
@@ -171,69 +172,68 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
cursor_position: language::Anchor,
cx: &mut Context<Self>,
) -> Option<edit_prediction_types::EditPrediction> {
let prediction =
self.store
.read(cx)
.current_prediction_for_buffer(buffer, &self.project, cx)?;
self.store.update(cx, |store, cx| {
let prediction =
store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
let prediction = match prediction {
BufferEditPrediction::Local { prediction } => prediction,
BufferEditPrediction::Jump { prediction } => {
return Some(edit_prediction_types::EditPrediction::Jump {
id: Some(prediction.id.to_string().into()),
snapshot: prediction.snapshot.clone(),
target: prediction.edits.first().unwrap().0.start,
});
}
};
let prediction = match prediction {
BufferEditPrediction::Local { prediction } => prediction,
BufferEditPrediction::Jump { prediction } => {
return Some(edit_prediction_types::EditPrediction::Jump {
id: Some(prediction.id.to_string().into()),
snapshot: prediction.snapshot.clone(),
target: prediction.edits.first().unwrap().0.start,
});
}
};
let buffer = buffer.read(cx);
let snapshot = buffer.snapshot();
let buffer = buffer.read(cx);
let snapshot = buffer.snapshot();
let Some(edits) = prediction.interpolate(&snapshot) else {
self.store.update(cx, |store, _cx| {
let Some(edits) = prediction.interpolate(&snapshot) else {
store.reject_current_prediction(
EditPredictionRejectReason::InterpolatedEmpty,
&self.project,
);
});
return None;
};
return None;
};
let cursor_row = cursor_position.to_point(&snapshot).row;
let (closest_edit_ix, (closest_edit_range, _)) =
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
cmp::min(distance_from_start, distance_from_end)
})?;
let cursor_row = cursor_position.to_point(&snapshot).row;
let (closest_edit_ix, (closest_edit_range, _)) =
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
let distance_from_start =
cursor_row.abs_diff(range.start.to_point(&snapshot).row);
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
cmp::min(distance_from_start, distance_from_end)
})?;
let mut edit_start_ix = closest_edit_ix;
for (range, _) in edits[..edit_start_ix].iter().rev() {
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
- range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_start_ix -= 1;
} else {
break;
let mut edit_start_ix = closest_edit_ix;
for (range, _) in edits[..edit_start_ix].iter().rev() {
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
- range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_start_ix -= 1;
} else {
break;
}
}
}
let mut edit_end_ix = closest_edit_ix + 1;
for (range, _) in &edits[edit_end_ix..] {
let distance_from_closest_edit =
range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_end_ix += 1;
} else {
break;
let mut edit_end_ix = closest_edit_ix + 1;
for (range, _) in &edits[edit_end_ix..] {
let distance_from_closest_edit = range.start.to_point(buffer).row
- closest_edit_range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_end_ix += 1;
} else {
break;
}
}
}
Some(edit_prediction_types::EditPrediction::Local {
id: Some(prediction.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(prediction.edit_preview.clone()),
Some(edit_prediction_types::EditPrediction::Local {
id: Some(prediction.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(prediction.edit_preview.clone()),
})
})
}
}

View File

@@ -1,22 +1,23 @@
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
prediction::{EditPredictionInputs, EditPredictionResult},
prediction::EditPredictionResult,
};
use anyhow::{Context as _, Result};
use cloud_llm_client::{
PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
predict_edits_v3::Event,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
use language::{
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
};
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
use zeta_prompt::{Event, ZetaPromptInput};
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
@@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500;
pub(crate) fn request_prediction_with_zeta1(
store: &mut EditPredictionStore,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
trigger: PredictEditsRequestTrigger,
EditPredictionModelInput {
project,
buffer,
snapshot,
position,
events,
trigger,
debug_tx,
..
}: EditPredictionModelInput,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
let can_collect_file = store.can_collect_file(project, file, cx);
let can_collect_file = store.can_collect_file(&project, file, cx);
let git_info = if can_collect_file {
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
} else {
None
};
@@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1(
)
.await;
let inputs = EditPredictionInputs {
let context_start_offset = context_range.start.to_offset(&snapshot);
let editable_offset_range = editable_range.to_offset(&snapshot);
let inputs = ZetaPromptInput {
events: included_events.into(),
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
text: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
related_files: vec![].into(),
cursor_path: full_path,
cursor_excerpt: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
editable_range_in_excerpt: (editable_range.start - context_start_offset)
..(editable_offset_range.end - context_start_offset),
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
};
// let response = perform_predict_edits(PerformPredictEditsParams {
// client,
// llm_token,
// app_version,
// body,
// })
// .await;
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: buffer.downgrade(),
prompt: Some(serde_json::to_string(&inputs).unwrap()),
position,
},
))
.ok();
}
let (response, usage) = match response {
Ok(response) => response,
@@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1(
.ok();
}
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
EditPredictionFinishedDebugEvent {
buffer: buffer.downgrade(),
model_output: Some(response.output_excerpt.clone()),
position,
},
))
.ok();
}
let edit_prediction = process_completion_response(
response,
buffer,
@@ -226,7 +242,7 @@ fn process_completion_response(
buffer: Entity<Buffer>,
snapshot: &BufferSnapshot,
editable_range: Range<usize>,
inputs: EditPredictionInputs,
inputs: ZetaPromptInput,
buffer_snapshotted_at: Instant,
received_response_at: Instant,
cx: &AsyncApp,

View File

@@ -1,48 +1,41 @@
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
use crate::EvalCacheEntryKind;
use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
use crate::{
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
EditPredictionRequestedDebugEvent, EditPredictionStore,
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
};
use anyhow::{Result, anyhow, bail};
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
use cloud_zeta2_prompt::CURSOR_MARKER;
use edit_prediction_context::{EditPredictionExcerpt, Line};
use edit_prediction_context::{RelatedExcerpt, RelatedFile};
use futures::channel::oneshot;
use gpui::{Entity, Task, prelude::*};
use language::{Anchor, BufferSnapshot};
use language::{Buffer, Point, ToOffset as _, ToPoint};
use project::{Project, ProjectItem as _};
use anyhow::{Result, anyhow};
use cloud_llm_client::EditPredictionRejectReason;
use gpui::{Task, prelude::*};
use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
use release_channel::AppVersion;
use std::{
env,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
use std::{path::Path, sync::Arc, time::Instant};
use zeta_prompt::CURSOR_MARKER;
use zeta_prompt::format_zeta_prompt;
const MAX_CONTEXT_TOKENS: usize = 150;
const MAX_REWRITE_TOKENS: usize = 350;
pub fn request_prediction_with_zeta2(
store: &mut EditPredictionStore,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
active_snapshot: BufferSnapshot,
position: Anchor,
events: Vec<Arc<Event>>,
mut included_files: Vec<RelatedFile>,
trigger: PredictEditsRequestTrigger,
EditPredictionModelInput {
buffer,
snapshot,
position,
related_files,
events,
debug_tx,
..
}: EditPredictionModelInput,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let options = store.options.clone();
let buffer_snapshotted_at = Instant::now();
let Some((excerpt_path, active_project_path)) = active_snapshot
let Some(excerpt_path) = snapshot
.file()
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
.zip(active_buffer.read(cx).project_path(cx))
else {
return Task::ready(Err(anyhow!("No file path for excerpt")));
};
@@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2(
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let debug_tx = store.debug_tx.clone();
let file = active_buffer.read(cx).file();
let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
// TODO data collection
let can_collect_data = file
.as_ref()
.map_or(false, |file| store.can_collect_file(project, file, cx));
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
let eval_cache = store.eval_cache.clone();
let request_task = cx.background_spawn({
let active_buffer = active_buffer.clone();
async move {
let cursor_offset = position.to_offset(&active_snapshot);
let cursor_point = cursor_offset.to_point(&active_snapshot);
let before_retrieval = Instant::now();
let excerpt_options = options.context;
let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
cursor_point,
&active_snapshot,
&excerpt_options,
) else {
return Ok((None, None));
};
let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
..active_snapshot.anchor_before(excerpt.range.end);
let related_excerpt = RelatedExcerpt {
anchor_range: excerpt_anchor_range.clone(),
point_range: Point::new(excerpt.line_range.start.0, 0)
..Point::new(excerpt.line_range.end.0, 0),
text: active_snapshot.as_rope().slice(excerpt.range),
};
if let Some(buffer_ix) = included_files
.iter()
.position(|file| file.buffer.entity_id() == active_buffer.entity_id())
{
let file = &mut included_files[buffer_ix];
file.excerpts.push(related_excerpt);
file.merge_excerpts();
let last_ix = included_files.len() - 1;
included_files.swap(buffer_ix, last_ix);
} else {
let active_file = RelatedFile {
path: active_project_path,
buffer: active_buffer.downgrade(),
excerpts: vec![related_excerpt],
max_row: active_snapshot.max_point().row,
};
included_files.push(active_file);
}
let included_files = included_files
.iter()
.map(|related_file| predict_edits_v3::RelatedFile {
path: Arc::from(related_file.path.path.as_std_path()),
max_row: Line(related_file.max_row),
excerpts: related_file
.excerpts
.iter()
.map(|excerpt| predict_edits_v3::Excerpt {
start_line: Line(excerpt.point_range.start.row),
text: excerpt.text.to_string().into(),
})
.collect(),
})
.collect::<Vec<_>>();
let cloud_request = predict_edits_v3::PredictEditsRequest {
excerpt_path,
excerpt: String::new(),
excerpt_line_range: Line(0)..Line(0),
excerpt_range: 0..0,
cursor_point: predict_edits_v3::Point {
line: predict_edits_v3::Line(cursor_point.row),
column: cursor_point.column,
},
related_files: included_files,
let cursor_offset = position.to_offset(&snapshot);
let (editable_offset_range, prompt_input) = zeta2_prompt_input(
&snapshot,
related_files,
events,
can_collect_data,
debug_info: debug_tx.is_some(),
prompt_max_bytes: Some(options.max_prompt_bytes),
prompt_format: options.prompt_format,
excerpt_parent: None,
git_info: None,
trigger,
};
excerpt_path,
cursor_offset,
);
let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
let inputs = EditPredictionInputs {
included_files: cloud_request.related_files,
events: cloud_request.events,
cursor_point: cloud_request.cursor_point,
cursor_path: cloud_request.excerpt_path,
};
let retrieval_time = Instant::now() - before_retrieval;
let debug_response_tx = if let Some(debug_tx) = &debug_tx {
let (response_tx, response_rx) = oneshot::channel();
let prompt = format_zeta_prompt(&prompt_input);
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionRequested(
EditPredictionRequestedDebugEvent {
inputs: inputs.clone(),
retrieval_time,
buffer: active_buffer.downgrade(),
local_prompt: match prompt_result.as_ref() {
Ok(prompt) => Ok(prompt.clone()),
Err(err) => Err(err.to_string()),
},
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: buffer.downgrade(),
prompt: Some(prompt.clone()),
position,
response_rx,
},
))
.ok();
Some(response_tx)
} else {
None
};
if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
if let Some(debug_response_tx) = debug_response_tx {
debug_response_tx
.send((Err("Request skipped".to_string()), Duration::ZERO))
.ok();
}
anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
}
let prompt = prompt_result?;
let generation_params =
cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
let request = open_ai::Request {
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
messages: vec![open_ai::RequestMessage::User {
@@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2(
}],
stream: false,
max_completion_tokens: None,
stop: generation_params.stop.unwrap_or_default(),
temperature: generation_params.temperature.or(Some(0.7)),
stop: Default::default(),
temperature: Default::default(),
tool_choice: None,
parallel_tool_calls: None,
tools: vec![],
@@ -210,81 +90,65 @@ pub fn request_prediction_with_zeta2(
log::trace!("Sending edit prediction request");
let before_request = Instant::now();
let response = EditPredictionStore::send_raw_llm_request(
request,
client,
llm_token,
app_version,
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
eval_cache,
#[cfg(feature = "eval-support")]
#[cfg(feature = "cli-support")]
EvalCacheEntryKind::Prediction,
)
.await;
let received_response_at = Instant::now();
let request_time = received_response_at - before_request;
log::trace!("Got edit prediction response");
if let Some(debug_response_tx) = debug_response_tx {
debug_response_tx
.send((
response
.as_ref()
.map_err(|err| err.to_string())
.map(|response| response.0.clone()),
request_time,
))
.ok();
}
let (res, usage) = response?;
let request_id = EditPredictionId(res.id.clone().into());
let Some(mut output_text) = text_from_response(res) else {
return Ok((Some((request_id, None)), usage));
};
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
EditPredictionFinishedDebugEvent {
buffer: buffer.downgrade(),
position,
model_output: Some(output_text.clone()),
},
))
.ok();
}
if output_text.contains(CURSOR_MARKER) {
log::trace!("Stripping out {CURSOR_MARKER} from response");
output_text = output_text.replace(CURSOR_MARKER, "");
}
let get_buffer_from_context = |path: &Path| {
if Some(path) == active_file_full_path.as_deref() {
Some((
&active_snapshot,
std::slice::from_ref(&excerpt_anchor_range),
))
} else {
None
}
};
let (_, edits) = match options.prompt_format {
PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
if output_text.contains("--- a/\n+++ b/\nNo edits") {
let edits = vec![];
(&active_snapshot, edits)
} else {
crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
}
}
PromptFormat::OldTextNewText => {
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
}
_ => {
bail!("unsupported prompt format {}", options.prompt_format)
}
};
let old_text = snapshot
.text_for_range(editable_offset_range.clone())
.collect::<String>();
let edits: Vec<_> = language::text_diff(&old_text, &output_text)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(editable_offset_range.start + range.start)
..snapshot.anchor_before(editable_offset_range.start + range.end),
text,
)
})
.collect();
anyhow::Ok((
Some((
request_id,
Some((
inputs,
active_buffer,
active_snapshot.clone(),
prompt_input,
buffer,
snapshot.clone(),
edits,
received_response_at,
)),
@@ -325,3 +189,52 @@ pub fn request_prediction_with_zeta2(
))
})
}
pub fn zeta2_prompt_input(
snapshot: &language::BufferSnapshot,
related_files: Arc<[zeta_prompt::RelatedFile]>,
events: Vec<Arc<zeta_prompt::Event>>,
excerpt_path: Arc<Path>,
cursor_offset: usize,
) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
let cursor_point = cursor_offset.to_point(snapshot);
let (editable_range, context_range) =
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
cursor_point,
snapshot,
MAX_CONTEXT_TOKENS,
MAX_REWRITE_TOKENS,
);
let context_start_offset = context_range.start.to_offset(snapshot);
let editable_offset_range = editable_range.to_offset(snapshot);
let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
..(editable_offset_range.end - context_start_offset);
let prompt_input = zeta_prompt::ZetaPromptInput {
cursor_path: excerpt_path,
cursor_excerpt: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
editable_range_in_excerpt,
cursor_offset_in_excerpt,
events,
related_files,
};
(editable_offset_range, prompt_input)
}
#[cfg(feature = "cli-support")]
pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> String {
eprintln!("{}", patch);
eprintln!("---------------------");
eprintln!("{}", input.cursor_excerpt);
crate::udiff::apply_diff_to_string(
patch,
&input.cursor_excerpt[input.editable_range_in_excerpt.clone()],
)
.unwrap()
}

View File

@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
workspace = true
[[bin]]
name = "ep_cli"
name = "ep"
path = "src/main.rs"
[dependencies]
@@ -20,10 +20,9 @@ chrono.workspace = true
clap.workspace = true
client.workspace = true
cloud_llm_client.workspace= true
cloud_zeta2_prompt.workspace = true
collections.workspace = true
debug_adapter_extension.workspace = true
edit_prediction_context.workspace = true
dirs.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
@@ -35,6 +34,7 @@ language_extension.workspace = true
language_model.workspace = true
language_models.workspace = true
languages = { workspace = true, features = ["load-grammars"] }
libc.workspace = true
log.workspace = true
node_runtime.workspace = true
paths.workspace = true
@@ -51,11 +51,19 @@ smol.workspace = true
sqlez.workspace = true
sqlez_macros.workspace = true
terminal_view.workspace = true
toml.workspace = true
util.workspace = true
watch.workspace = true
edit_prediction = { workspace = true, features = ["eval-support"] }
zlog.workspace = true
edit_prediction = { workspace = true, features = ["cli-support"] }
wasmtime.workspace = true
zeta_prompt.workspace = true
# Wasmtime is included as a dependency in order to enable the same
# features that are enabled in Zed.
#
# If we don't enable these features we get crashes when creating
# a Tree-sitter WasmStore.
[package.metadata.cargo-machete]
ignored = ["wasmtime"]
[dev-dependencies]
indoc.workspace = true

View File

@@ -5,11 +5,13 @@ use anthropic::{
use anyhow::Result;
use http_client::HttpClient;
use indoc::indoc;
use reqwest_client::ReqwestClient;
use sqlez::bindable::Bind;
use sqlez::bindable::StaticColumnCount;
use sqlez_macros::sql;
use std::hash::Hash;
use std::hash::Hasher;
use std::path::Path;
use std::sync::Arc;
pub struct PlainLlmClient {
@@ -18,7 +20,8 @@ pub struct PlainLlmClient {
}
impl PlainLlmClient {
fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
fn new() -> Result<Self> {
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
Ok(Self {
@@ -29,12 +32,12 @@ impl PlainLlmClient {
async fn generate(
&self,
model: String,
model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<AnthropicResponse> {
let request = AnthropicRequest {
model,
model: model.to_string(),
max_tokens,
messages,
tools: Vec::new(),
@@ -105,11 +108,12 @@ struct SerializableMessage {
}
impl BatchingLlmClient {
fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
fn new(cache_path: &Path) -> Result<Self> {
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
let connection = sqlez::connection::Connection::open_file(&cache_path);
let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
let mut statement = sqlez::statement::Statement::prepare(
&connection,
indoc! {"
@@ -182,16 +186,16 @@ impl BatchingLlmClient {
async fn generate(
&self,
model: String,
model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<Option<AnthropicResponse>> {
let response = self.lookup(&model, max_tokens, &messages)?;
let response = self.lookup(model, max_tokens, &messages)?;
if let Some(response) = response {
return Ok(Some(response));
}
self.mark_for_batch(&model, max_tokens, &messages)?;
self.mark_for_batch(model, max_tokens, &messages)?;
Ok(None)
}
@@ -258,7 +262,7 @@ impl BatchingLlmClient {
}
}
}
log::info!("Uploaded {} successful requests", success_count);
log::info!("Downloaded {} successful requests", success_count);
}
}
@@ -363,23 +367,20 @@ fn message_content_to_string(content: &[RequestContent]) -> String {
.join("\n")
}
pub enum LlmClient {
pub enum AnthropicClient {
// No batching
Plain(PlainLlmClient),
Batch(BatchingLlmClient),
Dummy,
}
impl LlmClient {
pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
Ok(Self::Plain(PlainLlmClient::new(http_client)?))
impl AnthropicClient {
pub fn plain() -> Result<Self> {
Ok(Self::Plain(PlainLlmClient::new()?))
}
pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
Ok(Self::Batch(BatchingLlmClient::new(
cache_path,
http_client,
)?))
pub fn batch(cache_path: &Path) -> Result<Self> {
Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
}
#[allow(dead_code)]
@@ -389,29 +390,29 @@ impl LlmClient {
pub async fn generate(
&self,
model: String,
model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<Option<AnthropicResponse>> {
match self {
LlmClient::Plain(plain_llm_client) => plain_llm_client
AnthropicClient::Plain(plain_llm_client) => plain_llm_client
.generate(model, max_tokens, messages)
.await
.map(Some),
LlmClient::Batch(batching_llm_client) => {
AnthropicClient::Batch(batching_llm_client) => {
batching_llm_client
.generate(model, max_tokens, messages)
.await
}
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
}
}
pub async fn sync_batches(&self) -> Result<()> {
match self {
LlmClient::Plain(_) => Ok(()),
LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
AnthropicClient::Plain(_) => Ok(()),
AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
}
}
}

View File

@@ -0,0 +1,22 @@
use anyhow::{Result, anyhow};
use std::mem;
use crate::example::Example;
pub async fn run_distill(example: &mut Example) -> Result<()> {
let [prediction]: [_; 1] =
mem::take(&mut example.predictions)
.try_into()
.map_err(|preds: Vec<_>| {
anyhow!(
"Example has {} predictions, but it should have exactly one",
preds.len()
)
})?;
example.expected_patch = prediction.actual_patch;
example.prompt = None;
example.predictions = Vec::new();
example.score = Vec::new();
Ok(())
}

View File

@@ -1,641 +0,0 @@
use crate::metrics::{self, Scores};
use std::{
collections::HashMap,
io::{IsTerminal, Write},
sync::Arc,
};
use anyhow::Result;
use edit_prediction::{EditPredictionStore, udiff::DiffLine};
use gpui::{AsyncApp, Entity};
use project::Project;
use util::ResultExt as _;
use crate::{
EvaluateArguments, PredictionOptions,
example::{Example, NamedExample},
headless::ZetaCliAppState,
paths::print_run_data_dir,
predict::{PredictionDetails, perform_predict, setup_store},
};
#[derive(Debug)]
pub(crate) struct ExecutionData {
execution_id: String,
diff: String,
reasoning: String,
}
pub async fn run_evaluate(
args: EvaluateArguments,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) {
if args.example_paths.is_empty() {
eprintln!("No examples provided");
return;
}
let all_tasks = args.example_paths.into_iter().map(|path| {
let options = args.options.clone();
let app_state = app_state.clone();
let example = NamedExample::load(&path).expect("Failed to load example");
cx.spawn(async move |cx| {
let project = example.setup_project(&app_state, cx).await.unwrap();
let providers = (0..args.repetitions)
.map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
.collect::<Vec<_>>();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
let tasks = providers
.into_iter()
.enumerate()
.map(move |(repetition_ix, store)| {
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
let example = example.clone();
let project = project.clone();
let options = options.clone();
cx.spawn(async move |cx| {
let name = example.name.clone();
run_evaluate_one(
example,
repetition_ix,
project,
store,
options,
!args.skip_prediction,
cx,
)
.await
.map_err(|err| (err, name, repetition_ix))
})
});
futures::future::join_all(tasks).await
})
});
let all_results = futures::future::join_all(all_tasks).await;
write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
if let Some(mut output_file) =
std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
{
write_aggregated_scores(&mut output_file, &all_results).log_err();
};
if args.repetitions > 1 {
if let Err(e) = write_bucketed_analysis(&all_results) {
eprintln!("Failed to write bucketed analysis: {:?}", e);
}
}
print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
}
fn write_aggregated_scores(
w: &mut impl std::io::Write,
all_results: &Vec<
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
>,
) -> Result<()> {
let mut successful = Vec::new();
let mut failed_count = 0;
for result in all_results.iter().flatten() {
match result {
Ok((eval_result, _execution_data)) => successful.push(eval_result),
Err((err, name, repetition_ix)) => {
if failed_count == 0 {
writeln!(w, "## Errors\n")?;
}
failed_count += 1;
writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
}
}
}
if successful.len() > 1 {
let edit_scores = successful
.iter()
.filter_map(|r| r.edit_scores.clone())
.collect::<Vec<_>>();
let has_edit_predictions = edit_scores.len() > 0;
let aggregated_result = EvaluationResult {
context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
/ successful.len(),
};
writeln!(w, "\n{}", "-".repeat(80))?;
writeln!(w, "\n## TOTAL SCORES")?;
writeln!(w, "{:#}", aggregated_result)?;
}
if successful.len() + failed_count > 1 {
writeln!(
w,
"\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
successful.len(),
successful.len() + failed_count,
(successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
)?;
}
Ok(())
}
pub async fn run_evaluate_one(
example: NamedExample,
repetition_ix: Option<u16>,
project: Entity<Project>,
store: Entity<EditPredictionStore>,
prediction_options: PredictionOptions,
predict: bool,
cx: &mut AsyncApp,
) -> Result<(EvaluationResult, ExecutionData)> {
let predict_result = perform_predict(
example.clone(),
project,
store,
repetition_ix,
prediction_options,
cx,
)
.await?;
let evaluation_result = evaluate(&example.example, &predict_result, predict);
if repetition_ix.is_none() {
write_eval_result(
&example,
&predict_result,
&evaluation_result,
&mut std::io::stdout(),
std::io::stdout().is_terminal(),
predict,
)?;
}
if let Some(mut results_file) =
std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
{
write_eval_result(
&example,
&predict_result,
&evaluation_result,
&mut results_file,
false,
predict,
)
.log_err();
}
let execution_data = ExecutionData {
execution_id: if let Some(rep_ix) = repetition_ix {
format!("{:03}", rep_ix)
} else {
example.name.clone()
},
diff: predict_result.diff.clone(),
reasoning: std::fs::read_to_string(
predict_result
.run_example_dir
.join("prediction_response.md"),
)
.unwrap_or_default(),
};
anyhow::Ok((evaluation_result, execution_data))
}
fn write_eval_result(
example: &NamedExample,
predictions: &PredictionDetails,
evaluation_result: &EvaluationResult,
out: &mut impl Write,
use_color: bool,
predict: bool,
) -> Result<()> {
if predict {
writeln!(
out,
"## Expected edit prediction:\n\n```diff\n{}\n```\n",
compare_diffs(
&example.example.expected_patch,
&predictions.diff,
use_color
)
)?;
writeln!(
out,
"## Actual edit prediction:\n\n```diff\n{}\n```\n",
compare_diffs(
&predictions.diff,
&example.example.expected_patch,
use_color
)
)?;
}
writeln!(out, "{:#}", evaluation_result)?;
anyhow::Ok(())
}
#[derive(Debug, Default, Clone)]
pub struct EditScores {
pub line_match: Scores,
pub chr_f: f64,
}
impl EditScores {
pub fn aggregate(scores: &[EditScores]) -> EditScores {
let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
EditScores { line_match, chr_f }
}
}
#[derive(Debug, Default)]
pub struct EvaluationResult {
pub edit_scores: Option<EditScores>,
pub context_scores: Scores,
pub prompt_len: usize,
pub generated_len: usize,
}
impl std::fmt::Display for EvaluationResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if f.alternate() {
self.fmt_table(f)
} else {
self.fmt_markdown(f)
}
}
}
impl EvaluationResult {
fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
r#"
### Context Scores
{}
"#,
self.context_scores.to_markdown(),
)?;
if let Some(scores) = &self.edit_scores {
write!(
f,
r#"
### Edit Prediction Scores
{}"#,
scores.line_match.to_markdown()
)?;
}
Ok(())
}
fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "#### Prompt Statistics")?;
writeln!(f, "─────────────────────────")?;
writeln!(f, "Prompt_len Generated_len")?;
writeln!(f, "─────────────────────────")?;
writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
writeln!(f)?;
writeln!(f)?;
writeln!(f, "#### Performance Scores")?;
writeln!(
f,
"──────────────────────────────────────────────────────────────────"
)?;
writeln!(
f,
" TP FP FN Precision Recall F1"
)?;
writeln!(
f,
"──────────────────────────────────────────────────────────────────"
)?;
writeln!(
f,
"Context Retrieval {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
self.context_scores.true_positives,
self.context_scores.false_positives,
self.context_scores.false_negatives,
self.context_scores.precision() * 100.0,
self.context_scores.recall() * 100.0,
self.context_scores.f1_score() * 100.0
)?;
if let Some(edit_scores) = &self.edit_scores {
let line_match = &edit_scores.line_match;
writeln!(f, "Edit Prediction")?;
writeln!(
f,
" ├─ exact lines {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
line_match.true_positives,
line_match.false_positives,
line_match.false_negatives,
line_match.precision() * 100.0,
line_match.recall() * 100.0,
line_match.f1_score() * 100.0
)?;
writeln!(
f,
" └─ diff chrF {:<6} {:<6} {:<6} {:>8} {:>8} {:>6.2}",
"-", "-", "-", "-", "-", edit_scores.chr_f
)?;
}
Ok(())
}
}
fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
let mut eval_result = EvaluationResult {
prompt_len: preds.prompt_len,
generated_len: preds.generated_len,
..Default::default()
};
if predict {
// todo: alternatives for patches
let expected_patch = example
.expected_patch
.lines()
.map(DiffLine::parse)
.collect::<Vec<_>>();
let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
eval_result.edit_scores = Some(EditScores { line_match, chr_f });
}
eval_result
}
/// Return annotated `patch_a` so that:
/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
/// Additions and deletions that are present in `patch_b` will be highlighted in green.
pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
let green = if use_color { "\x1b[32m✓ " } else { "" };
let red = if use_color { "\x1b[31m✗ " } else { "" };
let neutral = if use_color { " " } else { "" };
let reset = if use_color { "\x1b[0m" } else { "" };
let lines_a = patch_a.lines().map(DiffLine::parse);
let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
let annotated = lines_a
.map(|line| match line {
DiffLine::Addition(_) | DiffLine::Deletion(_) => {
if lines_b.contains(&line) {
format!("{green}{line}{reset}")
} else {
format!("{red}{line}{reset}")
}
}
_ => format!("{neutral}{line}{reset}"),
})
.collect::<Vec<String>>();
annotated.join("\n")
}
fn write_bucketed_analysis(
all_results: &Vec<
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
>,
) -> Result<()> {
#[derive(Debug)]
struct EditBucket {
diff: String,
is_correct: bool,
execution_indices: Vec<String>,
reasoning_samples: Vec<String>,
}
let mut total_executions = 0;
let mut empty_predictions = Vec::new();
let mut errors = Vec::new();
let mut buckets: HashMap<String, EditBucket> = HashMap::new();
for result in all_results.iter().flatten() {
total_executions += 1;
let (evaluation_result, execution_data) = match result {
Ok((eval_result, execution_data)) => {
if execution_data.diff.is_empty() {
empty_predictions.push(execution_data);
continue;
}
(eval_result, execution_data)
}
Err(err) => {
errors.push(err);
continue;
}
};
buckets
.entry(execution_data.diff.clone())
.and_modify(|bucket| {
bucket
.execution_indices
.push(execution_data.execution_id.clone());
bucket
.reasoning_samples
.push(execution_data.reasoning.clone());
})
.or_insert_with(|| EditBucket {
diff: execution_data.diff.clone(),
is_correct: {
evaluation_result
.edit_scores
.as_ref()
.map_or(false, |edit_scores| {
edit_scores.line_match.false_positives == 0
&& edit_scores.line_match.false_negatives == 0
&& edit_scores.line_match.true_positives > 0
})
},
execution_indices: vec![execution_data.execution_id.clone()],
reasoning_samples: vec![execution_data.reasoning.clone()],
});
}
let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => b.execution_indices.len().cmp(&a.execution_indices.len()),
});
let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
let mut output = std::fs::File::create(&output_path)?;
writeln!(output, "# Bucketed Edit Analysis\n")?;
writeln!(output, "## Summary\n")?;
writeln!(output, "- **Total executions**: {}", total_executions)?;
let correct_count: usize = sorted_buckets
.iter()
.filter(|b| b.is_correct)
.map(|b| b.execution_indices.len())
.sum();
let incorrect_count: usize = sorted_buckets
.iter()
.filter(|b| !b.is_correct)
.map(|b| b.execution_indices.len())
.sum();
writeln!(
output,
"- **Correct predictions**: {} ({:.1}%)",
correct_count,
(correct_count as f64 / total_executions as f64) * 100.0
)?;
writeln!(
output,
"- **Incorrect predictions**: {} ({:.1}%)",
incorrect_count,
(incorrect_count as f64 / total_executions as f64) * 100.0
)?;
writeln!(
output,
"- **No Predictions**: {} ({:.1}%)",
empty_predictions.len(),
(empty_predictions.len() as f64 / total_executions as f64) * 100.0
)?;
let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
writeln!(
output,
"- **Unique incorrect edit patterns**: {}\n",
unique_incorrect
)?;
writeln!(output, "---\n")?;
for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
if idx == 0 {
writeln!(
output,
"## Correct Predictions ({} occurrences)\n",
bucket.execution_indices.len()
)?;
}
writeln!(output, "**Predicted Edit:**\n")?;
writeln!(output, "```diff")?;
writeln!(output, "{}", bucket.diff)?;
writeln!(output, "```\n")?;
writeln!(
output,
"**Executions:** {}\n",
bucket.execution_indices.join(", ")
)?;
writeln!(output, "---\n")?;
}
for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
writeln!(
output,
"## Incorrect Prediction #{} ({} occurrences)\n",
idx + 1,
bucket.execution_indices.len()
)?;
writeln!(output, "**Predicted Edit:**\n")?;
writeln!(output, "```diff")?;
writeln!(output, "{}", bucket.diff)?;
writeln!(output, "```\n")?;
writeln!(
output,
"**Executions:** {}\n",
bucket.execution_indices.join(", ")
)?;
for (exec_id, reasoning) in bucket
.execution_indices
.iter()
.zip(bucket.reasoning_samples.iter())
{
writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
}
writeln!(output, "\n---\n")?;
}
if !empty_predictions.is_empty() {
writeln!(
output,
"## No Predictions ({} occurrences)\n",
empty_predictions.len()
)?;
for execution_data in &empty_predictions {
writeln!(
output,
"{}",
fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
)?;
}
writeln!(output, "\n---\n")?;
}
if !errors.is_empty() {
writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
for (err, name, repetition_ix) in &errors {
writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
}
writeln!(output, "\n---\n")?;
}
fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
let exec_content = format!(
"\n### Execution {} `{}/{}/prediction_response.md`{}",
exec_id,
crate::paths::RUN_DIR.display(),
exec_id,
indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
);
indent_text(&exec_content, 2)
}
fn indent_text(text: &str, spaces: usize) -> String {
let indent = " ".repeat(spaces);
text.lines()
.collect::<Vec<_>>()
.join(&format!("\n{}", indent))
}
Ok(())
}
fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
let err = format!("{err:?}")
.replace("<edits", "```xml\n<edits")
.replace("</edits>", "</edits>\n```");
format!(
"### ERROR {name}{}\n\n{err}\n",
repetition_ix
.map(|ix| format!(" [RUN {ix:03}]"))
.unwrap_or_default()
)
}

View File

@@ -1,63 +1,105 @@
use std::{
borrow::Cow,
cell::RefCell,
fmt::{self, Display},
fs,
hash::Hash,
hash::Hasher,
io::Write,
mem,
path::{Path, PathBuf},
sync::{Arc, OnceLock},
};
use crate::headless::ZetaCliAppState;
use anyhow::{Context as _, Result, anyhow};
use clap::ValueEnum;
use cloud_zeta2_prompt::CURSOR_MARKER;
use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::udiff::OpenedBuffers;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
};
use futures::{FutureExt as _, future::Shared};
use gpui::{AsyncApp, Entity, Task, http_client::Url};
use gpui::Entity;
use http_client::Url;
use language::{Anchor, Buffer};
use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use project::Project;
use serde::{Deserialize, Serialize};
use util::{paths::PathStyle, rel_path::RelPath};
use std::sync::Arc;
use std::{
borrow::Cow,
io::{Read, Write},
mem,
path::{Path, PathBuf},
};
use zeta_prompt::RelatedFile;
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
const REPOSITORY_URL_FIELD: &str = "repository_url";
const REVISION_FIELD: &str = "revision";
#[derive(Debug, Clone)]
pub struct NamedExample {
pub name: String,
pub example: Example,
}
#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Example {
#[serde(default)]
pub name: String,
pub repository_url: String,
pub revision: String,
#[serde(default)]
pub uncommitted_diff: String,
pub cursor_path: PathBuf,
pub cursor_path: Arc<Path>,
pub cursor_position: String,
pub edit_history: String,
pub expected_patch: String,
/// The full content of the file where an edit is being predicted, and the
/// actual cursor offset.
#[serde(skip_serializing_if = "Option::is_none")]
pub buffer: Option<ExampleBuffer>,
/// The context retrieved for the prediction. This requires the worktree to
/// be loaded and the language server to be started.
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<ExampleContext>,
/// The input and expected output from the edit prediction model.
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<ExamplePrompt>,
/// The actual predictions from the model.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub predictions: Vec<ExamplePrediction>,
/// The scores, for how well the actual predictions match the expected
/// predictions.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub score: Vec<ExampleScore>,
/// The application state used to process this example.
#[serde(skip)]
pub state: Option<ExampleState>,
}
#[derive(Clone, Debug)]
pub struct ExampleState {
pub project: Entity<Project>,
pub buffer: Entity<Buffer>,
pub cursor_position: Anchor,
pub _open_buffers: OpenedBuffers,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleContext {
pub files: Arc<[RelatedFile]>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleBuffer {
pub content: String,
pub cursor_row: u32,
pub cursor_column: u32,
pub cursor_offset: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExamplePrompt {
pub input: String,
pub expected_output: String,
pub format: PromptFormat,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExamplePrediction {
pub actual_patch: String,
pub actual_output: String,
pub provider: PredictionProvider,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleScore {
pub delta_chr_f: f32,
pub line_match: ClassificationMetrics,
}
impl Example {
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
// git@github.com:owner/repo.git
if self.repository_url.contains('@') {
let (owner, repo) = self
@@ -89,486 +131,249 @@ impl Example {
Ok((owner.into(), repo.into()))
}
}
}
pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
let (repo_owner, repo_name) = self.repo_name()?;
pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
let mut examples = Vec::new();
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let repo_lock = lock_repo(&repo_dir).await;
let stdin_path: PathBuf = PathBuf::from("-");
if !repo_dir.is_dir() {
fs::create_dir_all(&repo_dir)?;
run_git(&repo_dir, &["init"]).await?;
run_git(
&repo_dir,
&["remote", "add", "origin", &self.repository_url],
)
.await?;
}
let inputs = if inputs.is_empty() {
&[stdin_path]
} else {
inputs
};
// Resolve the example to a revision, fetching it if needed.
let revision = run_git(
&repo_dir,
&["rev-parse", &format!("{}^{{commit}}", self.revision)],
)
.await;
let revision = if let Ok(revision) = revision {
revision
for path in inputs {
let is_stdin = path.as_path() == Path::new("-");
let content = if is_stdin {
let mut buffer = String::new();
std::io::stdin()
.read_to_string(&mut buffer)
.expect("Failed to read from stdin");
buffer
} else {
if run_git(
&repo_dir,
&["fetch", "--depth", "1", "origin", &self.revision],
)
.await
.is_err()
{
run_git(&repo_dir, &["fetch", "origin"]).await?;
}
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
if revision != self.revision {
run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
}
revision
std::fs::read_to_string(path)
.unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
};
let filename = path.file_stem().unwrap().to_string_lossy().to_string();
let ext = if !is_stdin {
path.extension()
.map(|ext| ext.to_string_lossy().to_string())
.unwrap_or_else(|| panic!("{} should have an extension", path.display()))
} else {
"jsonl".to_string()
};
// Create the worktree for this example if needed.
let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
run_git(
&repo_dir,
&["worktree", "add", "-f", &worktree_path_string, &file_name],
)
.await?;
}
drop(repo_lock);
// Apply the uncommitted diff for this example.
if !self.uncommitted_diff.is_empty() {
let mut apply_process = smol::process::Command::new("git")
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()?;
let mut stdin = apply_process.stdin.take().unwrap();
stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
stdin.close().await?;
drop(stdin);
let apply_result = apply_process.output().await?;
if !apply_result.status.success() {
anyhow::bail!(
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
match ext.as_ref() {
"json" => {
let mut example =
serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
panic!("Failed to parse example file: {}\n{error}", path.display())
});
if example.name.is_empty() {
example.name = filename;
}
examples.push(example);
}
"jsonl" => examples.extend(
content
.lines()
.enumerate()
.map(|(line_ix, line)| {
let mut example =
serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
panic!(
"Failed to parse example on {}:{}\n{error}",
path.display(),
line_ix + 1
)
});
if example.name.is_empty() {
example.name = format!("{filename}-{line_ix}")
}
example
})
.collect::<Vec<Example>>(),
),
"md" => {
examples.push(parse_markdown_example(filename, &content).unwrap());
}
ext => {
panic!("{} has invalid example extension `{ext}`", path.display())
}
}
Ok(worktree_path)
}
pub fn unique_name(&self) -> String {
let mut hasher = std::hash::DefaultHasher::new();
self.hash(&mut hasher);
let disambiguator = hasher.finish();
let hash = format!("{:04x}", disambiguator);
format!("{}_{}", &self.revision[..8], &hash[..4])
sort_examples_by_repo_and_rev(&mut examples);
examples
}
pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
let mut content = String::new();
for example in examples {
let line = serde_json::to_string(example).unwrap();
content.push_str(&line);
content.push('\n');
}
if let Some(output_path) = output_path {
std::fs::write(output_path, content).expect("Failed to write examples");
} else {
std::io::stdout().write_all(&content.as_bytes()).unwrap();
}
}
pub type ActualExcerpt = Excerpt;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Excerpt {
pub path: PathBuf,
pub text: String,
pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
examples.sort_by(|a, b| {
a.repository_url
.cmp(&b.repository_url)
.then(b.revision.cmp(&a.revision))
});
}
#[derive(ValueEnum, Debug, Clone)]
pub enum ExampleFormat {
Json,
Toml,
Md,
pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
let mut examples_by_repo = HashMap::default();
for example in examples.iter_mut() {
examples_by_repo
.entry(example.repository_url.clone())
.or_insert_with(Vec::new)
.push(example);
}
examples_by_repo.into_values().collect()
}
impl NamedExample {
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)?;
let ext = path.extension();
fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
match ext.and_then(|s| s.to_str()) {
Some("json") => Ok(Self {
name: path.file_stem().unwrap_or_default().display().to_string(),
example: serde_json::from_str(&content)?,
}),
Some("toml") => Ok(Self {
name: path.file_stem().unwrap_or_default().display().to_string(),
example: toml::from_str(&content)?,
}),
Some("md") => Self::parse_md(&content),
Some(_) => {
anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
}
None => {
anyhow::bail!(
"Failed to determine example type since the file does not have an extension."
);
}
}
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
const REPOSITORY_URL_FIELD: &str = "repository_url";
const REVISION_FIELD: &str = "revision";
let parser = Parser::new(input);
let mut example = Example {
name: id,
repository_url: String::new(),
revision: String::new(),
uncommitted_diff: String::new(),
cursor_path: PathBuf::new().into(),
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
buffer: None,
context: None,
prompt: None,
predictions: Vec::new(),
score: Vec::new(),
state: None,
};
let mut text = String::new();
let mut block_info: CowStr = "".into();
#[derive(PartialEq)]
enum Section {
Start,
UncommittedDiff,
EditHistory,
CursorPosition,
ExpectedExcerpts,
ExpectedPatch,
Other,
}
pub fn parse_md(input: &str) -> Result<Self> {
use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
let mut current_section = Section::Start;
let parser = Parser::new(input);
for event in parser {
match event {
Event::Text(line) => {
text.push_str(&line);
let mut named = NamedExample {
name: String::new(),
example: Example {
repository_url: String::new(),
revision: String::new(),
uncommitted_diff: String::new(),
cursor_path: PathBuf::new(),
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
},
};
let mut text = String::new();
let mut block_info: CowStr = "".into();
#[derive(PartialEq)]
enum Section {
UncommittedDiff,
EditHistory,
CursorPosition,
ExpectedExcerpts,
ExpectedPatch,
Other,
}
let mut current_section = Section::Other;
for event in parser {
match event {
Event::Text(line) => {
text.push_str(&line);
if !named.name.is_empty()
&& current_section == Section::Other
// in h1 section
&& let Some((field, value)) = line.split_once('=')
{
match field.trim() {
REPOSITORY_URL_FIELD => {
named.example.repository_url = value.trim().to_string();
}
REVISION_FIELD => {
named.example.revision = value.trim().to_string();
}
_ => {}
if let Section::Start = current_section
&& let Some((field, value)) = line.split_once('=')
{
match field.trim() {
REPOSITORY_URL_FIELD => {
example.repository_url = value.trim().to_string();
}
REVISION_FIELD => {
example.revision = value.trim().to_string();
}
_ => {}
}
}
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
if !named.name.is_empty() {
anyhow::bail!(
"Found multiple H1 headings. There should only be one with the name of the example."
);
}
named.name = mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
let title = mem::take(&mut text);
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
Section::UncommittedDiff
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
Section::EditHistory
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
Section::CursorPosition
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
Section::ExpectedPatch
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
Section::ExpectedExcerpts
} else {
Section::Other
};
}
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(level)) => {
anyhow::bail!("Unexpected heading level: {level}");
}
Event::Start(Tag::CodeBlock(kind)) => {
match kind {
CodeBlockKind::Fenced(info) => {
block_info = info;
}
CodeBlockKind::Indented => {
anyhow::bail!("Unexpected indented codeblock");
}
};
}
Event::Start(_) => {
text.clear();
block_info = "".into();
}
Event::End(TagEnd::CodeBlock) => {
let block_info = block_info.trim();
match current_section {
Section::UncommittedDiff => {
named.example.uncommitted_diff = mem::take(&mut text);
}
Section::EditHistory => {
named.example.edit_history.push_str(&mem::take(&mut text));
}
Section::CursorPosition => {
named.example.cursor_path = block_info.into();
named.example.cursor_position = mem::take(&mut text);
}
Section::ExpectedExcerpts => {
mem::take(&mut text);
}
Section::ExpectedPatch => {
named.example.expected_patch = mem::take(&mut text);
}
Section::Other => {}
}
}
_ => {}
}
}
if named.example.cursor_path.as_path() == Path::new("")
|| named.example.cursor_position.is_empty()
{
anyhow::bail!("Missing cursor position codeblock");
}
Ok(named)
}
pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
match format {
ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
ExampleFormat::Toml => {
Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
}
ExampleFormat::Md => Ok(write!(out, "{}", self)?),
}
}
pub async fn setup_project(
&self,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<Entity<Project>> {
let worktree_path = self.setup_worktree().await?;
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
AUTHENTICATED
.get_or_init(|| {
let client = app_state.client.clone();
cx.spawn(async move |cx| {
client
.sign_in_with_optional_connect(true, cx)
.await
.unwrap();
})
.shared()
})
.clone()
.await;
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(&worktree_path, true, cx)
})?
.await?;
worktree
.read_with(cx, |worktree, _cx| {
worktree.as_local().unwrap().scan_complete()
})?
.await;
anyhow::Ok(project)
}
pub async fn setup_worktree(&self) -> Result<PathBuf> {
self.example.setup_worktree(self.file_name()).await
}
pub fn file_name(&self) -> String {
self.name
.chars()
.map(|c| {
if c.is_whitespace() {
'-'
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
let title = mem::take(&mut text);
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
Section::UncommittedDiff
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
Section::EditHistory
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
Section::CursorPosition
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
Section::ExpectedPatch
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
Section::ExpectedExcerpts
} else {
c.to_ascii_lowercase()
Section::Other
};
}
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(level)) => {
anyhow::bail!("Unexpected heading level: {level}");
}
Event::Start(Tag::CodeBlock(kind)) => {
match kind {
CodeBlockKind::Fenced(info) => {
block_info = info;
}
CodeBlockKind::Indented => {
anyhow::bail!("Unexpected indented codeblock");
}
};
}
Event::Start(_) => {
text.clear();
block_info = "".into();
}
Event::End(TagEnd::CodeBlock) => {
let block_info = block_info.trim();
match current_section {
Section::UncommittedDiff => {
example.uncommitted_diff = mem::take(&mut text);
}
Section::EditHistory => {
example.edit_history.push_str(&mem::take(&mut text));
}
Section::CursorPosition => {
example.cursor_path = Path::new(block_info).into();
example.cursor_position = mem::take(&mut text);
}
Section::ExpectedExcerpts => {
mem::take(&mut text);
}
Section::ExpectedPatch => {
example.expected_patch = mem::take(&mut text);
}
Section::Start | Section::Other => {}
}
})
.collect()
}
pub async fn cursor_position(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<(Entity<Buffer>, Anchor)> {
let worktree = project.read_with(cx, |project, cx| {
project.visible_worktrees(cx).next().unwrap()
})?;
let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
project.open_buffer(
ProjectPath {
worktree_id: worktree.read(cx).id(),
path: cursor_path,
},
cx,
)
})?
.await?;
let cursor_offset_within_excerpt = self
.example
.cursor_position
.find(CURSOR_MARKER)
.ok_or_else(|| anyhow!("missing cursor marker"))?;
let mut cursor_excerpt = self.example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
"",
);
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let Some((excerpt_offset, _)) = matches.next() else {
anyhow::bail!(
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
);
};
assert!(matches.next().is_none());
Ok(excerpt_offset)
})??;
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor =
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
Ok((cursor_buffer, cursor_anchor))
}
#[must_use]
pub async fn apply_edit_history(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'_>> {
edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
}
}
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = smol::process::Command::new("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
anyhow::ensure!(
output.status.success(),
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
args.join(" "),
repo_path.display(),
output.status,
String::from_utf8_lossy(&output.stderr),
String::from_utf8_lossy(&output.stdout),
);
Ok(String::from_utf8(output.stdout)?.trim().to_string())
}
impl Display for NamedExample {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "# {}\n\n", self.name)?;
write!(
f,
"{REPOSITORY_URL_FIELD} = {}\n",
self.example.repository_url
)?;
write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
write!(f, "`````diff\n")?;
write!(f, "{}", self.example.uncommitted_diff)?;
write!(f, "`````\n")?;
if !self.example.edit_history.is_empty() {
write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
}
_ => {}
}
write!(
f,
"## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
self.example.cursor_path.display(),
self.example.cursor_position
)?;
write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
if !self.example.expected_patch.is_empty() {
write!(
f,
"\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
self.example.expected_patch
)?;
}
Ok(())
}
}
if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
anyhow::bail!("Missing cursor position codeblock");
}
thread_local! {
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
}
#[must_use]
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
REPO_LOCKS
.with(|cell| {
cell.borrow_mut()
.entry(path.as_ref().to_path_buf())
.or_default()
.clone()
})
.lock_owned()
.await
Ok(example)
}

View File

@@ -0,0 +1,287 @@
use crate::{
PromptFormat,
example::{Example, ExamplePrompt},
headless::EpAppState,
load_project::run_load_project,
progress::{Progress, Step},
retrieve_context::run_context_retrieval,
};
use anyhow::{Context as _, Result, ensure};
use edit_prediction::{
EditPredictionStore,
zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
};
use gpui::AsyncApp;
use std::sync::Arc;
use zeta_prompt::format_zeta_prompt;
pub async fn run_format_prompt(
example: &mut Example,
prompt_format: PromptFormat,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) -> Result<()> {
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
match prompt_format {
PromptFormat::Teacher => {
let prompt = TeacherPrompt::format_prompt(example);
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output: example.expected_patch.clone(), // TODO
format: prompt_format,
});
}
PromptFormat::Zeta2 => {
run_load_project(example, app_state, cx.clone()).await?;
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
let state = example.state.as_ref().context("state must be set")?;
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
let project = state.project.clone();
let (_, input) = ep_store.update(&mut cx, |ep_store, _cx| {
anyhow::Ok(zeta2_prompt_input(
&snapshot,
example
.context
.as_ref()
.context("context must be set")?
.files
.clone(),
ep_store.edit_history_for_project(&project),
example.cursor_path.clone(),
example
.buffer
.as_ref()
.context("buffer must be set")?
.cursor_offset,
))
})??;
let prompt = format_zeta_prompt(&input);
let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output,
format: prompt_format,
});
}
};
Ok(())
}
pub struct TeacherPrompt;
impl TeacherPrompt {
const PROMPT: &str = include_str!("teacher.prompt.md");
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
/// Truncate edit history to this number of last lines
const MAX_HISTORY_LINES: usize = 128;
pub fn format_prompt(example: &Example) -> String {
let edit_history = Self::format_edit_history(&example.edit_history);
let context = Self::format_context(example);
let editable_region = Self::format_editable_region(example);
let prompt = Self::PROMPT
.replace("{{context}}", &context)
.replace("{{edit_history}}", &edit_history)
.replace("{{editable_region}}", &editable_region);
prompt
}
pub fn parse(example: &Example, response: &str) -> Result<String> {
// Ideally, we should always be able to find cursor position in the retrieved context.
// In reality, sometimes we don't find it for these reasons:
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
// (can be fixed by getting cursor coordinates at the load_example stage)
// 2. Context retriever just didn't include cursor line.
//
// In that case, fallback to using `cursor_position` as excerpt.
let cursor_file = &example
.buffer
.as_ref()
.context("`buffer` should be filled in in the context collection step")?
.content;
// Extract updated (new) editable region from the model response
let new_editable_region = extract_last_codeblock(response);
// Reconstruct old editable region we sent to the model
let old_editable_region = Self::format_editable_region(example);
let old_editable_region = Self::extract_editable_region(&old_editable_region);
ensure!(
cursor_file.contains(&old_editable_region),
"Something's wrong: editable_region is not found in the cursor file"
);
// Apply editable region to a larger context and compute diff.
// This is needed to get a better context lines around the editable region
let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
let diff = language::unified_diff(&cursor_file, &edited_file);
let diff = indoc::formatdoc! {"
--- a/{path}
+++ b/{path}
{diff}",
path = example.cursor_path.to_string_lossy(),
diff = diff,
};
Ok(diff)
}
fn format_edit_history(edit_history: &str) -> String {
// Strip comments ("garbage lines") from edit history
let lines = edit_history
.lines()
.filter(|&s| Self::is_udiff_content_line(s))
.collect::<Vec<_>>();
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
} else {
&lines
};
if history_lines.is_empty() {
return "(No edit history)".to_string();
}
history_lines.join("\n")
}
fn format_context(example: &Example) -> String {
assert!(example.context.is_some(), "Missing context retriever step");
let mut prompt = String::new();
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
prompt
}
fn format_editable_region(example: &Example) -> String {
let mut result = String::new();
let path_str = example.cursor_path.to_string_lossy();
result.push_str(&format!("`````path=\"{path_str}\"\n"));
result.push_str(Self::EDITABLE_REGION_START);
// TODO: control number of lines around cursor
result.push_str(&example.cursor_position);
if !example.cursor_position.ends_with('\n') {
result.push('\n');
}
result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
result.push_str("`````");
result
}
fn extract_editable_region(text: &str) -> String {
let start = text
.find(Self::EDITABLE_REGION_START)
.map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
let region = &text[start..end];
region.replace("<|user_cursor|>", "")
}
fn is_udiff_content_line(s: &str) -> bool {
s.starts_with("-")
|| s.starts_with("+")
|| s.starts_with(" ")
|| s.starts_with("---")
|| s.starts_with("+++")
|| s.starts_with("@@")
}
}
fn extract_last_codeblock(text: &str) -> String {
let mut last_block = None;
let mut search_start = 0;
while let Some(start) = text[search_start..].find("```") {
let start = start + search_start;
let bytes = text.as_bytes();
let mut backtick_end = start;
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
backtick_end += 1;
}
let backtick_count = backtick_end - start;
let closing_backticks = "`".repeat(backtick_count);
while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
backtick_end += 1;
}
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
let code_block = &text[backtick_end + 1..backtick_end + end_pos];
last_block = Some(code_block.to_string());
search_start = backtick_end + end_pos + backtick_count;
} else {
break;
}
}
last_block.unwrap_or_else(|| text.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_last_code_block() {
let text = indoc::indoc! {"
Some thinking
```
first block
```
`````path='something' lines=1:2
last block
`````
"};
let last_block = extract_last_codeblock(text);
assert_eq!(last_block, "last block\n");
}
#[test]
fn test_extract_editable_region() {
let text = indoc::indoc! {"
some lines
are
here
<|editable_region_start|>
one
two three
<|editable_region_end|>
more
lines here
"};
let parsed = TeacherPrompt::extract_editable_region(text);
assert_eq!(
parsed,
indoc::indoc! {"
one
two three
"}
);
}
}

View File

@@ -1,4 +1,5 @@
use client::{Client, ProxySettings, UserStore};
use collections::HashMap;
use extension::ExtensionHostProxy;
use fs::RealFs;
use gpui::http_client::read_proxy_from_env;
@@ -7,25 +8,39 @@ use gpui_tokio::Tokio;
use language::LanguageRegistry;
use language_extension::LspAccess;
use node_runtime::{NodeBinaryOptions, NodeRuntime};
use project::Project;
use project::project_settings::ProjectSettings;
use release_channel::{AppCommitSha, AppVersion};
use reqwest_client::ReqwestClient;
use settings::{Settings, SettingsStore};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use util::ResultExt as _;
/// Headless subset of `workspace::AppState`.
pub struct ZetaCliAppState {
pub struct EpAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
pub project_cache: ProjectCache,
}
// TODO: dedupe with crates/eval/src/eval.rs
pub fn init(cx: &mut App) -> ZetaCliAppState {
#[derive(Default)]
pub struct ProjectCache(Mutex<HashMap<String, Entity<Project>>>);
impl ProjectCache {
pub fn insert(&self, repository_url: String, project: Entity<Project>) {
self.0.lock().unwrap().insert(repository_url, project);
}
pub fn get(&self, repository_url: &String) -> Option<Entity<Project>> {
self.0.lock().unwrap().get(repository_url).cloned()
}
}
pub fn init(cx: &mut App) -> EpAppState {
let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
let app_version = AppVersion::load(
@@ -112,11 +127,14 @@ pub fn init(cx: &mut App) -> ZetaCliAppState {
prompt_store::init(cx);
terminal_view::init(cx);
ZetaCliAppState {
let project_cache = ProjectCache::default();
EpAppState {
languages,
client,
user_store,
fs,
node_runtime,
project_cache,
}
}

View File

@@ -0,0 +1,346 @@
use crate::{
example::{Example, ExampleBuffer, ExampleState},
headless::EpAppState,
paths::{REPOS_DIR, WORKTREES_DIR},
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::EditPredictionStore;
use edit_prediction::udiff::OpenedBuffers;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
};
use gpui::{AsyncApp, Entity};
use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
use project::buffer_store::BufferStoreEvent;
use project::{Project, ProjectPath};
use std::{
cell::RefCell,
fs,
path::{Path, PathBuf},
sync::Arc,
};
use util::{paths::PathStyle, rel_path::RelPath};
use zeta_prompt::CURSOR_MARKER;
pub async fn run_load_project(
example: &mut Example,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) -> Result<()> {
if example.state.is_some() {
return Ok(());
}
let progress = Progress::global().start(Step::LoadProject, &example.name);
let project = setup_project(example, &app_state, &progress, &mut cx).await?;
let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
let language_name = buffer
.language()
.map(|l| l.name().to_string())
.unwrap_or_else(|| "Unknown".to_string());
(
ExampleBuffer {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
},
language_name,
)
})?;
progress.set_info(language_name, InfoStyle::Normal);
example.buffer = Some(example_buffer);
example.state = Some(ExampleState {
buffer,
project,
cursor_position,
_open_buffers,
});
Ok(())
}
async fn cursor_position(
example: &Example,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<(Entity<Buffer>, Anchor)> {
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
let result = language_registry
.load_language_for_file_path(&example.cursor_path)
.await;
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
return Err(error);
}
let worktree = project.read_with(cx, |project, cx| {
project
.visible_worktrees(cx)
.next()
.context("No visible worktrees")
})??;
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
.context("Failed to create RelPath")?
.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
project.open_buffer(
ProjectPath {
worktree_id: worktree.read(cx).id(),
path: cursor_path,
},
cx,
)
})?
.await?;
let cursor_offset_within_excerpt = example
.cursor_position
.find(CURSOR_MARKER)
.context("missing cursor marker")?;
let mut cursor_excerpt = example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
"",
);
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let (excerpt_offset, _) = matches.next().with_context(|| {
format!(
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
example.name
)
})?;
anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
Ok(excerpt_offset)
})??;
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor =
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
Ok((cursor_buffer, cursor_anchor))
}
async fn setup_project(
example: &mut Example,
app_state: &Arc<EpAppState>,
step_progress: &StepProgress,
cx: &mut AsyncApp,
) -> Result<Entity<Project>> {
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx))?
.context("Store should be initialized at init")?;
let worktree_path = setup_worktree(example, step_progress).await?;
if let Some(project) = app_state.project_cache.get(&example.repository_url) {
ep_store.update(cx, |ep_store, _| {
ep_store.clear_history_for_project(&project);
})?;
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
let buffers = buffer_store.read_with(cx, |buffer_store, _| {
buffer_store.buffers().collect::<Vec<_>>()
})?;
for buffer in buffers {
buffer
.update(cx, |buffer, cx| buffer.reload(cx))?
.await
.ok();
}
return Ok(project);
}
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
project
.update(cx, |project, cx| {
project.disable_worktree_scanner(cx);
project.create_worktree(&worktree_path, true, cx)
})?
.await?;
app_state
.project_cache
.insert(example.repository_url.clone(), project.clone());
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
cx.subscribe(&buffer_store, {
let project = project.clone();
move |_, event, cx| match event {
BufferStoreEvent::BufferAdded(buffer) => {
ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
}
_ => {}
}
})?
.detach();
Ok(project)
}
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let worktree_path = WORKTREES_DIR
.join(repo_owner.as_ref())
.join(repo_name.as_ref());
let repo_lock = lock_repo(&repo_dir).await;
if !repo_dir.is_dir() {
step_progress.set_substatus(format!("cloning {}", repo_name));
fs::create_dir_all(&repo_dir)?;
run_git(&repo_dir, &["init"]).await?;
run_git(
&repo_dir,
&["remote", "add", "origin", &example.repository_url],
)
.await?;
}
// Resolve the example to a revision, fetching it if needed.
let revision = run_git(
&repo_dir,
&["rev-parse", &format!("{}^{{commit}}", example.revision)],
)
.await;
let revision = if let Ok(revision) = revision {
revision
} else {
step_progress.set_substatus("fetching");
if run_git(
&repo_dir,
&["fetch", "--depth", "1", "origin", &example.revision],
)
.await
.is_err()
{
run_git(&repo_dir, &["fetch", "origin"]).await?;
}
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
revision
};
// Create the worktree for this example if needed.
step_progress.set_substatus("preparing worktree");
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(
&repo_dir,
&["branch", "-f", &example.name, revision.as_str()],
)
.await?;
run_git(
&repo_dir,
&[
"worktree",
"add",
"-f",
&worktree_path_string,
&example.name,
],
)
.await?;
}
drop(repo_lock);
// Apply the uncommitted diff for this example.
if !example.uncommitted_diff.is_empty() {
step_progress.set_substatus("applying diff");
let mut apply_process = smol::process::Command::new("git")
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()?;
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
stdin.write_all(example.uncommitted_diff.as_bytes()).await?;
stdin.close().await?;
drop(stdin);
let apply_result = apply_process.output().await?;
anyhow::ensure!(
apply_result.status.success(),
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
step_progress.clear_substatus();
Ok(worktree_path)
}
async fn apply_edit_history(
example: &Example,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers> {
edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
}
thread_local! {
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
}
#[must_use]
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
REPO_LOCKS
.with(|cell| {
cell.borrow_mut()
.entry(path.as_ref().to_path_buf())
.or_default()
.clone()
})
.lock_owned()
.await
}
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = smol::process::Command::new("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
anyhow::ensure!(
output.status.success(),
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
args.join(" "),
repo_path.display(),
output.status,
String::from_utf8_lossy(&output.stderr),
String::from_utf8_lossy(&output.stdout),
);
Ok(String::from_utf8(output.stdout)?.trim().to_string())
}

View File

@@ -1,523 +1,340 @@
mod evaluate;
mod anthropic_client;
mod distill;
mod example;
mod format_prompt;
mod headless;
mod load_project;
mod metrics;
mod paths;
mod predict;
mod source_location;
mod training;
mod util;
mod progress;
mod retrieve_context;
mod score;
use crate::{
evaluate::run_evaluate,
example::{ExampleFormat, NamedExample},
headless::ZetaCliAppState,
predict::run_predict,
source_location::SourceLocation,
training::{context::ContextType, distill::run_distill},
util::{open_buffer, open_buffer_with_language_server},
};
use ::util::{ResultExt, paths::PathStyle};
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand, ValueEnum};
use cloud_llm_client::predict_edits_v3;
use edit_prediction::udiff::DiffLine;
use edit_prediction_context::EditPredictionExcerptOptions;
use gpui::{Application, AsyncApp, Entity, prelude::*};
use language::{Bias, Buffer, BufferSnapshot, Point};
use metrics::delta_chr_f;
use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use edit_prediction::EditPredictionStore;
use gpui::Application;
use reqwest_client::ReqwestClient;
use std::io::{self};
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use std::{path::PathBuf, sync::Arc};
use crate::distill::run_distill;
use crate::example::{group_examples_by_repo, read_examples, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::paths::FAILED_EXAMPLES_DIR;
use crate::predict::run_prediction;
use crate::progress::Progress;
use crate::retrieve_context::run_context_retrieval;
use crate::score::run_scoring;
#[derive(Parser, Debug)]
#[command(name = "zeta")]
struct ZetaCliArgs {
#[command(name = "ep")]
struct EpArgs {
#[arg(long, default_value_t = false)]
printenv: bool,
#[clap(long, default_value_t = 10, global = true)]
max_parallelism: usize,
#[command(subcommand)]
command: Option<Command>,
#[clap(global = true)]
inputs: Vec<PathBuf>,
#[arg(long, short, global = true)]
output: Option<PathBuf>,
#[arg(long, short, global = true)]
in_place: bool,
#[arg(long, short, global = true)]
failfast: bool,
}
#[derive(Subcommand, Debug)]
enum Command {
Context(ContextArgs),
Predict(PredictArguments),
Eval(EvaluateArguments),
Distill(DistillArguments),
ConvertExample {
path: PathBuf,
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
output_format: ExampleFormat,
},
Score {
golden_patch: PathBuf,
actual_patch: PathBuf,
},
/// Parse markdown examples and output a combined .jsonl file
ParseExample,
/// Create git worktrees for each example and load file contents
LoadProject,
/// Retrieve context for input examples.
Context,
/// Generate a prompt string for a specific model
FormatPrompt(FormatPromptArgs),
/// Runs edit prediction
Predict(PredictArgs),
/// Computes a score based on actual and expected patches
Score(PredictArgs),
/// Prepares a distillation dataset by copying expected outputs to
/// predicted outputs and removing actual outputs and prompts.
Distill,
/// Print aggregated scores
Eval(PredictArgs),
/// Remove git repositories and worktrees
Clean,
}
#[derive(Debug, Args)]
struct ContextArgs {
#[arg(long)]
provider: ContextProvider,
#[arg(long)]
worktree: PathBuf,
#[arg(long)]
cursor: SourceLocation,
#[arg(long)]
use_language_server: bool,
#[arg(long)]
edit_history: Option<FileOrStdin>,
#[clap(flatten)]
zeta2_args: Zeta2Args,
impl Display for Command {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Command::ParseExample => write!(f, "parse-example"),
Command::LoadProject => write!(f, "load-project"),
Command::Context => write!(f, "context"),
Command::FormatPrompt(format_prompt_args) => write!(
f,
"format-prompt --prompt-format={}",
format_prompt_args
.prompt_format
.to_possible_value()
.unwrap()
.get_name()
),
Command::Predict(predict_args) => {
write!(
f,
"predict --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
)
}
Command::Score(predict_args) => {
write!(
f,
"score --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
)
}
Command::Distill => write!(f, "distill"),
Command::Eval(predict_args) => write!(
f,
"eval --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
),
Command::Clean => write!(f, "clean"),
}
}
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
enum ContextProvider {
Zeta1,
#[default]
#[derive(Debug, Args)]
struct FormatPromptArgs {
#[clap(long)]
prompt_format: PromptFormat,
}
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
enum PromptFormat {
Teacher,
Zeta2,
}
#[derive(Clone, Debug, Args)]
struct Zeta2Args {
#[arg(long, default_value_t = 8192)]
max_prompt_bytes: usize,
#[arg(long, default_value_t = 2048)]
max_excerpt_bytes: usize,
#[arg(long, default_value_t = 1024)]
min_excerpt_bytes: usize,
#[arg(long, default_value_t = 0.66)]
target_before_cursor_over_total_bytes: f32,
#[arg(long, default_value_t = 1024)]
max_diagnostic_bytes: usize,
#[arg(long, value_enum, default_value_t = PromptFormat::default())]
prompt_format: PromptFormat,
#[arg(long, value_enum, default_value_t = Default::default())]
output_format: OutputFormat,
#[arg(long, default_value_t = 42)]
file_indexing_parallelism: usize,
#[arg(long, default_value_t = false)]
disable_imports_gathering: bool,
#[arg(long, default_value_t = u8::MAX)]
max_retrieved_definitions: u8,
}
#[derive(Debug, Args)]
pub struct PredictArguments {
#[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
format: PredictionsOutputFormat,
example_path: PathBuf,
#[clap(flatten)]
options: PredictionOptions,
}
#[derive(Debug, Args)]
pub struct DistillArguments {
split_commit_dataset: PathBuf,
#[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
context_type: ContextType,
#[clap(long)]
batch: Option<String>,
}
#[derive(Clone, Debug, Args)]
pub struct PredictionOptions {
#[clap(flatten)]
zeta2: Zeta2Args,
struct PredictArgs {
#[clap(long)]
provider: PredictionProvider,
#[clap(long, value_enum, default_value_t = CacheMode::default())]
cache: CacheMode,
#[clap(long, default_value_t = 1)]
repetitions: usize,
}
#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
pub enum CacheMode {
/// Use cached LLM requests and responses, except when multiple repetitions are requested
#[default]
Auto,
/// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
#[value(alias = "request")]
Requests,
/// Ignore existing cache entries for both LLM and search.
Skip,
/// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
/// Useful for reproducing results and fixing bugs outside of search queries
Force,
}
impl CacheMode {
fn use_cached_llm_responses(&self) -> bool {
self.assert_not_auto();
matches!(self, CacheMode::Requests | CacheMode::Force)
}
fn use_cached_search_results(&self) -> bool {
self.assert_not_auto();
matches!(self, CacheMode::Force)
}
fn assert_not_auto(&self) {
assert_ne!(
*self,
CacheMode::Auto,
"Cache mode should not be auto at this point!"
);
}
}
#[derive(clap::ValueEnum, Debug, Clone)]
pub enum PredictionsOutputFormat {
Json,
Md,
Diff,
}
#[derive(Debug, Args)]
pub struct EvaluateArguments {
example_paths: Vec<PathBuf>,
#[clap(flatten)]
options: PredictionOptions,
#[clap(short, long, default_value_t = 1, alias = "repeat")]
repetitions: u16,
#[arg(long)]
skip_prediction: bool,
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
enum PredictionProvider {
Zeta1,
#[default]
Zeta2,
Sweep,
Mercury,
Zeta1,
Zeta2,
Teacher,
TeacherNonBatching,
}
fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
edit_prediction::ZetaOptions {
context: EditPredictionExcerptOptions {
max_bytes: args.max_excerpt_bytes,
min_bytes: args.min_excerpt_bytes,
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
},
max_prompt_bytes: args.max_prompt_bytes,
prompt_format: args.prompt_format.into(),
}
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
enum PromptFormat {
OnlySnippets,
#[default]
OldTextNewText,
Minimal,
MinimalQwen,
SeedCoder1120,
}
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
fn into(self) -> predict_edits_v3::PromptFormat {
match self {
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
}
}
}
#[derive(clap::ValueEnum, Default, Debug, Clone)]
enum OutputFormat {
#[default]
Prompt,
Request,
Full,
}
#[derive(Debug, Clone)]
enum FileOrStdin {
File(PathBuf),
Stdin,
}
impl FileOrStdin {
async fn read_to_string(&self) -> Result<String, std::io::Error> {
match self {
FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
}
}
}
impl FromStr for FileOrStdin {
type Err = <PathBuf as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"-" => Ok(Self::Stdin),
_ => Ok(Self::File(PathBuf::from_str(s)?)),
}
}
}
struct LoadedContext {
full_path_str: String,
snapshot: BufferSnapshot,
clipped_cursor: Point,
worktree: Entity<Worktree>,
project: Entity<Project>,
buffer: Entity<Buffer>,
lsp_open_handle: Option<OpenLspBufferHandle>,
}
async fn load_context(
args: &ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<LoadedContext> {
let ContextArgs {
worktree: worktree_path,
cursor,
use_language_server,
..
} = args;
let worktree_path = worktree_path.canonicalize()?;
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(&worktree_path, true, cx)
})?
.await?;
let mut ready_languages = HashSet::default();
let (lsp_open_handle, buffer) = if *use_language_server {
let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
project.clone(),
worktree.clone(),
cursor.path.clone(),
&mut ready_languages,
cx,
)
.await?;
(Some(lsp_open_handle), buffer)
} else {
let buffer =
open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
(None, buffer)
};
let full_path_str = worktree
.read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
.display(PathStyle::local())
.to_string();
let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
if clipped_cursor != cursor.point {
let max_row = snapshot.max_point().row;
if cursor.point.row < max_row {
return Err(anyhow!(
"Cursor position {:?} is out of bounds (line length is {})",
cursor.point,
snapshot.line_len(cursor.point.row)
));
impl EpArgs {
fn output_path(&self) -> Option<PathBuf> {
if self.in_place {
if self.inputs.len() == 1 {
self.inputs.first().cloned()
} else {
panic!("--in-place requires exactly one input file")
}
} else {
return Err(anyhow!(
"Cursor position {:?} is out of bounds (max row is {})",
cursor.point,
max_row
));
self.output.clone()
}
}
Ok(LoadedContext {
full_path_str,
snapshot,
clipped_cursor,
worktree,
project,
buffer,
lsp_open_handle,
})
}
async fn zeta2_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<String> {
let LoadedContext {
worktree,
project,
buffer,
clipped_cursor,
lsp_open_handle: _handle,
..
} = load_context(&args, app_state, cx).await?;
// wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
// the whole worktree.
worktree
.read_with(cx, |worktree, _cx| {
worktree.as_local().unwrap().scan_complete()
})?
.await;
let output = cx
.update(|cx| {
let store = cx.new(|cx| {
edit_prediction::EditPredictionStore::new(
app_state.client.clone(),
app_state.user_store.clone(),
cx,
)
});
store.update(cx, |store, cx| {
store.set_options(zeta2_args_to_options(&args.zeta2_args));
store.register_buffer(&buffer, &project, cx);
});
cx.spawn(async move |cx| {
let updates_rx = store.update(cx, |store, cx| {
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
store.set_use_context(true);
store.refresh_context(&project, &buffer, cursor, cx);
store.project_context_updates(&project).unwrap()
})?;
updates_rx.recv().await.ok();
let context = store.update(cx, |store, cx| {
store.context_for_project(&project, cx).to_vec()
})?;
anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
})
})?
.await?;
Ok(output)
}
async fn zeta1_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<edit_prediction::zeta1::GatherContextOutput> {
let LoadedContext {
full_path_str,
snapshot,
clipped_cursor,
..
} = load_context(&args, app_state, cx).await?;
let events = match args.edit_history {
Some(events) => events.read_to_string().await?,
None => String::new(),
};
let prompt_for_events = move || (events, 0);
cx.update(|cx| {
edit_prediction::zeta1::gather_context(
full_path_str,
&snapshot,
clipped_cursor,
prompt_for_events,
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})?
.await
}
fn main() {
zlog::init();
zlog::init_output_stderr();
let args = ZetaCliArgs::parse();
let args = EpArgs::parse();
if args.printenv {
::util::shell_env::print_env();
return;
}
let output = args.output_path();
let command = match args.command {
Some(cmd) => cmd,
None => {
EpArgs::command().print_help().unwrap();
return;
}
};
match &command {
Command::Clean => {
std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
return;
}
_ => {}
}
let mut examples = read_examples(&args.inputs);
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client);
app.run(move |cx| {
let app_state = Arc::new(headless::init(cx));
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
cx.spawn(async move |cx| {
match args.command {
None => {
if args.printenv {
::util::shell_env::print_env();
} else {
panic!("Expected a command");
}
let result = async {
if let Command::Predict(args) = &command {
predict::sync_batches(&args.provider).await?;
}
Some(Command::Context(context_args)) => {
let result = match context_args.provider {
ContextProvider::Zeta1 => {
let context =
zeta1_context(context_args, &app_state, cx).await.unwrap();
serde_json::to_string_pretty(&context.body).unwrap()
let total_examples = examples.len();
Progress::global().set_total_examples(total_examples);
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
for example_batch in example_batches {
let futures = example_batch.into_iter().map(|repo_examples| async {
for example in repo_examples.iter_mut() {
let result = async {
match &command {
Command::ParseExample => {}
Command::LoadProject => {
run_load_project(example, app_state.clone(), cx.clone())
.await?;
}
Command::Context => {
run_context_retrieval(
example,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::FormatPrompt(args) => {
run_format_prompt(
example,
args.prompt_format,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::Predict(args) => {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::Distill => {
run_distill(example).await?;
}
Command::Score(args) | Command::Eval(args) => {
run_scoring(example, &args, app_state.clone(), cx.clone())
.await?;
}
Command::Clean => {
unreachable!()
}
}
anyhow::Ok(())
}
.await;
if let Err(e) = result {
Progress::global().increment_failed();
let failed_example_path =
FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name));
app_state
.fs
.write(
&failed_example_path,
&serde_json::to_vec_pretty(&example).unwrap(),
)
.await
.unwrap();
let err_path =
FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name));
app_state
.fs
.write(&err_path, e.to_string().as_bytes())
.await
.unwrap();
let msg = format!(
indoc::indoc! {"
While processing {}:
{:?}
Written to: \x1b[36m{}\x1b[0m
Explore this example data with:
fx \x1b[36m{}\x1b[0m
Re-run this example with:
cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
"},
example.name,
e,
err_path.display(),
failed_example_path.display(),
command,
failed_example_path.display(),
);
if args.failfast || total_examples == 1 {
Progress::global().finalize();
panic!("{}", msg);
} else {
log::error!("{}", msg);
}
}
}
ContextProvider::Zeta2 => {
zeta2_context(context_args, &app_state, cx).await.unwrap()
}
};
println!("{}", result);
});
futures::future::join_all(futures).await;
}
Some(Command::Predict(arguments)) => {
run_predict(arguments, &app_state, cx).await;
}
Some(Command::Eval(arguments)) => {
run_evaluate(arguments, &app_state, cx).await;
}
Some(Command::Distill(arguments)) => {
let _guard = cx
.update(|cx| gpui_tokio::Tokio::handle(cx))
.unwrap()
.enter();
run_distill(arguments).await.log_err();
}
Some(Command::ConvertExample {
path,
output_format,
}) => {
let example = NamedExample::load(path).unwrap();
example.write(output_format, io::stdout()).unwrap();
}
Some(Command::Score {
golden_patch,
actual_patch,
}) => {
let golden_content = std::fs::read_to_string(golden_patch).unwrap();
let actual_content = std::fs::read_to_string(actual_patch).unwrap();
Progress::global().finalize();
let golden_diff: Vec<DiffLine> = golden_content
.lines()
.map(|line| DiffLine::parse(line))
.collect();
let actual_diff: Vec<DiffLine> = actual_content
.lines()
.map(|line| DiffLine::parse(line))
.collect();
let score = delta_chr_f(&golden_diff, &actual_diff);
println!("{:.2}", score);
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
write_examples(&examples, output.as_ref());
}
Some(Command::Clean) => {
std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
}
};
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await?,
Command::Eval(_) => score::print_report(&examples),
_ => (),
};
anyhow::Ok(())
}
.await;
if let Err(e) = result {
panic!("Fatal error: {:?}", e);
}
let _ = cx.update(|cx| cx.quit());
})

View File

@@ -1,30 +1,34 @@
use collections::{HashMap, HashSet};
use edit_prediction::udiff::DiffLine;
use serde::{Deserialize, Serialize};
type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
#[derive(Default, Debug, Clone)]
pub struct Scores {
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationMetrics {
pub true_positives: usize,
pub false_positives: usize,
pub false_negatives: usize,
}
impl Scores {
pub fn from_sets(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
impl ClassificationMetrics {
pub fn from_sets(
expected: &HashSet<String>,
actual: &HashSet<String>,
) -> ClassificationMetrics {
let true_positives = expected.intersection(actual).count();
let false_positives = actual.difference(expected).count();
let false_negatives = expected.difference(actual).count();
Scores {
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores {
pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -45,32 +49,16 @@ impl Scores {
}
}
Scores {
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
pub fn to_markdown(&self) -> String {
format!(
"
Precision : {:.4}
Recall : {:.4}
F1 Score : {:.4}
True Positives : {}
False Positives : {}
False Negatives : {}",
self.precision(),
self.recall(),
self.f1_score(),
self.true_positives,
self.false_positives,
self.false_negatives
)
}
pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
pub fn aggregate<'a>(
scores: impl Iterator<Item = &'a ClassificationMetrics>,
) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -81,7 +69,7 @@ False Negatives : {}",
false_negatives += score.false_negatives;
}
Scores {
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
@@ -115,7 +103,10 @@ False Negatives : {}",
}
}
pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores {
pub fn line_match_score(
expected_patch: &[DiffLine],
actual_patch: &[DiffLine],
) -> ClassificationMetrics {
let expected_change_lines = expected_patch
.iter()
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
@@ -128,7 +119,7 @@ pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine])
.map(|line| line.to_string())
.collect();
Scores::from_sets(&expected_change_lines, &actual_change_lines)
ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
}
enum ChrfWhitespace {
@@ -204,7 +195,7 @@ pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
let expected_counts = ngram_delta_to_counts(&expected_delta);
let actual_counts = ngram_delta_to_counts(&actual_delta);
let score = Scores::from_counts(&expected_counts, &actual_counts);
let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
total_precision += score.precision();
total_recall += score.recall();
}

View File

@@ -1,57 +1,27 @@
use std::{env, path::PathBuf, sync::LazyLock};
use std::{
path::{Path, PathBuf},
sync::LazyLock,
};
pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
pub static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
let dir = dirs::home_dir().unwrap().join(".zed_ep");
ensure_dir(&dir)
});
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("cache")));
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("repos")));
pub static WORKTREES_DIR: LazyLock<PathBuf> =
LazyLock::new(|| ensure_dir(&DATA_DIR.join("worktrees")));
pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
TARGET_ZETA_DIR
DATA_DIR
.join("runs")
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
});
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
pub fn print_run_data_dir(deep: bool, use_color: bool) {
println!("\n## Run Data\n");
let mut files = Vec::new();
let current_dir = std::env::current_dir().unwrap();
for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
let file = file.unwrap();
if file.file_type().unwrap().is_dir() && deep {
for file in std::fs::read_dir(file.path()).unwrap() {
let path = file.unwrap().path();
let path = path.strip_prefix(&current_dir).unwrap_or(&path);
files.push(format!(
"- {}/{}{}{}",
path.parent().unwrap().display(),
if use_color { "\x1b[34m" } else { "" },
path.file_name().unwrap().display(),
if use_color { "\x1b[0m" } else { "" },
));
}
} else {
let path = file.path();
let path = path.strip_prefix(&current_dir).unwrap_or(&path);
files.push(format!(
"- {}/{}{}{}",
path.parent().unwrap().display(),
if use_color { "\x1b[34m" } else { "" },
path.file_name().unwrap().display(),
if use_color { "\x1b[0m" } else { "" }
));
}
}
files.sort();
for file in files {
println!("{}", file);
}
println!(
"\n💡 Tip of the day: {} always points to the latest run\n",
LATEST_EXAMPLE_RUN_DIR.display()
);
fn ensure_dir(path: &Path) -> PathBuf {
std::fs::create_dir_all(path).expect("Failed to create directory");
path.to_path_buf()
}

View File

@@ -1,374 +1,291 @@
use crate::example::{ActualExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
use crate::{
CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
PredictionProvider, PromptFormat,
anthropic_client::AnthropicClient,
example::{Example, ExamplePrediction},
format_prompt::{TeacherPrompt, run_format_prompt},
headless::EpAppState,
load_project::run_load_project,
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
progress::{InfoStyle, Progress, Step},
retrieve_context::run_context_retrieval,
};
use anyhow::Context as _;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, future::Shared};
use gpui::{AppContext as _, AsyncApp, Task};
use std::{
fs,
sync::{
Arc, Mutex, OnceLock,
atomic::{AtomicUsize, Ordering::SeqCst},
},
};
use ::serde::Serialize;
use anyhow::{Context, Result, anyhow};
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
use futures::StreamExt as _;
use gpui::{AppContext, AsyncApp, Entity};
use project::Project;
use project::buffer_store::BufferStoreEvent;
use serde::Deserialize;
use std::fs;
use std::io::{IsTerminal, Write};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
pub async fn run_predict(
args: PredictArguments,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) {
let example = NamedExample::load(args.example_path).unwrap();
let project = example.setup_project(app_state, cx).await.unwrap();
let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
let result = perform_predict(example, project, store, None, args.options, cx)
.await
.unwrap();
result.write(args.format, std::io::stdout()).unwrap();
pub async fn run_prediction(
example: &mut Example,
provider: Option<PredictionProvider>,
repetition_count: usize,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) -> anyhow::Result<()> {
if !example.predictions.is_empty() {
return Ok(());
}
print_run_data_dir(true, std::io::stdout().is_terminal());
}
let provider = provider.context("provider is required")?;
pub fn setup_store(
provider: PredictionProvider,
project: &Entity<Project>,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<Entity<EditPredictionStore>> {
let store = cx.new(|cx| {
edit_prediction::EditPredictionStore::new(
app_state.client.clone(),
app_state.user_store.clone(),
cx,
)
})?;
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
store.update(cx, |store, _cx| {
if matches!(
provider,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
) {
let _step_progress = Progress::global().start(Step::Predict, &example.name);
if example.prompt.is_none() {
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
}
let batched = matches!(provider, PredictionProvider::Teacher);
return predict_anthropic(example, repetition_count, batched).await;
}
run_load_project(example, app_state.clone(), cx.clone()).await?;
let _step_progress = Progress::global().start(Step::Predict, &example.name);
if matches!(
provider,
PredictionProvider::Zeta1 | PredictionProvider::Zeta2
) {
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
AUTHENTICATED
.get_or_init(|| {
let client = app_state.client.clone();
cx.spawn(async move |cx| {
if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
eprintln!("Authentication failed: {}", e);
}
})
.shared()
})
.clone()
.await;
}
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
ep_store.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
unreachable!()
}
};
store.set_edit_prediction_model(model);
})?;
let state = example.state.as_ref().context("state must be set")?;
let run_dir = RUN_DIR.join(&example.name);
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
let updated_example = Arc::new(Mutex::new(example.clone()));
let current_run_ix = Arc::new(AtomicUsize::new(0));
cx.subscribe(&buffer_store, {
let project = project.clone();
let store = store.clone();
move |_, event, cx| match event {
BufferStoreEvent::BufferAdded(buffer) => {
store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
}
_ => {}
}
})?
.detach();
let mut debug_rx =
ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
let debug_task = cx.background_spawn({
let updated_example = updated_example.clone();
let current_run_ix = current_run_ix.clone();
let run_dir = run_dir.clone();
async move {
while let Some(event) = debug_rx.next().await {
let run_ix = current_run_ix.load(SeqCst);
let mut updated_example = updated_example.lock().unwrap();
anyhow::Ok(store)
}
let run_dir = if repetition_count > 1 {
run_dir.join(format!("{:03}", run_ix))
} else {
run_dir.clone()
};
pub async fn perform_predict(
example: NamedExample,
project: Entity<Project>,
store: Entity<EditPredictionStore>,
repetition_ix: Option<u16>,
options: PredictionOptions,
cx: &mut AsyncApp,
) -> Result<PredictionDetails> {
let mut cache_mode = options.cache;
if repetition_ix.is_some() {
if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
panic!("Repetitions are not supported in Auto cache mode");
} else {
cache_mode = CacheMode::Skip;
}
} else if cache_mode == CacheMode::Auto {
cache_mode = CacheMode::Requests;
}
match event {
DebugEvent::EditPredictionStarted(request) => {
assert_eq!(updated_example.predictions.len(), run_ix + 1);
let mut example_run_dir = RUN_DIR.join(&example.file_name());
if let Some(repetition_ix) = repetition_ix {
example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
}
fs::create_dir_all(&example_run_dir)?;
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
}
#[cfg(unix)]
std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
.context("creating latest link")?;
#[cfg(windows)]
std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
.context("creating latest link")?;
store.update(cx, |store, _cx| {
store.with_eval_cache(Arc::new(RunCache {
example_run_dir: example_run_dir.clone(),
cache_mode,
}));
})?;
let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
let prompt_format = options.zeta2.prompt_format;
store.update(cx, |store, _cx| {
let mut options = store.options().clone();
options.prompt_format = prompt_format.into();
store.set_options(options);
})?;
let mut debug_task = gpui::Task::ready(Ok(()));
if options.provider == crate::PredictionProvider::Zeta2 {
let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
debug_task = cx.background_spawn({
let result = result.clone();
async move {
let mut start_time = None;
let mut retrieval_finished_at = None;
while let Some(event) = debug_rx.next().await {
match event {
edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
fs::write(
example_run_dir.join("search_prompt.md"),
&info.search_prompt,
)?;
if let Some(prompt) = request.prompt {
fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
}
edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
retrieval_finished_at = Some(info.timestamp);
for (key, value) in &info.metadata {
if *key == "search_queries" {
fs::write(
example_run_dir.join("search_queries.json"),
value.as_bytes(),
)?;
}
}
}
DebugEvent::EditPredictionFinished(request) => {
assert_eq!(updated_example.predictions.len(), run_ix + 1);
if let Some(output) = request.model_output {
fs::write(run_dir.join("prediction_response.md"), &output)?;
updated_example
.predictions
.last_mut()
.unwrap()
.actual_output = output;
}
edit_prediction::DebugEvent::EditPredictionRequested(request) => {
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
let prompt = request.local_prompt.unwrap_or_default();
fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
{
let mut result = result.lock().unwrap();
result.prompt_len = prompt.chars().count();
for included_file in request.inputs.included_files {
let insertions =
vec![(request.inputs.cursor_point, CURSOR_MARKER)];
result.excerpts.extend(included_file.excerpts.iter().map(
|excerpt| ActualExcerpt {
path: included_file.path.components().skip(1).collect(),
text: String::from(excerpt.text.as_ref()),
},
));
write_codeblock(
&included_file.path,
included_file.excerpts.iter(),
if included_file.path == request.inputs.cursor_path {
&insertions
} else {
&[]
},
included_file.max_row,
false,
&mut result.excerpts_text,
);
}
}
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
let response =
edit_prediction::open_ai_response::text_from_response(response)
.unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
let mut result = result.lock().unwrap();
result.generated_len = response.chars().count();
result.retrieval_time =
retrieval_finished_at.unwrap() - start_time.unwrap();
result.prediction_time = prediction_finished_at - prediction_started_at;
result.total_time = prediction_finished_at - start_time.unwrap();
if run_ix >= repetition_count {
break;
}
}
_ => {}
}
anyhow::Ok(())
}
});
anyhow::Ok(())
}
});
store.update(cx, |store, cx| {
store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
})?;
}
let prediction = store
.update(cx, |store, cx| {
store.request_prediction(
&project,
&cursor_buffer,
cursor_anchor,
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})?
.await?;
debug_task.await?;
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
result.diff = prediction
.and_then(|prediction| {
let prediction = prediction.prediction.ok()?;
prediction.edit_preview.as_unified_diff(&prediction.edits)
})
.unwrap_or_default();
anyhow::Ok(result)
}
struct RunCache {
cache_mode: CacheMode,
example_run_dir: PathBuf,
}
impl RunCache {
fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
}
fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
}
fn link_to_run(&self, key: &EvalCacheKey) {
let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
}
}
impl EvalCache for RunCache {
fn read(&self, key: EvalCacheKey) -> Option<String> {
let path = RunCache::output_cache_path(&key);
if path.exists() {
let use_cache = match key.0 {
EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
self.cache_mode.use_cached_llm_responses()
}
};
if use_cache {
log::info!("Using cache entry: {}", path.display());
self.link_to_run(&key);
Some(fs::read_to_string(path).unwrap())
} else {
log::trace!("Skipping cached entry: {}", path.display());
None
}
} else if matches!(self.cache_mode, CacheMode::Force) {
panic!(
"No cached entry found for {:?}. Run without `--cache force` at least once.",
key.0
);
for ix in 0..repetition_count {
current_run_ix.store(ix, SeqCst);
let run_dir = if repetition_count > 1 {
run_dir.join(format!("{:03}", ix))
} else {
None
}
}
fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
fs::create_dir_all(&*CACHE_DIR).unwrap();
let input_path = RunCache::input_cache_path(&key);
fs::write(&input_path, input).unwrap();
let output_path = RunCache::output_cache_path(&key);
log::trace!("Writing cache entry: {}", output_path.display());
fs::write(&output_path, output).unwrap();
self.link_to_run(&key);
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PredictionDetails {
pub diff: String,
pub excerpts: Vec<ActualExcerpt>,
pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
pub retrieval_time: Duration,
pub prediction_time: Duration,
pub total_time: Duration,
pub run_example_dir: PathBuf,
pub prompt_len: usize,
pub generated_len: usize,
}
impl PredictionDetails {
pub fn new(run_example_dir: PathBuf) -> Self {
Self {
diff: Default::default(),
excerpts: Default::default(),
excerpts_text: Default::default(),
retrieval_time: Default::default(),
prediction_time: Default::default(),
total_time: Default::default(),
run_example_dir,
prompt_len: 0,
generated_len: 0,
}
}
pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
let formatted = match format {
PredictionsOutputFormat::Md => self.to_markdown(),
PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
PredictionsOutputFormat::Diff => self.diff.clone(),
run_dir.clone()
};
Ok(out.write_all(formatted.as_bytes())?)
fs::create_dir_all(&run_dir)?;
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
}
#[cfg(unix)]
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
#[cfg(windows)]
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
updated_example
.lock()
.unwrap()
.predictions
.push(ExamplePrediction {
actual_patch: String::new(),
actual_output: String::new(),
provider,
});
let prediction = ep_store
.update(&mut cx, |store, cx| {
store.request_prediction(
&state.project,
&state.buffer,
state.cursor_position,
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})?
.await?;
let actual_patch = prediction
.and_then(|prediction| {
let prediction = prediction.prediction.ok()?;
prediction.edit_preview.as_unified_diff(&prediction.edits)
})
.unwrap_or_default();
let has_prediction = !actual_patch.is_empty();
updated_example
.lock()
.unwrap()
.predictions
.last_mut()
.unwrap()
.actual_patch = actual_patch;
if ix == repetition_count - 1 {
let (info, style) = if has_prediction {
("predicted", InfoStyle::Normal)
} else {
("no prediction", InfoStyle::Warning)
};
_step_progress.set_info(info, style);
}
}
pub fn to_markdown(&self) -> String {
format!(
"## Excerpts\n\n\
{}\n\n\
## Prediction\n\n\
{}\n\n\
## Time\n\n\
Retrieval: {}ms\n\
Prediction: {}ms\n\n\
Total: {}ms\n",
self.excerpts_text,
self.diff,
self.retrieval_time.as_millis(),
self.prediction_time.as_millis(),
self.total_time.as_millis(),
)
}
ep_store.update(&mut cx, |store, _| {
store.remove_project(&state.project);
})?;
debug_task.await?;
*example = Arc::into_inner(updated_example)
.ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
.into_inner()
.map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
Ok(())
}
async fn predict_anthropic(
example: &mut Example,
_repetition_count: usize,
batched: bool,
) -> anyhow::Result<()> {
let llm_model_name = "claude-sonnet-4-5";
let max_tokens = 16384;
let llm_client = if batched {
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
} else {
AnthropicClient::plain()
};
let llm_client = llm_client.context("Failed to create LLM client")?;
let prompt = example.prompt.as_ref().context("Prompt is required")?;
let messages = vec![anthropic::Message {
role: anthropic::Role::User,
content: vec![anthropic::RequestContent::Text {
text: prompt.input.clone(),
cache_control: None,
}],
}];
let Some(response) = llm_client
.generate(llm_model_name, max_tokens, messages)
.await?
else {
// Request stashed for batched processing
return Ok(());
};
let actual_output = response
.content
.into_iter()
.filter_map(|content| match content {
anthropic::ResponseContent::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<String>>()
.join("\n");
let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
let prediction = ExamplePrediction {
actual_patch,
actual_output,
provider: PredictionProvider::Teacher,
};
example.predictions.push(prediction);
Ok(())
}
pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
match provider {
PredictionProvider::Teacher => {
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
let llm_client =
AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
llm_client
.sync_batches()
.await
.context("Failed to sync batches")?;
}
_ => (),
};
Ok(())
}

View File

@@ -0,0 +1,508 @@
use std::{
borrow::Cow,
collections::HashMap,
io::{IsTerminal, Write},
sync::{Arc, Mutex, OnceLock},
time::{Duration, Instant},
};
use log::{Level, Log, Metadata, Record};
pub struct Progress {
inner: Mutex<ProgressInner>,
}
struct ProgressInner {
completed: Vec<CompletedTask>,
in_progress: HashMap<String, InProgressTask>,
is_tty: bool,
terminal_width: usize,
max_example_name_len: usize,
status_lines_displayed: usize,
total_examples: usize,
failed_examples: usize,
last_line_is_logging: bool,
}
#[derive(Clone)]
struct InProgressTask {
step: Step,
started_at: Instant,
substatus: Option<String>,
info: Option<(String, InfoStyle)>,
}
struct CompletedTask {
step: Step,
example_name: String,
duration: Duration,
info: Option<(String, InfoStyle)>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Step {
LoadProject,
Context,
FormatPrompt,
Predict,
Score,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InfoStyle {
Normal,
Warning,
}
impl Step {
pub fn label(&self) -> &'static str {
match self {
Step::LoadProject => "Load",
Step::Context => "Context",
Step::FormatPrompt => "Format",
Step::Predict => "Predict",
Step::Score => "Score",
}
}
fn color_code(&self) -> &'static str {
match self {
Step::LoadProject => "\x1b[33m",
Step::Context => "\x1b[35m",
Step::FormatPrompt => "\x1b[34m",
Step::Predict => "\x1b[32m",
Step::Score => "\x1b[31m",
}
}
}
static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
static LOGGER: ProgressLogger = ProgressLogger;
const MARGIN: usize = 4;
const MAX_STATUS_LINES: usize = 10;
impl Progress {
/// Returns the global Progress instance, initializing it if necessary.
pub fn global() -> Arc<Progress> {
GLOBAL
.get_or_init(|| {
let progress = Arc::new(Self {
inner: Mutex::new(ProgressInner {
completed: Vec::new(),
in_progress: HashMap::new(),
is_tty: std::io::stderr().is_terminal(),
terminal_width: get_terminal_width(),
max_example_name_len: 0,
status_lines_displayed: 0,
total_examples: 0,
failed_examples: 0,
last_line_is_logging: false,
}),
});
let _ = log::set_logger(&LOGGER);
log::set_max_level(log::LevelFilter::Error);
progress
})
.clone()
}
pub fn set_total_examples(&self, total: usize) {
let mut inner = self.inner.lock().unwrap();
inner.total_examples = total;
}
pub fn increment_failed(&self) {
let mut inner = self.inner.lock().unwrap();
inner.failed_examples += 1;
}
/// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
/// This should be used for any output that needs to appear above the status lines.
fn log(&self, message: &str) {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
if !inner.last_line_is_logging {
let reset = "\x1b[0m";
let dim = "\x1b[2m";
let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
eprintln!("{dim}{divider}{reset}");
inner.last_line_is_logging = true;
}
eprintln!("{}", message);
}
pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
inner.in_progress.insert(
example_name.to_string(),
InProgressTask {
step,
started_at: Instant::now(),
substatus: None,
info: None,
},
);
Self::print_status_lines(&mut inner);
StepProgress {
progress: self.clone(),
step,
example_name: example_name.to_string(),
}
}
fn finish(&self, step: Step, example_name: &str) {
let mut inner = self.inner.lock().unwrap();
let Some(task) = inner.in_progress.remove(example_name) else {
return;
};
if task.step == step {
inner.completed.push(CompletedTask {
step: task.step,
example_name: example_name.to_string(),
duration: task.started_at.elapsed(),
info: task.info,
});
Self::clear_status_lines(&mut inner);
Self::print_logging_closing_divider(&mut inner);
Self::print_completed(&inner, inner.completed.last().unwrap());
Self::print_status_lines(&mut inner);
} else {
inner.in_progress.insert(example_name.to_string(), task);
}
}
fn print_logging_closing_divider(inner: &mut ProgressInner) {
if inner.last_line_is_logging {
let reset = "\x1b[0m";
let dim = "\x1b[2m";
let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
eprintln!("{dim}{divider}{reset}");
inner.last_line_is_logging = false;
}
}
fn clear_status_lines(inner: &mut ProgressInner) {
if inner.is_tty && inner.status_lines_displayed > 0 {
// Move up and clear each line we previously displayed
for _ in 0..inner.status_lines_displayed {
eprint!("\x1b[A\x1b[K");
}
let _ = std::io::stderr().flush();
inner.status_lines_displayed = 0;
}
}
fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
let duration = format_duration(task.duration);
let name_width = inner.max_example_name_len;
if inner.is_tty {
let reset = "\x1b[0m";
let bold = "\x1b[1m";
let dim = "\x1b[2m";
let yellow = "\x1b[33m";
let info_part = task
.info
.as_ref()
.map(|(s, style)| {
if *style == InfoStyle::Warning {
format!("{yellow}{s}{reset}")
} else {
s.to_string()
}
})
.unwrap_or_default();
let prefix = format!(
"{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}{reset} {info_part}",
color = task.step.color_code(),
label = task.step.label(),
name = task.example_name,
);
let duration_with_margin = format!("{duration} ");
let padding_needed = inner
.terminal_width
.saturating_sub(MARGIN)
.saturating_sub(duration_with_margin.len())
.saturating_sub(strip_ansi_len(&prefix));
let padding = " ".repeat(padding_needed);
eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
} else {
let info_part = task
.info
.as_ref()
.map(|(s, _)| format!(" | {}", s))
.unwrap_or_default();
eprintln!(
"{label:>12} {name:<name_width$}{info_part} {duration}",
label = task.step.label(),
name = task.example_name,
);
}
}
fn print_status_lines(inner: &mut ProgressInner) {
if !inner.is_tty || inner.in_progress.is_empty() {
inner.status_lines_displayed = 0;
return;
}
let reset = "\x1b[0m";
let bold = "\x1b[1m";
let dim = "\x1b[2m";
// Build the done/in-progress/total label
let done_count = inner.completed.len();
let in_progress_count = inner.in_progress.len();
let failed_count = inner.failed_examples;
let failed_label = if failed_count > 0 {
format!(" {} failed ", failed_count)
} else {
String::new()
};
let range_label = format!(
" {}/{}/{} ",
done_count, in_progress_count, inner.total_examples
);
// Print a divider line with failed count on left, range label on right
let failed_visible_len = strip_ansi_len(&failed_label);
let range_visible_len = range_label.len();
let middle_divider_len = inner
.terminal_width
.saturating_sub(MARGIN * 2)
.saturating_sub(failed_visible_len)
.saturating_sub(range_visible_len);
let left_divider = "".repeat(MARGIN);
let middle_divider = "".repeat(middle_divider_len);
let right_divider = "".repeat(MARGIN);
eprintln!(
"{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
);
let mut tasks: Vec<_> = inner.in_progress.iter().collect();
tasks.sort_by_key(|(name, _)| *name);
let total_tasks = tasks.len();
let mut lines_printed = 0;
for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
let elapsed = format_duration(task.started_at.elapsed());
let substatus_part = task
.substatus
.as_ref()
.map(|s| truncate_with_ellipsis(s, 30))
.unwrap_or_default();
let step_label = task.step.label();
let step_color = task.step.color_code();
let name_width = inner.max_example_name_len;
let prefix = format!(
"{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}{reset} {substatus_part}",
name = name,
);
let duration_with_margin = format!("{elapsed} ");
let padding_needed = inner
.terminal_width
.saturating_sub(MARGIN)
.saturating_sub(duration_with_margin.len())
.saturating_sub(strip_ansi_len(&prefix));
let padding = " ".repeat(padding_needed);
eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
lines_printed += 1;
}
// Show "+N more" on its own line if there are more tasks
if total_tasks > MAX_STATUS_LINES {
let remaining = total_tasks - MAX_STATUS_LINES;
eprintln!("{:>12} +{remaining} more", "");
lines_printed += 1;
}
inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
let _ = std::io::stderr().flush();
}
pub fn finalize(&self) {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
// Print summary if there were failures
if inner.failed_examples > 0 {
let total_processed = inner.completed.len() + inner.failed_examples;
let percentage = if total_processed > 0 {
inner.failed_examples as f64 / total_processed as f64 * 100.0
} else {
0.0
};
eprintln!(
"\n{} of {} examples failed ({:.1}%)",
inner.failed_examples, total_processed, percentage
);
}
}
}
pub struct StepProgress {
progress: Arc<Progress>,
step: Step,
example_name: String,
}
impl StepProgress {
pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
let mut inner = self.progress.inner.lock().unwrap();
if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
task.substatus = Some(substatus.into().into_owned());
Progress::clear_status_lines(&mut inner);
Progress::print_status_lines(&mut inner);
}
}
pub fn clear_substatus(&self) {
let mut inner = self.progress.inner.lock().unwrap();
if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
task.substatus = None;
Progress::clear_status_lines(&mut inner);
Progress::print_status_lines(&mut inner);
}
}
pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
let mut inner = self.progress.inner.lock().unwrap();
if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
task.info = Some((info.into(), style));
}
}
}
impl Drop for StepProgress {
fn drop(&mut self) {
self.progress.finish(self.step, &self.example_name);
}
}
struct ProgressLogger;
impl Log for ProgressLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
metadata.level() <= Level::Info
}
fn log(&self, record: &Record) {
if !self.enabled(record.metadata()) {
return;
}
let level_color = match record.level() {
Level::Error => "\x1b[31m",
Level::Warn => "\x1b[33m",
Level::Info => "\x1b[32m",
Level::Debug => "\x1b[34m",
Level::Trace => "\x1b[35m",
};
let reset = "\x1b[0m";
let bold = "\x1b[1m";
let level_label = match record.level() {
Level::Error => "Error",
Level::Warn => "Warn",
Level::Info => "Info",
Level::Debug => "Debug",
Level::Trace => "Trace",
};
let message = format!(
"{bold}{level_color}{level_label:>12}{reset} {}",
record.args()
);
if let Some(progress) = GLOBAL.get() {
progress.log(&message);
} else {
eprintln!("{}", message);
}
}
fn flush(&self) {
let _ = std::io::stderr().flush();
}
}
#[cfg(unix)]
fn get_terminal_width() -> usize {
unsafe {
let mut winsize: libc::winsize = std::mem::zeroed();
if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
&& winsize.ws_col > 0
{
winsize.ws_col as usize
} else {
80
}
}
}
#[cfg(not(unix))]
fn get_terminal_width() -> usize {
80
}
fn strip_ansi_len(s: &str) -> usize {
let mut len = 0;
let mut in_escape = false;
for c in s.chars() {
if c == '\x1b' {
in_escape = true;
} else if in_escape {
if c == 'm' {
in_escape = false;
}
} else {
len += 1;
}
}
len
}
fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}", &s[..max_len.saturating_sub(1)])
}
}
fn format_duration(duration: Duration) -> String {
const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
let millis = duration.as_millis() as f32;
if millis < 1000.0 {
format!("{}ms", millis)
} else if millis < MINUTE_IN_MILLIS {
format!("{:.1}s", millis / 1_000.0)
} else {
format!("{:.1}m", millis / MINUTE_IN_MILLIS)
}
}

View File

@@ -0,0 +1,192 @@
use crate::{
example::{Example, ExampleContext},
headless::EpAppState,
load_project::run_load_project,
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::Context as _;
use collections::HashSet;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
use gpui::{AsyncApp, Entity};
use language::Buffer;
use project::Project;
use std::sync::Arc;
use std::time::Duration;
pub async fn run_context_retrieval(
example: &mut Example,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) -> anyhow::Result<()> {
if example.context.is_some() {
return Ok(());
}
run_load_project(example, app_state.clone(), cx.clone()).await?;
let step_progress: Arc<StepProgress> = Progress::global()
.start(Step::Context, &example.name)
.into();
let state = example.state.as_ref().unwrap();
let project = state.project.clone();
let _lsp_handle = project.update(&mut cx, |project, cx| {
project.register_buffer_with_language_servers(&state.buffer, cx)
})?;
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
let mut events = ep_store.update(&mut cx, |store, cx| {
store.register_buffer(&state.buffer, &project, cx);
store.set_use_context(true);
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
store.debug_info(&project, cx)
})?;
while let Some(event) = events.next().await {
match event {
DebugEvent::ContextRetrievalFinished(_) => {
break;
}
_ => {}
}
}
let context_files =
ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx))?;
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
example.context = Some(ExampleContext {
files: context_files,
});
Ok(())
}
async fn wait_for_language_servers_to_start(
project: &Entity<Project>,
buffer: &Entity<Buffer>,
step_progress: &Arc<StepProgress>,
cx: &mut AsyncApp,
) -> anyhow::Result<()> {
let lsp_store = project.read_with(cx, |project, _| project.lsp_store())?;
let (language_server_ids, mut starting_language_server_ids) = buffer
.update(cx, |buffer, cx| {
lsp_store.update(cx, |lsp_store, cx| {
let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
let starting_ids = ids
.iter()
.copied()
.filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
.collect::<HashSet<_>>();
(ids, starting_ids)
})
})
.unwrap_or_default();
step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
let timeout = cx
.background_executor()
.timer(Duration::from_secs(60 * 5))
.shared();
let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
let added_subscription = cx.subscribe(project, {
let step_progress = step_progress.clone();
move |_, event, _| match event {
project::Event::LanguageServerAdded(language_server_id, name, _) => {
step_progress.set_substatus(format!("LSP started: {}", name));
tx.try_send(*language_server_id).ok();
}
_ => {}
}
});
while !starting_language_server_ids.is_empty() {
futures::select! {
language_server_id = rx.next() => {
if let Some(id) = language_server_id {
starting_language_server_ids.remove(&id);
}
},
_ = timeout.clone().fuse() => {
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
}
}
}
drop(added_subscription);
if !language_server_ids.is_empty() {
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.detach();
}
let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
let subscriptions = [
cx.subscribe(&lsp_store, {
let step_progress = step_progress.clone();
move |_, event, _| {
if let project::LspStoreEvent::LanguageServerUpdate {
message:
client::proto::update_language_server::Variant::WorkProgress(
client::proto::LspWorkProgress {
message: Some(message),
..
},
),
..
} = event
{
step_progress.set_substatus(message.clone());
}
}
}),
cx.subscribe(project, {
let step_progress = step_progress.clone();
move |_, event, cx| match event {
project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
let lsp_store = lsp_store.read(cx);
let name = lsp_store
.language_server_adapter_for_id(*language_server_id)
.unwrap()
.name();
step_progress.set_substatus(format!("LSP idle: {}", name));
tx.try_send(*language_server_id).ok();
}
_ => {}
}
}),
];
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
while !pending_language_server_ids.is_empty() {
futures::select! {
language_server_id = rx.next() => {
if let Some(id) = language_server_id {
pending_language_server_ids.remove(&id);
}
},
_ = timeout.clone().fuse() => {
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
}
}
}
drop(subscriptions);
step_progress.clear_substatus();
Ok(())
}

View File

@@ -0,0 +1,123 @@
use crate::{
PredictArgs,
example::{Example, ExampleScore},
headless::EpAppState,
metrics::{self, ClassificationMetrics},
predict::run_prediction,
progress::{Progress, Step},
};
use edit_prediction::udiff::DiffLine;
use gpui::AsyncApp;
use std::sync::Arc;
pub async fn run_scoring(
example: &mut Example,
args: &PredictArgs,
app_state: Arc<EpAppState>,
cx: AsyncApp,
) -> anyhow::Result<()> {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state,
cx,
)
.await?;
let _progress = Progress::global().start(Step::Score, &example.name);
let expected_patch = parse_patch(&example.expected_patch);
let mut scores = vec![];
for pred in &example.predictions {
let actual_patch = parse_patch(&pred.actual_patch);
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
scores.push(ExampleScore {
delta_chr_f,
line_match,
});
}
example.score = scores;
Ok(())
}
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
patch.lines().map(DiffLine::parse).collect()
}
pub fn print_report(examples: &[Example]) {
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
"Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
);
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
let mut all_line_match_scores = Vec::new();
let mut all_delta_chr_f_scores = Vec::new();
for example in examples {
for score in example.score.iter() {
let line_match = &score.line_match;
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
truncate_name(&example.name, 30),
line_match.true_positives,
line_match.false_positives,
line_match.false_negatives,
line_match.precision() * 100.0,
line_match.recall() * 100.0,
line_match.f1_score() * 100.0,
score.delta_chr_f
);
all_line_match_scores.push(line_match.clone());
all_delta_chr_f_scores.push(score.delta_chr_f);
}
}
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
if !all_line_match_scores.is_empty() {
let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
let avg_delta_chr_f: f32 =
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
"TOTAL",
total_line_match.true_positives,
total_line_match.false_positives,
total_line_match.false_negatives,
total_line_match.precision() * 100.0,
total_line_match.recall() * 100.0,
total_line_match.f1_score() * 100.0,
avg_delta_chr_f
);
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
}
eprintln!("\n");
}
fn truncate_name(name: &str, max_len: usize) -> String {
if name.len() <= max_len {
name.to_string()
} else {
format!("{}...", &name[..max_len - 3])
}
}

View File

@@ -1,70 +0,0 @@
use std::{fmt, fmt::Display, path::Path, str::FromStr, sync::Arc};
use ::util::{paths::PathStyle, rel_path::RelPath};
use anyhow::{Result, anyhow};
use language::Point;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct SourceLocation {
pub path: Arc<RelPath>,
pub point: Point,
}
impl Serialize for SourceLocation {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for SourceLocation {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
impl Display for SourceLocation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}:{}:{}",
self.path.display(PathStyle::Posix),
self.point.row + 1,
self.point.column + 1
)
}
}
impl FromStr for SourceLocation {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
let parts: Vec<&str> = s.split(':').collect();
if parts.len() != 3 {
return Err(anyhow!(
"Invalid source location. Expected 'file.rs:line:column', got '{}'",
s
));
}
let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
let line: u32 = parts[1]
.parse()
.map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
let column: u32 = parts[2]
.parse()
.map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
// Convert from 1-based to 0-based indexing
let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
Ok(SourceLocation { path, point })
}
}

View File

@@ -18,6 +18,7 @@ Focus on:
Rules:
- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
- Keep existing formatting unless it's absolutely necessary
Input format:
- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.
@@ -46,3 +47,7 @@ Output example:
## Code Context
{{context}}
## Editable region
{{editable_region}}

View File

@@ -1,89 +0,0 @@
use std::path::Path;
use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
#[derive(Debug, Clone, Default, clap::ValueEnum)]
pub enum ContextType {
#[default]
CurrentFile,
}
const MAX_CONTEXT_SIZE: usize = 32768;
pub fn collect_context(
context_type: &ContextType,
worktree_dir: &Path,
cursor: SourceLocation,
) -> String {
let context = match context_type {
ContextType::CurrentFile => {
let file_path = worktree_dir.join(cursor.path.as_std_path());
let context = std::fs::read_to_string(&file_path).unwrap_or_default();
let context = add_special_tags(&context, worktree_dir, cursor);
context
}
};
let region_end_offset = context.find(TeacherModel::REGION_END);
if context.len() <= MAX_CONTEXT_SIZE {
return context;
}
if let Some(region_end_offset) = region_end_offset
&& region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
{
let to_truncate = context.len() - MAX_CONTEXT_SIZE;
format!(
"[...{} bytes truncated]\n{}\n",
to_truncate,
&context[to_truncate..]
)
} else {
format!(
"{}\n[...{} bytes truncated]\n",
&context[..MAX_CONTEXT_SIZE],
context.len() - MAX_CONTEXT_SIZE
)
}
}
/// Add <|editable_region_start/end|> tags
fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
let path = worktree_dir.join(cursor.path.as_std_path());
let file = std::fs::read_to_string(&path).unwrap_or_default();
let lines = file.lines().collect::<Vec<_>>();
let cursor_row = cursor.point.row as usize;
let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
let snippet = lines[start_line..end_line].join("\n");
if context.contains(&snippet) {
let mut cursor_line = lines[cursor_row].to_string();
cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
let mut snippet_with_tags_lines = vec![];
snippet_with_tags_lines.push(TeacherModel::REGION_START);
snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
snippet_with_tags_lines.push(&cursor_line);
snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
snippet_with_tags_lines.push(TeacherModel::REGION_END);
let snippet_with_tags = snippet_with_tags_lines.join("\n");
context.replace(&snippet, &snippet_with_tags)
} else {
log::warn!(
"Can't find area around the cursor in the context; proceeding without special tags"
);
context.to_string()
}
}
pub fn strip_special_tags(context: &str) -> String {
context
.replace(TeacherModel::REGION_START, "")
.replace(TeacherModel::REGION_END, "")
.replace(TeacherModel::USER_CURSOR, "")
}

View File

@@ -1,94 +0,0 @@
use serde::Deserialize;
use std::sync::Arc;
use crate::{
DistillArguments,
example::Example,
source_location::SourceLocation,
training::{
context::ContextType,
llm_client::LlmClient,
teacher::{TeacherModel, TeacherOutput},
},
};
use anyhow::Result;
use reqwest_client::ReqwestClient;
#[derive(Debug, Deserialize)]
pub struct SplitCommit {
repo_url: String,
commit_sha: String,
edit_history: String,
expected_patch: String,
cursor_position: String,
}
pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
.expect("Failed to read split commit dataset")
.lines()
.map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
.collect();
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let llm_client = if let Some(cache_path) = arguments.batch {
LlmClient::batch(&cache_path, http_client)?
} else {
LlmClient::plain(http_client)?
};
let mut teacher = TeacherModel::new(
"claude-sonnet-4-5".to_string(),
ContextType::CurrentFile,
llm_client,
);
let mut num_marked_for_batching = 0;
for commit in split_commits {
if let Some(distilled) = distill_one(&mut teacher, commit).await? {
println!("{}", serde_json::to_string(&distilled)?);
} else {
if num_marked_for_batching == 0 {
log::warn!("Marked for batching");
}
num_marked_for_batching += 1;
}
}
eprintln!(
"{} requests are marked for batching",
num_marked_for_batching
);
let llm_client = teacher.client;
llm_client.sync_batches().await?;
Ok(())
}
pub async fn distill_one(
teacher: &mut TeacherModel,
commit: SplitCommit,
) -> Result<Option<TeacherOutput>> {
let cursor: SourceLocation = commit
.cursor_position
.parse()
.expect("Failed to parse cursor position");
let path = cursor.path.to_rel_path_buf();
let example = Example {
repository_url: commit.repo_url,
revision: commit.commit_sha,
uncommitted_diff: commit.edit_history.clone(),
cursor_path: path.as_std_path().to_path_buf(),
cursor_position: commit.cursor_position,
edit_history: commit.edit_history, // todo: trim
expected_patch: commit.expected_patch,
};
let prediction = teacher.predict(example).await;
prediction
}

View File

@@ -1,4 +0,0 @@
pub mod context;
pub mod distill;
pub mod llm_client;
pub mod teacher;

View File

@@ -1,266 +0,0 @@
use crate::{
example::Example,
source_location::SourceLocation,
training::{
context::{ContextType, collect_context, strip_special_tags},
llm_client::LlmClient,
},
};
use anthropic::{Message, RequestContent, ResponseContent, Role};
use anyhow::Result;
pub struct TeacherModel {
pub llm_name: String,
pub context: ContextType,
pub client: LlmClient,
}
#[derive(Debug, serde::Serialize)]
pub struct TeacherOutput {
parsed_output: String,
prompt: String,
raw_llm_response: String,
context: String,
diff: String,
}
impl TeacherModel {
const PROMPT: &str = include_str!("teacher.prompt.md");
pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
pub(crate) const REGION_END: &str = "<|editable_region_end|>";
pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
/// Number of lines to include before the cursor position
pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
/// Number of lines to include after the cursor position
pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
/// Truncate edit history to this number of last lines
const MAX_HISTORY_LINES: usize = 128;
pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
TeacherModel {
llm_name,
context,
client,
}
}
pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
let name = input.unique_name();
let worktree_dir = input.setup_worktree(name).await?;
let cursor: SourceLocation = input
.cursor_position
.parse()
.expect("Failed to parse cursor position");
let context = collect_context(&self.context, &worktree_dir, cursor.clone());
let edit_history = Self::format_edit_history(&input.edit_history);
let prompt = Self::PROMPT
.replace("{{context}}", &context)
.replace("{{edit_history}}", &edit_history);
let messages = vec![Message {
role: Role::User,
content: vec![RequestContent::Text {
text: prompt.clone(),
cache_control: None,
}],
}];
let Some(response) = self
.client
.generate(self.llm_name.clone(), 16384, messages)
.await?
else {
return Ok(None);
};
let response_text = response
.content
.into_iter()
.filter_map(|content| match content {
ResponseContent::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<String>>()
.join("\n");
let parsed_output = self.parse_response(&response_text);
let original_editable_region = Self::extract_editable_region(&context);
let context_after_edit = context.replace(&original_editable_region, &parsed_output);
let context_after_edit = strip_special_tags(&context_after_edit);
let context_before_edit = strip_special_tags(&context);
let diff = language::unified_diff(&context_before_edit, &context_after_edit);
// zeta distill --batch batch_results.txt
// zeta distill
// 1. Run `zeta distill <2000 examples <- all examples>` for the first time
// - store LLM requests in a batch, don't actual send the request
// - send the batch (2000 requests) after all inputs are processed
// 2. `zeta send-batches`
// - upload the batch to Anthropic
// https://platform.claude.com/docs/en/build-with-claude/batch-processing
// https://crates.io/crates/anthropic-sdk-rust
// - poll for results
// - when ready, store results in cache (a database)
// 3. `zeta distill` again
// - use the cached results this time
Ok(Some(TeacherOutput {
parsed_output,
prompt,
raw_llm_response: response_text,
context,
diff,
}))
}
fn parse_response(&self, content: &str) -> String {
let codeblock = Self::extract_last_codeblock(content);
let editable_region = Self::extract_editable_region(&codeblock);
editable_region
}
/// Extract content from the last code-fenced block if any, or else return content as is
fn extract_last_codeblock(text: &str) -> String {
let mut last_block = None;
let mut search_start = 0;
while let Some(start) = text[search_start..].find("```") {
let start = start + search_start;
let bytes = text.as_bytes();
let mut backtick_end = start;
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
backtick_end += 1;
}
let backtick_count = backtick_end - start;
let closing_backticks = "`".repeat(backtick_count);
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
last_block = Some(code_block.to_string());
search_start = backtick_end + end_pos + backtick_count;
} else {
break;
}
}
last_block.unwrap_or_else(|| text.to_string())
}
fn extract_editable_region(text: &str) -> String {
let start = text
.find(Self::REGION_START)
.map_or(0, |pos| pos + Self::REGION_START.len());
let end = text.find(Self::REGION_END).unwrap_or(text.len());
text[start..end].to_string()
}
/// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
fn format_edit_history(edit_history: &str) -> String {
let lines = edit_history
.lines()
.filter(|&s| Self::is_content_line(s))
.collect::<Vec<_>>();
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
} else {
&lines
};
history_lines.join("\n")
}
fn is_content_line(s: &str) -> bool {
s.starts_with("-")
|| s.starts_with("+")
|| s.starts_with(" ")
|| s.starts_with("---")
|| s.starts_with("+++")
|| s.starts_with("@@")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_response() {
let teacher = TeacherModel::new(
"test".to_string(),
ContextType::CurrentFile,
LlmClient::dummy(),
);
let response = "This is a test response.";
let parsed = teacher.parse_response(response);
assert_eq!(parsed, response.to_string());
let response = indoc::indoc! {"
Some thinking
`````
actual response
`````
"};
let parsed = teacher.parse_response(response);
assert_eq!(parsed, "actual response");
}
#[test]
fn test_extract_last_code_block() {
let text = indoc::indoc! {"
Some thinking
```
first block
```
`````
last block
`````
"};
let last_block = TeacherModel::extract_last_codeblock(text);
assert_eq!(last_block, "last block");
}
#[test]
fn test_extract_editable_region() {
let teacher = TeacherModel::new(
"test".to_string(),
ContextType::CurrentFile,
LlmClient::dummy(),
);
let response = indoc::indoc! {"
some lines
are
here
<|editable_region_start|>
one
two three
<|editable_region_end|>
more
lines here
"};
let parsed = teacher.parse_response(response);
assert_eq!(
parsed,
indoc::indoc! {"
one
two three
"}
);
}
}

View File

@@ -1,198 +0,0 @@
use anyhow::{Result, anyhow};
use futures::channel::mpsc;
use futures::{FutureExt as _, StreamExt as _};
use gpui::{AsyncApp, Entity, Task};
use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
use project::lsp_store::OpenLspBufferHandle;
use project::{Project, ProjectPath, Worktree};
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use util::rel_path::RelPath;
pub fn open_buffer(
project: Entity<Project>,
worktree: Entity<Worktree>,
path: Arc<RelPath>,
cx: &AsyncApp,
) -> Task<Result<Entity<Buffer>>> {
cx.spawn(async move |cx| {
let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
worktree_id: worktree.id(),
path,
})?;
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
}
Ok(buffer)
})
}
pub async fn open_buffer_with_language_server(
project: Entity<Project>,
worktree: Entity<Worktree>,
path: Arc<RelPath>,
ready_languages: &mut HashSet<LanguageId>,
cx: &mut AsyncApp,
) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
(
project.register_buffer_with_language_servers(&buffer, cx),
project.path_style(cx),
)
})?;
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
let result = language_registry
.load_language_for_file_path(path.as_std_path())
.await;
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
anyhow::bail!(error);
}
let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
buffer.language().map(|language| language.id())
})?
else {
return Err(anyhow!("No language for {}", path.display(path_style)));
};
let log_prefix = format!("{} | ", path.display(path_style));
if !ready_languages.contains(&language_id) {
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
ready_languages.insert(language_id);
}
let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
// hacky wait for buffer to be registered with the language server
for _ in 0..100 {
let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
buffer.update(cx, |buffer, cx| {
lsp_store
.language_servers_for_local_buffer(&buffer, cx)
.next()
.map(|(_, language_server)| language_server.server_id())
})
})?
else {
cx.background_executor()
.timer(Duration::from_millis(10))
.await;
continue;
};
return Ok((lsp_open_handle, language_server_id, buffer));
}
return Err(anyhow!("No language server found for buffer"));
}
// TODO: Dedupe with similar function in crates/eval/src/instance.rs
pub fn wait_for_lang_server(
project: &Entity<Project>,
buffer: &Entity<Buffer>,
log_prefix: String,
cx: &mut AsyncApp,
) -> Task<Result<()>> {
eprintln!("{}⏵ Waiting for language server", log_prefix);
let (mut tx, mut rx) = mpsc::channel(1);
let lsp_store = project
.read_with(cx, |project, _| project.lsp_store())
.unwrap();
let has_lang_server = buffer
.update(cx, |buffer, cx| {
lsp_store.update(cx, |lsp_store, cx| {
lsp_store
.language_servers_for_local_buffer(buffer, cx)
.next()
.is_some()
})
})
.unwrap_or(false);
if has_lang_server {
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.unwrap()
.detach();
}
let (mut added_tx, mut added_rx) = mpsc::channel(1);
let subscriptions = [
cx.subscribe(&lsp_store, {
let log_prefix = log_prefix.clone();
move |_, event, _| {
if let project::LspStoreEvent::LanguageServerUpdate {
message:
client::proto::update_language_server::Variant::WorkProgress(
client::proto::LspWorkProgress {
message: Some(message),
..
},
),
..
} = event
{
eprintln!("{}{message}", log_prefix)
}
}
}),
cx.subscribe(project, {
let buffer = buffer.clone();
move |project, event, cx| match event {
project::Event::LanguageServerAdded(_, _, _) => {
let buffer = buffer.clone();
project
.update(cx, |project, cx| project.save_buffer(buffer, cx))
.detach();
added_tx.try_send(()).ok();
}
project::Event::DiskBasedDiagnosticsFinished { .. } => {
tx.try_send(()).ok();
}
_ => {}
}
}),
];
cx.spawn(async move |cx| {
if !has_lang_server {
// some buffers never have a language server, so this aborts quickly in that case.
let timeout = cx.background_executor().timer(Duration::from_secs(500));
futures::select! {
_ = added_rx.next() => {},
_ = timeout.fuse() => {
anyhow::bail!("Waiting for language server add timed out after 5 seconds");
}
};
}
let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
let result = futures::select! {
_ = rx.next() => {
eprintln!("{}⚑ Language server idle", log_prefix);
anyhow::Ok(())
},
_ = timeout.fuse() => {
anyhow::bail!("LSP wait timed out after 5 minutes");
}
};
drop(subscriptions);
result
})
}

View File

@@ -26,6 +26,7 @@ serde.workspace = true
smallvec.workspace = true
tree-sitter.workspace = true
util.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
env_logger.workspace = true

View File

@@ -1,6 +1,6 @@
use crate::RelatedExcerpt;
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
use std::ops::Range;
use zeta_prompt::RelatedExcerpt;
#[cfg(not(test))]
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
@@ -76,14 +76,9 @@ pub fn assemble_excerpts(
input_ranges
.into_iter()
.map(|range| {
let offset_range = range.to_offset(buffer);
RelatedExcerpt {
point_range: range,
anchor_range: buffer.anchor_before(offset_range.start)
..buffer.anchor_after(offset_range.end),
text: buffer.as_rope().slice(offset_range),
}
.map(|range| RelatedExcerpt {
row_range: range.start.row..range.end.row,
text: buffer.text_for_range(range).collect(),
})
.collect()
}

View File

@@ -3,13 +3,13 @@ use anyhow::Result;
use collections::HashMap;
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
use project::{LocationLink, Project, ProjectPath};
use serde::{Serialize, Serializer};
use smallvec::SmallVec;
use std::{
collections::hash_map,
ops::Range,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
@@ -24,12 +24,14 @@ mod fake_definition_lsp;
pub use cloud_llm_client::predict_edits_v3::Line;
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
const IDENTIFIER_LINE_COUNT: u32 = 3;
pub struct RelatedExcerptStore {
project: WeakEntity<Project>,
related_files: Vec<RelatedFile>,
related_files: Arc<[RelatedFile]>,
related_file_buffers: Vec<Entity<Buffer>>,
cache: HashMap<Identifier, Arc<CacheEntry>>,
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
identifier_line_count: u32,
@@ -68,82 +70,6 @@ struct CachedDefinition {
anchor_range: Range<Anchor>,
}
#[derive(Clone, Debug, Serialize)]
pub struct RelatedFile {
#[serde(serialize_with = "serialize_project_path")]
pub path: ProjectPath,
#[serde(skip)]
pub buffer: WeakEntity<Buffer>,
pub excerpts: Vec<RelatedExcerpt>,
pub max_row: u32,
}
impl RelatedFile {
pub fn merge_excerpts(&mut self) {
self.excerpts.sort_unstable_by(|a, b| {
a.point_range
.start
.cmp(&b.point_range.start)
.then(b.point_range.end.cmp(&a.point_range.end))
});
let mut index = 1;
while index < self.excerpts.len() {
if self.excerpts[index - 1]
.point_range
.end
.cmp(&self.excerpts[index].point_range.start)
.is_ge()
{
let removed = self.excerpts.remove(index);
if removed
.point_range
.end
.cmp(&self.excerpts[index - 1].point_range.end)
.is_gt()
{
self.excerpts[index - 1].point_range.end = removed.point_range.end;
self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
}
} else {
index += 1;
}
}
}
}
#[derive(Clone, Debug, Serialize)]
pub struct RelatedExcerpt {
#[serde(skip)]
pub anchor_range: Range<Anchor>,
#[serde(serialize_with = "serialize_point_range")]
pub point_range: Range<Point>,
#[serde(serialize_with = "serialize_rope")]
pub text: Rope,
}
fn serialize_project_path<S: Serializer>(
project_path: &ProjectPath,
serializer: S,
) -> Result<S::Ok, S::Error> {
project_path.path.serialize(serializer)
}
fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
rope.to_string().serialize(serializer)
}
fn serialize_point_range<S: Serializer>(
range: &Range<Point>,
serializer: S,
) -> Result<S::Ok, S::Error> {
[
[range.start.row, range.start.column],
[range.end.row, range.end.column],
]
.serialize(serializer)
}
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
@@ -179,7 +105,8 @@ impl RelatedExcerptStore {
RelatedExcerptStore {
project: project.downgrade(),
update_tx,
related_files: Vec::new(),
related_files: Vec::new().into(),
related_file_buffers: Vec::new(),
cache: Default::default(),
identifier_line_count: IDENTIFIER_LINE_COUNT,
}
@@ -193,8 +120,21 @@ impl RelatedExcerptStore {
self.update_tx.unbounded_send((buffer, position)).ok();
}
pub fn related_files(&self) -> &[RelatedFile] {
&self.related_files
pub fn related_files(&self) -> Arc<[RelatedFile]> {
self.related_files.clone()
}
pub fn related_files_with_buffers(
&self,
) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
self.related_files
.iter()
.cloned()
.zip(self.related_file_buffers.iter().cloned())
}
pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
self.related_files = files.into();
}
async fn fetch_excerpts(
@@ -297,7 +237,8 @@ impl RelatedExcerptStore {
}
mean_definition_latency /= cache_miss_count.max(1) as u32;
let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
let (new_cache, related_files, related_file_buffers) =
rebuild_related_files(&project, new_cache, cx).await?;
if let Some(file) = &file {
log::debug!(
@@ -309,7 +250,8 @@ impl RelatedExcerptStore {
this.update(cx, |this, cx| {
this.cache = new_cache;
this.related_files = related_files;
this.related_files = related_files.into();
this.related_file_buffers = related_file_buffers;
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
@@ -323,10 +265,16 @@ impl RelatedExcerptStore {
}
async fn rebuild_related_files(
project: &Entity<Project>,
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
cx: &mut AsyncApp,
) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
) -> Result<(
HashMap<Identifier, Arc<CacheEntry>>,
Vec<RelatedFile>,
Vec<Entity<Buffer>>,
)> {
let mut snapshots = HashMap::default();
let mut worktree_root_names = HashMap::default();
for entry in new_entries.values() {
for definition in &entry.definitions {
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
@@ -340,12 +288,22 @@ async fn rebuild_related_files(
.read_with(cx, |buffer, _| buffer.snapshot())?,
);
}
let worktree_id = definition.path.worktree_id;
if let hash_map::Entry::Vacant(e) =
worktree_root_names.entry(definition.path.worktree_id)
{
project.read_with(cx, |project, cx| {
if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
}
})?;
}
}
}
Ok(cx
.background_spawn(async move {
let mut files = Vec::<RelatedFile>::new();
let mut files = Vec::new();
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
let mut paths_by_buffer = HashMap::default();
for entry in new_entries.values() {
@@ -369,16 +327,31 @@ async fn rebuild_related_files(
continue;
};
let excerpts = assemble_excerpts(snapshot, ranges);
files.push(RelatedFile {
path: project_path.clone(),
buffer: buffer.downgrade(),
excerpts,
max_row: snapshot.max_point().row,
});
let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
continue;
};
let path = Path::new(&format!(
"{}/{}",
root_name,
project_path.path.as_unix_str()
))
.into();
files.push((
buffer,
RelatedFile {
path,
excerpts,
max_row: snapshot.max_point().row,
},
));
}
files.sort_by_key(|file| file.path.clone());
(new_entries, files)
files.sort_by_key(|(_, file)| file.path.clone());
let (related_buffers, related_files) = files.into_iter().unzip();
(new_entries, related_files, related_buffers)
})
.await)
}

View File

@@ -48,7 +48,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
&excerpts,
&[
(
"src/company.rs",
"root/src/company.rs",
&[indoc! {"
pub struct Company {
owner: Arc<Person>,
@@ -56,7 +56,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
}"}],
),
(
"src/main.rs",
"root/src/main.rs",
&[
indoc! {"
pub struct Session {
@@ -71,7 +71,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
],
),
(
"src/person.rs",
"root/src/person.rs",
&[
indoc! {"
impl Person {
@@ -446,7 +446,7 @@ fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &
.iter()
.map(|excerpt| excerpt.text.to_string())
.collect::<Vec<_>>();
(file.path.path.as_unix_str(), excerpts)
(file.path.to_str().unwrap(), excerpts)
})
.collect::<Vec<_>>();
let expected_excerpts = expected_files
@@ -492,10 +492,10 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
if excerpt.text.is_empty() {
continue;
}
if current_row < excerpt.point_range.start.row {
if current_row < excerpt.row_range.start {
writeln!(&mut output, "").unwrap();
}
current_row = excerpt.point_range.start.row;
current_row = excerpt.row_range.start;
for line in excerpt.text.to_string().lines() {
output.push_str(line);

View File

@@ -17,7 +17,6 @@ anyhow.workspace = true
buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
codestral.workspace = true
command_palette_hooks.workspace = true
copilot.workspace = true
@@ -46,6 +45,7 @@ ui_input.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }

View File

@@ -17,7 +17,7 @@ use gpui::{
};
use multi_buffer::MultiBuffer;
use project::Project;
use text::OffsetRangeExt;
use text::Point;
use ui::{
ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
StyledTypography as _, h_flex, v_flex,
@@ -66,7 +66,7 @@ impl EditPredictionContextView {
) -> Self {
let store = EditPredictionStore::global(client, user_store, cx);
let mut debug_rx = store.update(cx, |store, _| store.debug_info());
let mut debug_rx = store.update(cx, |store, cx| store.debug_info(&project, cx));
let _update_task = cx.spawn_in(window, async move |this, cx| {
while let Some(event) = debug_rx.next().await {
this.update_in(cx, |this, window, cx| {
@@ -103,7 +103,8 @@ impl EditPredictionContextView {
self.handle_context_retrieval_finished(info, window, cx);
}
}
DebugEvent::EditPredictionRequested(_) => {}
DebugEvent::EditPredictionStarted(_) => {}
DebugEvent::EditPredictionFinished(_) => {}
}
}
@@ -152,12 +153,11 @@ impl EditPredictionContextView {
run.finished_at = Some(info.timestamp);
run.metadata = info.metadata;
let project = self.project.clone();
let related_files = self
.store
.read(cx)
.context_for_project(&self.project, cx)
.to_vec();
.context_for_project_with_buffers(&self.project, cx)
.map_or(Vec::new(), |files| files.collect());
let editor = run.editor.clone();
let multibuffer = run.editor.read(cx).buffer().clone();
@@ -168,33 +168,14 @@ impl EditPredictionContextView {
cx.spawn_in(window, async move |this, cx| {
let mut paths = Vec::new();
for related_file in related_files {
let (buffer, point_ranges): (_, Vec<_>) =
if let Some(buffer) = related_file.buffer.upgrade() {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
(
buffer,
related_file
.excerpts
.iter()
.map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
.collect(),
)
} else {
(
project
.update(cx, |project, cx| {
project.open_buffer(related_file.path.clone(), cx)
})?
.await?,
related_file
.excerpts
.iter()
.map(|excerpt| excerpt.point_range.clone())
.collect(),
)
};
for (related_file, buffer) in related_files {
let point_ranges = related_file
.excerpts
.iter()
.map(|excerpt| {
Point::new(excerpt.row_range.start, 0)..Point::new(excerpt.row_range.end, 0)
})
.collect::<Vec<_>>();
cx.update(|_, cx| {
let path = PathKey::for_buffer(&buffer, cx);
paths.push((path, buffer, point_ranges));

View File

@@ -1,5 +1,4 @@
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use cloud_zeta2_prompt::write_codeblock;
use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
use editor::{Editor, ExcerptRange, MultiBuffer};
use feature_flags::FeatureFlag;
@@ -362,14 +361,14 @@ impl RatePredictionsModal {
write!(&mut formatted_inputs, "## Events\n\n").unwrap();
for event in &prediction.inputs.events {
write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
formatted_inputs.push_str("```diff\n");
zeta_prompt::write_event(&mut formatted_inputs, event.as_ref());
formatted_inputs.push_str("```\n\n");
}
write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
for included_file in &prediction.inputs.included_files {
let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
write!(&mut formatted_inputs, "## Related files\n\n").unwrap();
for included_file in prediction.inputs.related_files.as_ref() {
write!(
&mut formatted_inputs,
"### {}\n\n",
@@ -377,20 +376,28 @@ impl RatePredictionsModal {
)
.unwrap();
write_codeblock(
&included_file.path,
&included_file.excerpts,
if included_file.path == prediction.inputs.cursor_path {
cursor_insertions.as_slice()
} else {
&[]
},
included_file.max_row,
false,
&mut formatted_inputs,
);
for excerpt in included_file.excerpts.iter() {
write!(
&mut formatted_inputs,
"```{}\n{}\n```\n",
included_file.path.display(),
excerpt.text
)
.unwrap();
}
}
write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap();
writeln!(
&mut formatted_inputs,
"```{}\n{}<CURSOR>{}\n```\n",
prediction.inputs.cursor_path.display(),
&prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt],
&prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..],
)
.unwrap();
self.active_prediction = Some(ActivePrediction {
prediction,
feedback_editor: cx.new(|cx| {

View File

@@ -45,7 +45,7 @@ impl Editor {
let bracket_matches_by_accent = self.visible_excerpts(false, cx).into_iter().fold(
HashMap::default(),
|mut acc, (excerpt_id, (buffer, buffer_version, buffer_range))| {
|mut acc, (excerpt_id, (buffer, _, buffer_range))| {
let buffer_snapshot = buffer.read(cx).snapshot();
if language_settings::language_settings(
buffer_snapshot.language().map(|language| language.name()),
@@ -62,7 +62,7 @@ impl Editor {
let brackets_by_accent = buffer_snapshot
.fetch_bracket_ranges(
buffer_range.start..buffer_range.end,
Some((&buffer_version, fetched_chunks)),
Some(fetched_chunks),
)
.into_iter()
.flat_map(|(chunk_range, pairs)| {

View File

@@ -56,6 +56,7 @@ use sum_tree::{Bias, TreeMap};
use text::{BufferId, LineIndent};
use ui::{SharedString, px};
use unicode_segmentation::UnicodeSegmentation;
use ztracing::instrument;
use std::{
any::TypeId,
@@ -168,6 +169,7 @@ impl DisplayMap {
}
}
#[instrument(skip_all)]
pub fn snapshot(&mut self, cx: &mut Context<Self>) -> DisplaySnapshot {
let tab_size = Self::tab_size(&self.buffer, cx);
@@ -195,6 +197,7 @@ impl DisplayMap {
}
}
#[instrument(skip_all)]
pub fn set_state(&mut self, other: &DisplaySnapshot, cx: &mut Context<Self>) {
self.fold(
other
@@ -211,6 +214,7 @@ impl DisplayMap {
}
/// Creates folds for the given creases.
#[instrument(skip_all)]
pub fn fold<T: Clone + ToOffset>(&mut self, creases: Vec<Crease<T>>, cx: &mut Context<Self>) {
let buffer_snapshot = self.buffer.read(cx).snapshot(cx);
let edits = self.buffer_subscription.consume().into_inner();
@@ -279,6 +283,7 @@ impl DisplayMap {
}
/// Removes any folds with the given ranges.
#[instrument(skip_all)]
pub fn remove_folds_with_type<T: ToOffset>(
&mut self,
ranges: impl IntoIterator<Item = Range<T>>,
@@ -304,6 +309,7 @@ impl DisplayMap {
}
/// Removes any folds whose ranges intersect any of the given ranges.
#[instrument(skip_all)]
pub fn unfold_intersecting<T: ToOffset>(
&mut self,
ranges: impl IntoIterator<Item = Range<T>>,
@@ -335,6 +341,7 @@ impl DisplayMap {
block_map.remove_intersecting_replace_blocks(offset_ranges, inclusive);
}
#[instrument(skip_all)]
pub fn disable_header_for_buffer(&mut self, buffer_id: BufferId, cx: &mut Context<Self>) {
let snapshot = self.buffer.read(cx).snapshot(cx);
let edits = self.buffer_subscription.consume().into_inner();
@@ -349,6 +356,7 @@ impl DisplayMap {
block_map.disable_header_for_buffer(buffer_id)
}
#[instrument(skip_all)]
pub fn fold_buffers(
&mut self,
buffer_ids: impl IntoIterator<Item = language::BufferId>,
@@ -367,6 +375,7 @@ impl DisplayMap {
block_map.fold_buffers(buffer_ids, self.buffer.read(cx), cx)
}
#[instrument(skip_all)]
pub fn unfold_buffers(
&mut self,
buffer_ids: impl IntoIterator<Item = language::BufferId>,
@@ -385,14 +394,17 @@ impl DisplayMap {
block_map.unfold_buffers(buffer_ids, self.buffer.read(cx), cx)
}
#[instrument(skip_all)]
pub(crate) fn is_buffer_folded(&self, buffer_id: language::BufferId) -> bool {
self.block_map.folded_buffers.contains(&buffer_id)
}
#[instrument(skip_all)]
pub(crate) fn folded_buffers(&self) -> &HashSet<BufferId> {
&self.block_map.folded_buffers
}
#[instrument(skip_all)]
pub fn insert_creases(
&mut self,
creases: impl IntoIterator<Item = Crease<Anchor>>,
@@ -402,6 +414,7 @@ impl DisplayMap {
self.crease_map.insert(creases, &snapshot)
}
#[instrument(skip_all)]
pub fn remove_creases(
&mut self,
crease_ids: impl IntoIterator<Item = CreaseId>,
@@ -411,6 +424,7 @@ impl DisplayMap {
self.crease_map.remove(crease_ids, &snapshot)
}
#[instrument(skip_all)]
pub fn insert_blocks(
&mut self,
blocks: impl IntoIterator<Item = BlockProperties<Anchor>>,
@@ -429,6 +443,7 @@ impl DisplayMap {
block_map.insert(blocks)
}
#[instrument(skip_all)]
pub fn resize_blocks(&mut self, heights: HashMap<CustomBlockId, u32>, cx: &mut Context<Self>) {
let snapshot = self.buffer.read(cx).snapshot(cx);
let edits = self.buffer_subscription.consume().into_inner();
@@ -443,10 +458,12 @@ impl DisplayMap {
block_map.resize(heights);
}
#[instrument(skip_all)]
pub fn replace_blocks(&mut self, renderers: HashMap<CustomBlockId, RenderBlock>) {
self.block_map.replace_blocks(renderers);
}
#[instrument(skip_all)]
pub fn remove_blocks(&mut self, ids: HashSet<CustomBlockId>, cx: &mut Context<Self>) {
let snapshot = self.buffer.read(cx).snapshot(cx);
let edits = self.buffer_subscription.consume().into_inner();
@@ -461,6 +478,7 @@ impl DisplayMap {
block_map.remove(ids);
}
#[instrument(skip_all)]
pub fn row_for_block(
&mut self,
block_id: CustomBlockId,
@@ -480,6 +498,7 @@ impl DisplayMap {
Some(DisplayRow(block_row.0))
}
#[instrument(skip_all)]
pub fn highlight_text(
&mut self,
key: HighlightKey,
@@ -507,6 +526,7 @@ impl DisplayMap {
self.text_highlights.insert(key, to_insert);
}
#[instrument(skip_all)]
pub(crate) fn highlight_inlays(
&mut self,
type_id: TypeId,
@@ -526,6 +546,7 @@ impl DisplayMap {
}
}
#[instrument(skip_all)]
pub fn text_highlights(&self, type_id: TypeId) -> Option<(HighlightStyle, &[Range<Anchor>])> {
let highlights = self.text_highlights.get(&HighlightKey::Type(type_id))?;
Some((highlights.0, &highlights.1))
@@ -538,6 +559,7 @@ impl DisplayMap {
self.text_highlights.values()
}
#[instrument(skip_all)]
pub fn clear_highlights(&mut self, type_id: TypeId) -> bool {
let mut cleared = self
.text_highlights
@@ -566,6 +588,7 @@ impl DisplayMap {
.update(cx, |map, cx| map.set_wrap_width(width, cx))
}
#[instrument(skip_all)]
pub fn update_fold_widths(
&mut self,
widths: impl IntoIterator<Item = (ChunkRendererId, Pixels)>,
@@ -597,6 +620,7 @@ impl DisplayMap {
self.inlay_map.current_inlays()
}
#[instrument(skip_all)]
pub(crate) fn splice_inlays(
&mut self,
to_remove: &[InlayId],
@@ -626,6 +650,7 @@ impl DisplayMap {
self.block_map.read(snapshot, edits);
}
#[instrument(skip_all)]
fn tab_size(buffer: &Entity<MultiBuffer>, cx: &App) -> NonZeroU32 {
let buffer = buffer.read(cx).as_singleton().map(|buffer| buffer.read(cx));
let language = buffer
@@ -675,6 +700,7 @@ pub struct HighlightedChunk<'a> {
}
impl<'a> HighlightedChunk<'a> {
#[instrument(skip_all)]
fn highlight_invisibles(
self,
editor_style: &'a EditorStyle,
@@ -832,6 +858,7 @@ impl DisplaySnapshot {
self.buffer_snapshot().widest_line_number()
}
#[instrument(skip_all)]
pub fn prev_line_boundary(&self, mut point: MultiBufferPoint) -> (Point, DisplayPoint) {
loop {
let mut inlay_point = self.inlay_snapshot().to_inlay_point(point);
@@ -850,6 +877,7 @@ impl DisplaySnapshot {
}
}
#[instrument(skip_all)]
pub fn next_line_boundary(
&self,
mut point: MultiBufferPoint,
@@ -888,6 +916,7 @@ impl DisplaySnapshot {
new_start..new_end
}
#[instrument(skip_all)]
pub fn point_to_display_point(&self, point: MultiBufferPoint, bias: Bias) -> DisplayPoint {
let inlay_point = self.inlay_snapshot().to_inlay_point(point);
let fold_point = self.fold_snapshot().to_fold_point(inlay_point, bias);
@@ -917,6 +946,7 @@ impl DisplaySnapshot {
.anchor_at(point.to_offset(self, bias), bias)
}
#[instrument(skip_all)]
fn display_point_to_inlay_point(&self, point: DisplayPoint, bias: Bias) -> InlayPoint {
let block_point = point.0;
let wrap_point = self.block_snapshot.to_wrap_point(block_point, bias);
@@ -928,6 +958,7 @@ impl DisplaySnapshot {
fold_point.to_inlay_point(self.fold_snapshot())
}
#[instrument(skip_all)]
pub fn display_point_to_fold_point(&self, point: DisplayPoint, bias: Bias) -> FoldPoint {
let block_point = point.0;
let wrap_point = self.block_snapshot.to_wrap_point(block_point, bias);
@@ -937,6 +968,7 @@ impl DisplaySnapshot {
.0
}
#[instrument(skip_all)]
pub fn fold_point_to_display_point(&self, fold_point: FoldPoint) -> DisplayPoint {
let tab_point = self.tab_snapshot().fold_point_to_tab_point(fold_point);
let wrap_point = self.wrap_snapshot().tab_point_to_wrap_point(tab_point);
@@ -949,6 +981,7 @@ impl DisplaySnapshot {
}
/// Returns text chunks starting at the given display row until the end of the file
#[instrument(skip_all)]
pub fn text_chunks(&self, display_row: DisplayRow) -> impl Iterator<Item = &str> {
self.block_snapshot
.chunks(
@@ -961,6 +994,7 @@ impl DisplaySnapshot {
}
/// Returns text chunks starting at the end of the given display row in reverse until the start of the file
#[instrument(skip_all)]
pub fn reverse_text_chunks(&self, display_row: DisplayRow) -> impl Iterator<Item = &str> {
(0..=display_row.0).rev().flat_map(move |row| {
self.block_snapshot
@@ -977,6 +1011,7 @@ impl DisplaySnapshot {
})
}
#[instrument(skip_all)]
pub fn chunks(
&self,
display_rows: Range<DisplayRow>,
@@ -995,6 +1030,7 @@ impl DisplaySnapshot {
)
}
#[instrument(skip_all)]
pub fn highlighted_chunks<'a>(
&'a self,
display_rows: Range<DisplayRow>,
@@ -1071,6 +1107,7 @@ impl DisplaySnapshot {
})
}
#[instrument(skip_all)]
pub fn layout_row(
&self,
display_row: DisplayRow,
@@ -1132,6 +1169,7 @@ impl DisplaySnapshot {
layout_line.closest_index_for_x(x) as u32
}
#[instrument(skip_all)]
pub fn grapheme_at(&self, mut point: DisplayPoint) -> Option<SharedString> {
point = DisplayPoint(self.block_snapshot.clip_point(point.0, Bias::Left));
let chars = self
@@ -1321,6 +1359,7 @@ impl DisplaySnapshot {
.unwrap_or(false)
}
#[instrument(skip_all)]
pub fn crease_for_buffer_row(&self, buffer_row: MultiBufferRow) -> Option<Crease<Point>> {
let start =
MultiBufferPoint::new(buffer_row.0, self.buffer_snapshot().line_len(buffer_row));
@@ -1407,6 +1446,7 @@ impl DisplaySnapshot {
}
#[cfg(any(test, feature = "test-support"))]
#[instrument(skip_all)]
pub fn text_highlight_ranges<Tag: ?Sized + 'static>(
&self,
) -> Option<Arc<(HighlightStyle, Vec<Range<Anchor>>)>> {
@@ -1417,6 +1457,7 @@ impl DisplaySnapshot {
}
#[cfg(any(test, feature = "test-support"))]
#[instrument(skip_all)]
pub fn all_text_highlight_ranges<Tag: ?Sized + 'static>(
&self,
) -> Vec<(gpui::Hsla, Range<Point>)> {
@@ -1466,6 +1507,7 @@ impl DisplaySnapshot {
///
/// This moves by buffer rows instead of display rows, a distinction that is
/// important when soft wrapping is enabled.
#[instrument(skip_all)]
pub fn start_of_relative_buffer_row(&self, point: DisplayPoint, times: isize) -> DisplayPoint {
let start = self.display_point_to_fold_point(point, Bias::Left);
let target = start.row() as isize + times;

View File

@@ -529,7 +529,7 @@ impl BlockMap {
BlockMapWriter(self)
}
#[ztracing::instrument(skip_all, fields(edits))]
#[ztracing::instrument(skip_all, fields(edits = ?edits))]
fn sync(&self, wrap_snapshot: &WrapSnapshot, mut edits: WrapPatch) {
let _timer = zlog::time!("BlockMap::sync").warn_if_gt(std::time::Duration::from_millis(50));
@@ -570,6 +570,9 @@ impl BlockMap {
let mut wrap_point_cursor = wrap_snapshot.wrap_point_cursor();
while let Some(edit) = edits.next() {
let span = ztracing::debug_span!("while edits", edit = ?edit);
let _enter = span.enter();
let mut old_start = edit.old.start;
let mut new_start = edit.new.start;
@@ -628,6 +631,8 @@ impl BlockMap {
let mut old_end = edit.old.end;
let mut new_end = edit.new.end;
loop {
let span = ztracing::debug_span!("decide where edit ends loop");
let _enter = span.enter();
// Seek to the transform starting at or after the end of the edit
cursor.seek(&old_end, Bias::Left);
cursor.next();
@@ -736,6 +741,10 @@ impl BlockMap {
// and then insert the block itself.
let mut just_processed_folded_buffer = false;
for (block_placement, block) in blocks_in_edit.drain(..) {
let span =
ztracing::debug_span!("for block in edits", block_height = block.height());
let _enter = span.enter();
let mut summary = TransformSummary {
input_rows: WrapRow(0),
output_rows: BlockRow(block.height()),
@@ -957,6 +966,7 @@ impl BlockMap {
}
}
#[ztracing::instrument(skip(tree, wrap_snapshot))]
fn push_isomorphic(tree: &mut SumTree<Transform>, rows: RowDelta, wrap_snapshot: &WrapSnapshot) {
if rows == RowDelta(0) {
return;

View File

@@ -840,7 +840,7 @@ impl WrapSnapshot {
self.tab_point_to_wrap_point(self.tab_snapshot.clip_point(self.to_tab_point(point), bias))
}
#[ztracing::instrument(skip_all, fields(point, ret))]
#[ztracing::instrument(skip_all, fields(point=?point, ret))]
pub fn prev_row_boundary(&self, mut point: WrapPoint) -> WrapRow {
if self.transforms.is_empty() {
return WrapRow(0);
@@ -851,11 +851,14 @@ impl WrapSnapshot {
let mut cursor = self
.transforms
.cursor::<Dimensions<WrapPoint, TabPoint>>(());
// start
cursor.seek(&point, Bias::Right);
// end
if cursor.item().is_none() {
cursor.prev();
}
// start
while let Some(transform) = cursor.item() {
if transform.is_isomorphic() && cursor.start().1.column() == 0 {
return cmp::min(cursor.end().0.row(), point.row());
@@ -863,6 +866,7 @@ impl WrapSnapshot {
cursor.prev();
}
}
// end
unreachable!()
}

View File

@@ -7135,6 +7135,7 @@ impl Editor {
Some((query, selection_anchor_range))
}
#[ztracing::instrument(skip_all)]
fn update_selection_occurrence_highlights(
&mut self,
query_text: String,
@@ -7279,6 +7280,7 @@ impl Editor {
});
}
#[ztracing::instrument(skip_all)]
fn refresh_selected_text_highlights(
&mut self,
on_buffer_edit: bool,
@@ -20973,9 +20975,22 @@ impl Editor {
buffer_ranges.last()
}?;
let selection = text::ToPoint::to_point(&range.start, buffer).row
..text::ToPoint::to_point(&range.end, buffer).row;
Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection))
let start_row_in_buffer = text::ToPoint::to_point(&range.start, buffer).row;
let end_row_in_buffer = text::ToPoint::to_point(&range.end, buffer).row;
let Some(buffer_diff) = multi_buffer.diff_for(buffer.remote_id()) else {
let selection = start_row_in_buffer..end_row_in_buffer;
return Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection));
};
let buffer_diff_snapshot = buffer_diff.read(cx).snapshot(cx);
Some((
multi_buffer.buffer(buffer.remote_id()).unwrap(),
buffer_diff_snapshot.row_to_base_text_row(start_row_in_buffer, buffer)
..buffer_diff_snapshot.row_to_base_text_row(end_row_in_buffer, buffer),
))
});
let Some((buffer, selection)) = buffer_and_selection else {

View File

@@ -27701,6 +27701,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
cx.update_editor(|editor, window, cx| {
editor.handle_input("x", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
- [ ] Item 1
- [ ] Item 1.a
@@ -27716,8 +27717,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
- [ ] Item 1.a
- [x] Item 2
- [x] Item 2.a
- [x] Item 2.bˇ
"
- [x] Item 2.bˇ"
});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
@@ -27728,34 +27728,41 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
- [x] Item 2
- [x] Item 2.a
- [x] Item 2.b
ˇ
"
ˇ"
});
// Case 3: Test adding a new nested list item preserves indent
cx.set_state(&indoc! {"
- [ ] Item 1
- [ ] Item 1.a
- [x] Item 2
- [x] Item 2.a
- [x] Item 2.b
ˇ"
});
cx.update_editor(|editor, window, cx| {
editor.handle_input("-", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
- [ ] Item 1
- [ ] Item 1.a
- [x] Item 2
- [x] Item 2.a
- [x] Item 2.b
"
"
});
cx.update_editor(|editor, window, cx| {
editor.handle_input(" [x] Item 2.c", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
- [ ] Item 1
- [ ] Item 1.a
- [x] Item 2
- [x] Item 2.a
- [x] Item 2.b
- [x] Item 2.cˇ
"
- [x] Item 2.cˇ"
});
// Case 4: Test adding new line after nested ordered list preserves indent of previous line
@@ -27764,8 +27771,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
1. Item 1.a
2. Item 2
1. Item 2.a
2. Item 2.bˇ
"
2. Item 2.bˇ"
});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
@@ -27776,60 +27782,81 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
2. Item 2
1. Item 2.a
2. Item 2.b
ˇ
"
ˇ"
});
// Case 5: Adding new ordered list item preserves indent
cx.set_state(indoc! {"
1. Item 1
1. Item 1.a
2. Item 2
1. Item 2.a
2. Item 2.b
ˇ"
});
cx.update_editor(|editor, window, cx| {
editor.handle_input("3", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
1. Item 1
1. Item 1.a
2. Item 2
1. Item 2.a
2. Item 2.b
"
"
});
cx.update_editor(|editor, window, cx| {
editor.handle_input(".", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
1. Item 1
1. Item 1.a
2. Item 2
1. Item 2.a
2. Item 2.b
3.ˇ
"
3.ˇ"
});
cx.update_editor(|editor, window, cx| {
editor.handle_input(" Item 2.c", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
1. Item 1
1. Item 1.a
2. Item 2
1. Item 2.a
2. Item 2.b
3. Item 2.cˇ
"
3. Item 2.cˇ"
});
// Case 6: Test adding new line after nested ordered list preserves indent of previous line
cx.set_state(indoc! {"
- Item 1
- Item 1.a
- Item 1.a
ˇ"});
cx.update_editor(|editor, window, cx| {
editor.handle_input("-", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
- Item 1
- Item 1.a
- Item 1.a
"});
// Case 7: Test blockquote newline preserves something
cx.set_state(indoc! {"
> Item 1ˇ
"
> Item 1ˇ"
});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.assert_editor_state(indoc! {"
> Item 1
ˇ
"
ˇ"
});
}

View File

@@ -7,6 +7,7 @@ use theme::ActiveTheme;
enum MatchingBracketHighlight {}
impl Editor {
#[ztracing::instrument(skip_all)]
pub fn refresh_matching_bracket_highlights(
&mut self,
window: &Window,

View File

@@ -623,7 +623,10 @@ pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
});
MarkdownStyle {
base_text_style,
code_block: StyleRefinement::default().my(rems(1.)).font_buffer(cx),
code_block: StyleRefinement::default()
.my(rems(1.))
.font_buffer(cx)
.font_features(buffer_font_features.clone()),
inline_code: TextStyleRefinement {
background_color: Some(cx.theme().colors().background),
font_family: Some(buffer_font_family),

View File

@@ -892,7 +892,7 @@ pub fn wait_for_lang_server(
.update(cx, |buffer, cx| {
lsp_store.update(cx, |lsp_store, cx| {
lsp_store
.language_servers_for_local_buffer(buffer, cx)
.running_language_servers_for_local_buffer(buffer, cx)
.next()
.is_some()
})

View File

@@ -23,6 +23,7 @@ use std::{
path::PathBuf,
sync::{Arc, LazyLock},
};
use text::LineEnding;
use util::{paths::PathStyle, rel_path::RelPath};
pub static LOAD_INDEX_TEXT_TASK: LazyLock<TaskLabel> = LazyLock::new(TaskLabel::new);
@@ -200,6 +201,7 @@ impl GitRepository for FakeGitRepository {
async {
Ok(CommitDetails {
sha: commit.into(),
message: "initial commit".into(),
..Default::default()
})
}
@@ -451,7 +453,12 @@ impl GitRepository for FakeGitRepository {
})
}
fn blame(&self, path: RepoPath, _content: Rope) -> BoxFuture<'_, Result<git::blame::Blame>> {
fn blame(
&self,
path: RepoPath,
_content: Rope,
_line_ending: LineEnding,
) -> BoxFuture<'_, Result<git::blame::Blame>> {
self.with_state_async(false, move |state| {
state
.blames
@@ -568,7 +575,7 @@ impl GitRepository for FakeGitRepository {
_askpass: AskPassDelegate,
_env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
unimplemented!()
async { Ok(()) }.boxed()
}
fn run_hook(
@@ -576,7 +583,7 @@ impl GitRepository for FakeGitRepository {
_hook: RunHook,
_env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
unimplemented!()
async { Ok(()) }.boxed()
}
fn push(

View File

@@ -803,7 +803,7 @@ impl Fs for RealFs {
}
let file = smol::fs::File::create(path).await?;
let mut writer = smol::io::BufWriter::with_capacity(buffer_size, file);
for chunk in chunks(text, line_ending) {
for chunk in text::chunks_with_line_ending(text, line_ending) {
writer.write_all(chunk.as_bytes()).await?;
}
writer.flush().await?;
@@ -2555,7 +2555,7 @@ impl Fs for FakeFs {
async fn save(&self, path: &Path, text: &Rope, line_ending: LineEnding) -> Result<()> {
self.simulate_random_delay().await;
let path = normalize_path(path);
let content = chunks(text, line_ending).collect::<String>();
let content = text::chunks_with_line_ending(text, line_ending).collect::<String>();
if let Some(path) = path.parent() {
self.create_dir(path).await?;
}
@@ -2773,25 +2773,6 @@ impl Fs for FakeFs {
}
}
fn chunks(rope: &Rope, line_ending: LineEnding) -> impl Iterator<Item = &str> {
rope.chunks().flat_map(move |chunk| {
let mut newline = false;
let end_with_newline = chunk.ends_with('\n').then_some(line_ending.as_str());
chunk
.lines()
.flat_map(move |line| {
let ending = if newline {
Some(line_ending.as_str())
} else {
None
};
newline = true;
ending.into_iter().chain([line])
})
.chain(end_with_newline)
})
}
pub fn normalize_path(path: &Path) -> PathBuf {
let mut components = path.components().peekable();
let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek().cloned() {

View File

@@ -8,7 +8,7 @@ use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use text::{LineEnding, Rope};
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
@@ -35,8 +35,10 @@ impl Blame {
working_directory: &Path,
path: &RepoPath,
content: &Rope,
line_ending: LineEnding,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let output =
run_git_blame(git_binary, working_directory, path, content, line_ending).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
@@ -63,12 +65,12 @@ async fn run_git_blame(
working_directory: &Path,
path: &RepoPath,
contents: &Rope,
line_ending: LineEnding,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("-w")
.arg("--contents")
.arg("-")
.arg(path.as_unix_str())
@@ -83,7 +85,7 @@ async fn run_git_blame(
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
for chunk in text::chunks_with_line_ending(contents, line_ending) {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;

View File

@@ -232,14 +232,12 @@ impl From<Oid> for usize {
#[derive(Copy, Clone, Debug)]
pub enum RunHook {
PreCommit,
PrePush,
}
impl RunHook {
pub fn as_str(&self) -> &str {
match self {
Self::PreCommit => "pre-commit",
Self::PrePush => "pre-push",
}
}
@@ -250,7 +248,6 @@ impl RunHook {
pub fn from_proto(value: i32) -> Option<Self> {
match value {
0 => Some(Self::PreCommit),
1 => Some(Self::PrePush),
_ => None,
}
}

View File

@@ -14,6 +14,7 @@ use rope::Rope;
use schemars::JsonSchema;
use serde::Deserialize;
use smol::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
use text::LineEnding;
use std::collections::HashSet;
use std::ffi::{OsStr, OsString};
@@ -487,7 +488,12 @@ pub trait GitRepository: Send + Sync {
fn show(&self, commit: String) -> BoxFuture<'_, Result<CommitDetails>>;
fn load_commit(&self, commit: String, cx: AsyncApp) -> BoxFuture<'_, Result<CommitDiff>>;
fn blame(&self, path: RepoPath, content: Rope) -> BoxFuture<'_, Result<crate::blame::Blame>>;
fn blame(
&self,
path: RepoPath,
content: Rope,
line_ending: LineEnding,
) -> BoxFuture<'_, Result<crate::blame::Blame>>;
fn file_history(&self, path: RepoPath) -> BoxFuture<'_, Result<FileHistory>>;
fn file_history_paginated(
&self,
@@ -652,6 +658,7 @@ pub struct RealGitRepository {
pub repository: Arc<Mutex<git2::Repository>>,
pub system_git_binary_path: Option<PathBuf>,
pub any_git_binary_path: PathBuf,
any_git_binary_help_output: Arc<Mutex<Option<SharedString>>>,
executor: BackgroundExecutor,
}
@@ -670,6 +677,7 @@ impl RealGitRepository {
system_git_binary_path,
any_git_binary_path,
executor,
any_git_binary_help_output: Arc::new(Mutex::new(None)),
})
}
@@ -680,6 +688,27 @@ impl RealGitRepository {
.context("failed to read git work directory")
.map(Path::to_path_buf)
}
async fn any_git_binary_help_output(&self) -> SharedString {
if let Some(output) = self.any_git_binary_help_output.lock().clone() {
return output;
}
let git_binary_path = self.any_git_binary_path.clone();
let executor = self.executor.clone();
let working_directory = self.working_directory();
let output: SharedString = self
.executor
.spawn(async move {
GitBinary::new(git_binary_path, working_directory?, executor)
.run(["help", "-a"])
.await
})
.await
.unwrap_or_default()
.into();
*self.any_git_binary_help_output.lock() = Some(output.clone());
output
}
}
#[derive(Clone, Debug)]
@@ -1489,7 +1518,12 @@ impl GitRepository for RealGitRepository {
.boxed()
}
fn blame(&self, path: RepoPath, content: Rope) -> BoxFuture<'_, Result<crate::blame::Blame>> {
fn blame(
&self,
path: RepoPath,
content: Rope,
line_ending: LineEnding,
) -> BoxFuture<'_, Result<crate::blame::Blame>> {
let working_directory = self.working_directory();
let git_binary_path = self.any_git_binary_path.clone();
let executor = self.executor.clone();
@@ -1501,6 +1535,7 @@ impl GitRepository for RealGitRepository {
&working_directory?,
&path,
&content,
line_ending,
)
.await
})
@@ -1819,6 +1854,7 @@ impl GitRepository for RealGitRepository {
.args(["commit", "--quiet", "-m"])
.arg(&message.to_string())
.arg("--cleanup=strip")
.arg("--no-verify")
.stdout(smol::process::Stdio::piped())
.stderr(smol::process::Stdio::piped());
@@ -2289,48 +2325,47 @@ impl GitRepository for RealGitRepository {
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
let working_directory = self.working_directory();
let repository = self.repository.clone();
let git_binary_path = self.any_git_binary_path.clone();
let executor = self.executor.clone();
self.executor
.spawn(async move {
let working_directory = working_directory?;
let git = GitBinary::new(git_binary_path, working_directory.clone(), executor)
.envs(HashMap::clone(&env));
let help_output = self.any_git_binary_help_output();
let output = git.run(&["help", "-a"]).await?;
if !output.lines().any(|line| line.trim().starts_with("hook ")) {
log::warn!(
"git hook command not available, running the {} hook manually",
hook.as_str()
);
// Note: Do not spawn these commands on the background thread, as this causes some git hooks to hang.
async move {
let working_directory = working_directory?;
if !help_output
.await
.lines()
.any(|line| line.trim().starts_with("hook "))
{
let hook_abs_path = repository.lock().path().join("hooks").join(hook.as_str());
if hook_abs_path.is_file() {
let output = new_smol_command(&hook_abs_path)
.envs(env.iter())
.current_dir(&working_directory)
.output()
.await?;
let hook_abs_path = working_directory
.join(".git")
.join("hooks")
.join(hook.as_str());
if hook_abs_path.is_file() {
let output = new_smol_command(&hook_abs_path)
.envs(env.iter())
.current_dir(&working_directory)
.output()
.await?;
anyhow::ensure!(
output.status.success(),
"{} hook failed:\n{}",
hook.as_str(),
String::from_utf8_lossy(&output.stderr)
);
if !output.status.success() {
return Err(GitBinaryCommandError {
stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
status: output.status,
}
.into());
}
return Ok(());
}
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
.await?;
Ok(())
})
.boxed()
return Ok(());
}
let git = GitBinary::new(git_binary_path, working_directory, executor)
.envs(HashMap::clone(&env));
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
.await?;
Ok(())
}
.boxed()
}
}

View File

@@ -47,11 +47,13 @@ impl BlameRenderer for GitBlameRenderer {
let name = util::truncate_and_trailoff(author_name, GIT_BLAME_MAX_AUTHOR_CHARS_DISPLAYED);
let avatar = if ProjectSettings::get_global(cx).git.blame.show_avatar {
CommitAvatar::new(
&blame_entry.sha.to_string().into(),
details.as_ref().and_then(|it| it.remote.as_ref()),
Some(
CommitAvatar::new(
&blame_entry.sha.to_string().into(),
details.as_ref().and_then(|it| it.remote.as_ref()),
)
.render(window, cx),
)
.render(window, cx)
} else {
None
};
@@ -65,7 +67,7 @@ impl BlameRenderer for GitBlameRenderer {
.w_full()
.gap_2()
.justify_between()
.font_family(style.font().family)
.font(style.font())
.line_height(style.line_height)
.text_color(cx.theme().status().hint)
.child(
@@ -264,7 +266,7 @@ impl BlameRenderer for GitBlameRenderer {
.flex_wrap()
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.children(avatar)
.child(avatar)
.child(author)
.when(!author_email.is_empty(), |this| {
this.child(

View File

@@ -139,7 +139,7 @@ impl CommitModal {
&& !git_panel.amend_pending()
{
git_panel.set_amend_pending(true, cx);
git_panel.load_last_commit_message_if_empty(cx);
git_panel.load_last_commit_message(cx);
}
}
ForceMode::Commit => {
@@ -492,53 +492,20 @@ impl CommitModal {
}
}
fn commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context<Self>) {
if self.git_panel.read(cx).amend_pending() {
return;
fn on_commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context<Self>) {
if self.git_panel.update(cx, |git_panel, cx| {
git_panel.commit(&self.commit_editor.focus_handle(cx), window, cx)
}) {
telemetry::event!("Git Committed", source = "Git Modal");
cx.emit(DismissEvent);
}
telemetry::event!("Git Committed", source = "Git Modal");
self.git_panel.update(cx, |git_panel, cx| {
git_panel.commit_changes(
CommitOptions {
amend: false,
signoff: git_panel.signoff_enabled(),
},
window,
cx,
)
});
cx.emit(DismissEvent);
}
fn amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context<Self>) {
if self
.git_panel
.read(cx)
.active_repository
.as_ref()
.and_then(|repo| repo.read(cx).head_commit.as_ref())
.is_none()
{
return;
}
if !self.git_panel.read(cx).amend_pending() {
self.git_panel.update(cx, |git_panel, cx| {
git_panel.set_amend_pending(true, cx);
git_panel.load_last_commit_message_if_empty(cx);
});
} else {
fn on_amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context<Self>) {
if self.git_panel.update(cx, |git_panel, cx| {
git_panel.amend(&self.commit_editor.focus_handle(cx), window, cx)
}) {
telemetry::event!("Git Amended", source = "Git Modal");
self.git_panel.update(cx, |git_panel, cx| {
git_panel.set_amend_pending(false, cx);
git_panel.commit_changes(
CommitOptions {
amend: true,
signoff: git_panel.signoff_enabled(),
},
window,
cx,
);
});
cx.emit(DismissEvent);
}
}
@@ -564,8 +531,8 @@ impl Render for CommitModal {
.id("commit-modal")
.key_context("GitCommit")
.on_action(cx.listener(Self::dismiss))
.on_action(cx.listener(Self::commit))
.on_action(cx.listener(Self::amend))
.on_action(cx.listener(Self::on_commit))
.on_action(cx.listener(Self::on_amend))
.when(!DisableAiSettings::get_global(cx).disable_ai, |this| {
this.on_action(cx.listener(|this, _: &GenerateCommitMessage, _, cx| {
this.git_panel.update(cx, |panel, cx| {

View File

@@ -29,11 +29,16 @@ pub struct CommitDetails {
pub struct CommitAvatar<'a> {
sha: &'a SharedString,
remote: Option<&'a GitRemote>,
size: Option<IconSize>,
}
impl<'a> CommitAvatar<'a> {
pub fn new(sha: &'a SharedString, remote: Option<&'a GitRemote>) -> Self {
Self { sha, remote }
Self {
sha,
remote,
size: None,
}
}
pub fn from_commit_details(details: &'a CommitDetails) -> Self {
@@ -43,28 +48,37 @@ impl<'a> CommitAvatar<'a> {
.message
.as_ref()
.and_then(|details| details.remote.as_ref()),
size: None,
}
}
}
impl<'a> CommitAvatar<'a> {
pub fn render(&'a self, window: &mut Window, cx: &mut App) -> Option<impl IntoElement + use<>> {
pub fn size(mut self, size: IconSize) -> Self {
self.size = Some(size);
self
}
pub fn render(&'a self, window: &mut Window, cx: &mut App) -> AnyElement {
match self.avatar(window, cx) {
// Loading or no avatar found
None => Icon::new(IconName::Person)
.color(Color::Muted)
.when_some(self.size, |this, size| this.size(size))
.into_any_element(),
// Found
Some(avatar) => avatar
.when_some(self.size, |this, size| this.size(size.rems()))
.into_any_element(),
}
}
pub fn avatar(&'a self, window: &mut Window, cx: &mut App) -> Option<Avatar> {
let remote = self
.remote
.filter(|remote| remote.host_supports_avatars())?;
let avatar_url = CommitAvatarAsset::new(remote.clone(), self.sha.clone());
let element = match window.use_asset::<CommitAvatarAsset>(&avatar_url, cx) {
// Loading or no avatar found
None | Some(None) => Icon::new(IconName::Person)
.color(Color::Muted)
.into_element()
.into_any(),
// Found
Some(Some(url)) => Avatar::new(url.to_string()).into_element().into_any(),
};
Some(element)
let url = window.use_asset::<CommitAvatarAsset>(&avatar_url, cx)??;
Some(Avatar::new(url.to_string()))
}
}
@@ -253,7 +267,7 @@ impl Render for CommitTooltip {
.gap_x_2()
.overflow_x_hidden()
.flex_wrap()
.children(avatar)
.child(avatar)
.child(author)
.when(!author_email.is_empty(), |this| {
this.child(

View File

@@ -5,8 +5,8 @@ use editor::{Editor, EditorEvent, ExcerptRange, MultiBuffer, multibuffer_context
use git::repository::{CommitDetails, CommitDiff, RepoPath};
use git::{GitHostingProviderRegistry, GitRemote, parse_git_remote_url};
use gpui::{
AnyElement, App, AppContext as _, Asset, AsyncApp, AsyncWindowContext, Context, Element,
Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ParentElement,
AnyElement, App, AppContext as _, AsyncApp, AsyncWindowContext, Context, Element, Entity,
EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ParentElement,
PromptLevel, Render, Styled, Task, WeakEntity, Window, actions,
};
use language::{
@@ -21,7 +21,7 @@ use std::{
sync::Arc,
};
use theme::ActiveTheme;
use ui::{Avatar, DiffStat, Tooltip, prelude::*};
use ui::{DiffStat, Tooltip, prelude::*};
use util::{ResultExt, paths::PathStyle, rel_path::RelPath, truncate_and_trailoff};
use workspace::item::TabTooltipContent;
use workspace::{
@@ -33,6 +33,7 @@ use workspace::{
searchable::SearchableItemHandle,
};
use crate::commit_tooltip::CommitAvatar;
use crate::git_panel::GitPanel;
actions!(git, [ApplyCurrentStash, PopCurrentStash, DropCurrentStash,]);
@@ -318,17 +319,7 @@ impl CommitView {
cx: &mut App,
) -> AnyElement {
let size = size.into();
let remote = self.remote.as_ref().filter(|r| r.host_supports_avatars());
if let Some(remote) = remote {
let avatar_asset = CommitAvatarAsset::new(remote.clone(), sha.clone());
if let Some(Some(url)) = window.use_asset::<CommitAvatarAsset>(&avatar_asset, cx) {
return Avatar::new(url.to_string())
.size(size)
.into_element()
.into_any();
}
}
let avatar = CommitAvatar::new(sha, self.remote.as_ref());
v_flex()
.w(size)
@@ -339,10 +330,15 @@ impl CommitView {
.justify_center()
.items_center()
.child(
Icon::new(IconName::Person)
.color(Color::Muted)
.size(IconSize::Medium)
.into_element(),
avatar
.avatar(window, cx)
.map(|a| a.size(size).into_any_element())
.unwrap_or_else(|| {
Icon::new(IconName::Person)
.color(Color::Muted)
.size(IconSize::Medium)
.into_any_element()
}),
)
.into_any()
}
@@ -647,54 +643,6 @@ impl CommitView {
}
}
#[derive(Clone, Debug)]
struct CommitAvatarAsset {
sha: SharedString,
remote: GitRemote,
}
impl std::hash::Hash for CommitAvatarAsset {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.sha.hash(state);
self.remote.host.name().hash(state);
}
}
impl CommitAvatarAsset {
fn new(remote: GitRemote, sha: SharedString) -> Self {
Self { remote, sha }
}
}
impl Asset for CommitAvatarAsset {
type Source = Self;
type Output = Option<SharedString>;
fn load(
source: Self::Source,
cx: &mut App,
) -> impl Future<Output = Self::Output> + Send + 'static {
let client = cx.http_client();
async move {
match source
.remote
.host
.commit_author_avatar_url(
&source.remote.owner,
&source.remote.repo,
source.sha.clone(),
client,
)
.await
{
Ok(Some(url)) => Some(SharedString::from(url.to_string())),
Ok(None) => None,
Err(_) => None,
}
}
}
}
impl language::File for GitBlob {
fn as_local(&self) -> Option<&dyn language::LocalFile> {
None

View File

@@ -111,6 +111,7 @@ fn excerpt_for_buffer_updated(
);
}
#[ztracing::instrument(skip_all)]
fn buffer_added(editor: &mut Editor, buffer: Entity<Buffer>, cx: &mut Context<Editor>) {
let Some(project) = editor.project() else {
return;
@@ -166,6 +167,7 @@ fn buffers_removed(editor: &mut Editor, removed_buffer_ids: &[BufferId], cx: &mu
editor.remove_blocks(removed_block_ids, None, cx);
}
#[ztracing::instrument(skip_all)]
fn conflicts_updated(
editor: &mut Editor,
conflict_set: Entity<ConflictSet>,
@@ -311,6 +313,7 @@ fn conflicts_updated(
}
}
#[ztracing::instrument(skip_all)]
fn update_conflict_highlighting(
editor: &mut Editor,
conflict: &ConflictRegion,

View File

@@ -1934,16 +1934,26 @@ impl GitPanel {
}
}
fn commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context<Self>) {
if self.amend_pending {
return;
}
if self
.commit_editor
.focus_handle(cx)
.contains_focused(window, cx)
{
fn on_commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context<Self>) {
if self.commit(&self.commit_editor.focus_handle(cx), window, cx) {
telemetry::event!("Git Committed", source = "Git Panel");
}
}
/// Commits staged changes with the current commit message.
///
/// Returns `true` if the commit was executed, `false` otherwise.
pub(crate) fn commit(
&mut self,
commit_editor_focus_handle: &FocusHandle,
window: &mut Window,
cx: &mut Context<Self>,
) -> bool {
if self.amend_pending {
return false;
}
if commit_editor_focus_handle.contains_focused(window, cx) {
self.commit_changes(
CommitOptions {
amend: false,
@@ -1951,24 +1961,39 @@ impl GitPanel {
},
window,
cx,
)
);
true
} else {
cx.propagate();
false
}
}
fn amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context<Self>) {
if self
.commit_editor
.focus_handle(cx)
.contains_focused(window, cx)
{
fn on_amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context<Self>) {
if self.amend(&self.commit_editor.focus_handle(cx), window, cx) {
telemetry::event!("Git Amended", source = "Git Panel");
}
}
/// Amends the most recent commit with staged changes and/or an updated commit message.
///
/// Uses a two-stage workflow where the first invocation loads the commit
/// message for editing, second invocation performs the amend. Returns
/// `true` if the amend was executed, `false` otherwise.
pub(crate) fn amend(
&mut self,
commit_editor_focus_handle: &FocusHandle,
window: &mut Window,
cx: &mut Context<Self>,
) -> bool {
if commit_editor_focus_handle.contains_focused(window, cx) {
if self.head_commit(cx).is_some() {
if !self.amend_pending {
self.set_amend_pending(true, cx);
self.load_last_commit_message_if_empty(cx);
self.load_last_commit_message(cx);
return false;
} else {
telemetry::event!("Git Amended", source = "Git Panel");
self.commit_changes(
CommitOptions {
amend: true,
@@ -1977,13 +2002,16 @@ impl GitPanel {
window,
cx,
);
return true;
}
}
return false;
} else {
cx.propagate();
return false;
}
}
pub fn head_commit(&self, cx: &App) -> Option<CommitDetails> {
self.active_repository
.as_ref()
@@ -1991,13 +2019,11 @@ impl GitPanel {
.cloned()
}
pub fn load_last_commit_message_if_empty(&mut self, cx: &mut Context<Self>) {
if !self.commit_editor.read(cx).is_empty(cx) {
return;
}
pub fn load_last_commit_message(&mut self, cx: &mut Context<Self>) {
let Some(head_commit) = self.head_commit(cx) else {
return;
};
let recent_sha = head_commit.sha.to_string();
let detail_task = self.load_commit_details(recent_sha, cx);
cx.spawn(async move |this, cx| {
@@ -2133,11 +2159,16 @@ impl GitPanel {
let result = task.await;
this.update_in(cx, |this, window, cx| {
this.pending_commit.take();
match result {
Ok(()) => {
this.commit_editor
.update(cx, |editor, cx| editor.clear(window, cx));
this.original_commit_message = None;
if options.amend {
this.set_amend_pending(false, cx);
} else {
this.commit_editor
.update(cx, |editor, cx| editor.clear(window, cx));
this.original_commit_message = None;
}
}
Err(e) => this.show_error_toast("commit", e, cx),
}
@@ -2146,9 +2177,6 @@ impl GitPanel {
});
self.pending_commit = Some(task);
if options.amend {
self.set_amend_pending(false, cx);
}
}
pub(crate) fn uncommit(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -5067,6 +5095,9 @@ impl GitPanel {
self.amend_pending
}
/// Sets the pending amend state, ensuring that the original commit message
/// is either saved, when `value` is `true` and there's no pending amend, or
/// restored, when `value` is `false` and there's a pending amend.
pub fn set_amend_pending(&mut self, value: bool, cx: &mut Context<Self>) {
if value && !self.amend_pending {
let current_message = self.commit_message_buffer(cx).read(cx).text();
@@ -5184,7 +5215,7 @@ impl GitPanel {
pub(crate) fn toggle_amend_pending(&mut self, cx: &mut Context<Self>) {
self.set_amend_pending(!self.amend_pending, cx);
if self.amend_pending {
self.load_last_commit_message_if_empty(cx);
self.load_last_commit_message(cx);
}
}
}
@@ -5215,8 +5246,8 @@ impl Render for GitPanel {
.when(has_write_access && !project.is_read_only(cx), |this| {
this.on_action(cx.listener(Self::toggle_staged_for_selected))
.on_action(cx.listener(Self::stage_range))
.on_action(cx.listener(GitPanel::commit))
.on_action(cx.listener(GitPanel::amend))
.on_action(cx.listener(GitPanel::on_commit))
.on_action(cx.listener(GitPanel::on_amend))
.on_action(cx.listener(GitPanel::toggle_signoff_enabled))
.on_action(cx.listener(Self::stage_all))
.on_action(cx.listener(Self::unstage_all))
@@ -6557,6 +6588,94 @@ mod tests {
});
}
#[gpui::test]
async fn test_amend(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(
"/root",
json!({
"project": {
".git": {},
"src": {
"main.rs": "fn main() {}"
}
}
}),
)
.await;
fs.set_status_for_repo(
Path::new(path!("/root/project/.git")),
&[("src/main.rs", StatusCode::Modified.worktree())],
);
let project = Project::test(fs.clone(), [Path::new(path!("/root/project"))], cx).await;
let workspace =
cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let cx = &mut VisualTestContext::from_window(*workspace, cx);
// Wait for the project scanning to finish so that `head_commit(cx)` is
// actually set, otherwise no head commit would be available from which
// to fetch the latest commit message from.
cx.executor().run_until_parked();
let panel = workspace.update(cx, GitPanel::new).unwrap();
panel.read_with(cx, |panel, cx| {
assert!(panel.active_repository.is_some());
assert!(panel.head_commit(cx).is_some());
});
panel.update_in(cx, |panel, window, cx| {
// Update the commit editor's message to ensure that its contents
// are later restored, after amending is finished.
panel.commit_message_buffer(cx).update(cx, |buffer, cx| {
buffer.set_text("refactor: update main.rs", cx);
});
// Start amending the previous commit.
panel.focus_editor(&Default::default(), window, cx);
panel.on_amend(&Amend, window, cx);
});
// Since `GitPanel.amend` attempts to fetch the latest commit message in
// a background task, we need to wait for it to complete before being
// able to assert that the commit message editor's state has been
// updated.
cx.run_until_parked();
panel.update_in(cx, |panel, window, cx| {
assert_eq!(
panel.commit_message_buffer(cx).read(cx).text(),
"initial commit"
);
assert_eq!(
panel.original_commit_message,
Some("refactor: update main.rs".to_string())
);
// Finish amending the previous commit.
panel.focus_editor(&Default::default(), window, cx);
panel.on_amend(&Amend, window, cx);
});
// Since the actual commit logic is run in a background task, we need to
// await its completion to actually ensure that the commit message
// editor's contents are set to the original message and haven't been
// cleared.
cx.run_until_parked();
panel.update_in(cx, |panel, _window, cx| {
// After amending, the commit editor's message should be restored to
// the original message.
assert_eq!(
panel.commit_message_buffer(cx).read(cx).text(),
"refactor: update main.rs"
);
assert!(panel.original_commit_message.is_none());
});
}
#[gpui::test]
async fn test_open_diff(cx: &mut TestAppContext) {
init_test(cx);

View File

@@ -21,7 +21,6 @@ default = ["font-kit", "wayland", "x11", "windows-manifest"]
test-support = [
"leak-detection",
"collections/test-support",
"rand",
"util/test-support",
"http_client/test-support",
"wayland",
@@ -109,7 +108,7 @@ parking = "2.0.0"
parking_lot.workspace = true
postage.workspace = true
profiling.workspace = true
rand = { optional = true, workspace = true }
rand.workspace = true
raw-window-handle = "0.6"
refineable.workspace = true
resvg = { version = "0.45.0", default-features = false, features = [
@@ -158,8 +157,10 @@ media.workspace = true
objc.workspace = true
objc2 = { version = "0.6", optional = true }
objc2-metal = { version = "0.3", optional = true }
mach2.workspace = true
#TODO: replace with "objc2"
metal.workspace = true
flume = "0.11"
[target.'cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))'.dependencies]
pathfinder_geometry = "0.5"

View File

@@ -84,6 +84,8 @@ mod macos {
.allowlist_var("_dispatch_main_q")
.allowlist_var("_dispatch_source_type_data_add")
.allowlist_var("DISPATCH_QUEUE_PRIORITY_HIGH")
.allowlist_var("DISPATCH_QUEUE_PRIORITY_DEFAULT")
.allowlist_var("DISPATCH_QUEUE_PRIORITY_LOW")
.allowlist_var("DISPATCH_TIME_NOW")
.allowlist_function("dispatch_get_global_queue")
.allowlist_function("dispatch_async_f")

View File

@@ -38,10 +38,11 @@ use crate::{
AssetSource, BackgroundExecutor, Bounds, ClipboardItem, CursorStyle, DispatchPhase, DisplayId,
EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, KeyContext,
Keymap, Keystroke, LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, PromptBuilder,
PromptButton, PromptHandle, PromptLevel, Render, RenderImage, RenderablePromptHandle,
Reservation, ScreenCaptureSource, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
TextSystem, Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, Priority,
PromptBuilder, PromptButton, PromptHandle, PromptLevel, Render, RenderImage,
RenderablePromptHandle, Reservation, ScreenCaptureSource, SharedString, SubscriberSet,
Subscription, SvgRenderer, Task, TextSystem, Window, WindowAppearance, WindowHandle, WindowId,
WindowInvalidator,
colors::{Colors, GlobalColors},
current_platform, hash, init_app_menus,
};
@@ -1494,6 +1495,24 @@ impl App {
.spawn(async move { f(&mut cx).await })
}
/// Spawns the future returned by the given function on the main thread with
/// the given priority. The closure will be invoked with [AsyncApp], which
/// allows the application state to be accessed across await points.
pub fn spawn_with_priority<AsyncFn, R>(&self, priority: Priority, f: AsyncFn) -> Task<R>
where
AsyncFn: AsyncFnOnce(&mut AsyncApp) -> R + 'static,
R: 'static,
{
if self.quitting {
debug_panic!("Can't spawn on main thread after on_app_quit")
};
let mut cx = self.to_async();
self.foreground_executor
.spawn_with_priority(priority, async move { f(&mut cx).await })
}
/// Schedules the given function to be run at the end of the current effect cycle, allowing entities
/// that are currently on the stack to be returned to the app.
pub fn defer(&mut self, f: impl FnOnce(&mut App) + 'static) {

View File

@@ -1,7 +1,7 @@
use crate::{
AnyView, AnyWindowHandle, AppContext, AsyncApp, DispatchPhase, Effect, EntityId, EventEmitter,
FocusHandle, FocusOutEvent, Focusable, Global, KeystrokeObserver, Reservation, SubscriberSet,
Subscription, Task, WeakEntity, WeakFocusHandle, Window, WindowHandle,
FocusHandle, FocusOutEvent, Focusable, Global, KeystrokeObserver, Priority, Reservation,
SubscriberSet, Subscription, Task, WeakEntity, WeakFocusHandle, Window, WindowHandle,
};
use anyhow::Result;
use futures::FutureExt;
@@ -667,6 +667,25 @@ impl<'a, T: 'static> Context<'a, T> {
window.spawn(self, async move |cx| f(view, cx).await)
}
/// Schedule a future to be run asynchronously with the given priority.
/// The given callback is invoked with a [`WeakEntity<V>`] to avoid leaking the entity for a long-running process.
/// It's also given an [`AsyncWindowContext`], which can be used to access the state of the entity across await points.
/// The returned future will be polled on the main thread.
#[track_caller]
pub fn spawn_in_with_priority<AsyncFn, R>(
&self,
priority: Priority,
window: &Window,
f: AsyncFn,
) -> Task<R>
where
R: 'static,
AsyncFn: AsyncFnOnce(WeakEntity<T>, &mut AsyncWindowContext) -> R + 'static,
{
let view = self.weak_entity();
window.spawn_with_priority(priority, self, async move |cx| f(view, cx).await)
}
/// Register a callback to be invoked when the given global state changes.
pub fn observe_global_in<G: Global>(
&mut self,

View File

@@ -1,6 +1,7 @@
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant};
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant, TaskTiming, profiler};
use async_task::Runnable;
use futures::channel::mpsc;
use parking_lot::{Condvar, Mutex};
use smol::prelude::*;
use std::{
fmt::Debug,
@@ -46,6 +47,52 @@ pub struct ForegroundExecutor {
not_send: PhantomData<Rc<()>>,
}
/// Realtime task priority
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[repr(u8)]
pub enum RealtimePriority {
/// Audio task
Audio,
/// Other realtime task
#[default]
Other,
}
/// Task priority
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[repr(u8)]
pub enum Priority {
/// Realtime priority
///
/// Spawning a task with this priority will spin it off on a separate thread dedicated just to that task.
Realtime(RealtimePriority),
/// High priority
///
/// Only use for tasks that are critical to the user experience / responsiveness of the editor.
High,
/// Medium priority, probably suits most of your use cases.
#[default]
Medium,
/// Low priority
///
/// Prioritize this for background work that can come in large quantities
/// to not starve the executor of resources for high priority tasks
Low,
}
impl Priority {
#[allow(dead_code)]
pub(crate) const fn probability(&self) -> u32 {
match self {
// realtime priorities are not considered for probability scheduling
Priority::Realtime(_) => 0,
Priority::High => 60,
Priority::Medium => 30,
Priority::Low => 10,
}
}
}
/// Task is a primitive that allows work to happen in the background.
///
/// It implements [`Future`] so you can `.await` on it.
@@ -151,7 +198,77 @@ impl BackgroundExecutor {
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), None)
self.spawn_with_priority(Priority::default(), future)
}
/// Enqueues the given future to be run to completion on a background thread.
#[track_caller]
pub fn spawn_with_priority<R>(
&self,
priority: Priority,
future: impl Future<Output = R> + Send + 'static,
) -> Task<R>
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), None, priority)
}
/// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
///
/// This allows to spawn background work that borrows from its scope. Note that the supplied future will run to
/// completion before the current task is resumed, even if the current task is slated for cancellation.
pub async fn await_on_background<R>(&self, future: impl Future<Output = R> + Send) -> R
where
R: Send,
{
// We need to ensure that cancellation of the parent task does not drop the environment
// before the our own task has completed or got cancelled.
struct NotifyOnDrop<'a>(&'a (Condvar, Mutex<bool>));
impl Drop for NotifyOnDrop<'_> {
fn drop(&mut self) {
*self.0.1.lock() = true;
self.0.0.notify_all();
}
}
struct WaitOnDrop<'a>(&'a (Condvar, Mutex<bool>));
impl Drop for WaitOnDrop<'_> {
fn drop(&mut self) {
let mut done = self.0.1.lock();
if !*done {
self.0.0.wait(&mut done);
}
}
}
let dispatcher = self.dispatcher.clone();
let location = core::panic::Location::caller();
let pair = &(Condvar::new(), Mutex::new(false));
let _wait_guard = WaitOnDrop(pair);
let (runnable, task) = unsafe {
async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn_unchecked(
move |_| async {
let _notify_guard = NotifyOnDrop(pair);
future.await
},
move |runnable| {
dispatcher.dispatch(
RunnableVariant::Meta(runnable),
None,
Priority::default(),
)
},
)
};
runnable.schedule();
task.await
}
/// Enqueues the given future to be run to completion on a background thread.
@@ -165,7 +282,7 @@ impl BackgroundExecutor {
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), Some(label))
self.spawn_internal::<R>(Box::pin(future), Some(label), Priority::default())
}
#[track_caller]
@@ -173,15 +290,55 @@ impl BackgroundExecutor {
&self,
future: AnyFuture<R>,
label: Option<TaskLabel>,
priority: Priority,
) -> Task<R> {
let dispatcher = self.dispatcher.clone();
let location = core::panic::Location::caller();
let (runnable, task) = async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn(
move |_| future,
move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), label),
let (runnable, task) = if let Priority::Realtime(realtime) = priority {
let location = core::panic::Location::caller();
let (mut tx, rx) = flume::bounded::<Runnable<RunnableMeta>>(1);
dispatcher.spawn_realtime(
realtime,
Box::new(move || {
while let Ok(runnable) = rx.recv() {
let start = Instant::now();
let location = runnable.metadata().location;
let mut timing = TaskTiming {
location,
start,
end: None,
};
profiler::add_task_timing(timing);
runnable.run();
let end = Instant::now();
timing.end = Some(end);
profiler::add_task_timing(timing);
}
}),
);
async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn(
move |_| future,
move |runnable| {
let _ = tx.send(runnable);
},
)
} else {
let location = core::panic::Location::caller();
async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn(
move |_| future,
move |runnable| {
dispatcher.dispatch(RunnableVariant::Meta(runnable), label, priority)
},
)
};
runnable.schedule();
Task(TaskState::Spawned(task))
}
@@ -354,11 +511,28 @@ impl BackgroundExecutor {
where
F: FnOnce(&mut Scope<'scope>),
{
let mut scope = Scope::new(self.clone());
let mut scope = Scope::new(self.clone(), Priority::default());
(scheduler)(&mut scope);
let spawned = mem::take(&mut scope.futures)
.into_iter()
.map(|f| self.spawn(f))
.map(|f| self.spawn_with_priority(scope.priority, f))
.collect::<Vec<_>>();
for task in spawned {
task.await;
}
}
/// Scoped lets you start a number of tasks and waits
/// for all of them to complete before returning.
pub async fn scoped_priority<'scope, F>(&self, priority: Priority, scheduler: F)
where
F: FnOnce(&mut Scope<'scope>),
{
let mut scope = Scope::new(self.clone(), priority);
(scheduler)(&mut scope);
let spawned = mem::take(&mut scope.futures)
.into_iter()
.map(|f| self.spawn_with_priority(scope.priority, f))
.collect::<Vec<_>>();
for task in spawned {
task.await;
@@ -494,6 +668,19 @@ impl ForegroundExecutor {
/// Enqueues the given Task to run on the main thread at some point in the future.
#[track_caller]
pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
where
R: 'static,
{
self.spawn_with_priority(Priority::default(), future)
}
/// Enqueues the given Task to run on the main thread at some point in the future.
#[track_caller]
pub fn spawn_with_priority<R>(
&self,
priority: Priority,
future: impl Future<Output = R> + 'static,
) -> Task<R>
where
R: 'static,
{
@@ -505,16 +692,19 @@ impl ForegroundExecutor {
dispatcher: Arc<dyn PlatformDispatcher>,
future: AnyLocalFuture<R>,
location: &'static core::panic::Location<'static>,
priority: Priority,
) -> Task<R> {
let (runnable, task) = spawn_local_with_source_location(
future,
move |runnable| dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable)),
move |runnable| {
dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable), priority)
},
RunnableMeta { location },
);
runnable.schedule();
Task(TaskState::Spawned(task))
}
inner::<R>(dispatcher, Box::pin(future), location)
inner::<R>(dispatcher, Box::pin(future), location, priority)
}
}
@@ -590,6 +780,7 @@ where
/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
pub struct Scope<'a> {
executor: BackgroundExecutor,
priority: Priority,
futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
tx: Option<mpsc::Sender<()>>,
rx: mpsc::Receiver<()>,
@@ -597,10 +788,11 @@ pub struct Scope<'a> {
}
impl<'a> Scope<'a> {
fn new(executor: BackgroundExecutor) -> Self {
fn new(executor: BackgroundExecutor, priority: Priority) -> Self {
let (tx, rx) = mpsc::channel(1);
Self {
executor,
priority,
tx: Some(tx),
rx,
futures: Default::default(),

View File

@@ -1416,9 +1416,9 @@ where
/// ```
pub fn contains(&self, point: &Point<T>) -> bool {
point.x >= self.origin.x
&& point.x <= self.origin.x.clone() + self.size.width.clone()
&& point.x < self.origin.x.clone() + self.size.width.clone()
&& point.y >= self.origin.y
&& point.y <= self.origin.y.clone() + self.size.height.clone()
&& point.y < self.origin.y.clone() + self.size.height.clone()
}
/// Checks if this bounds is completely contained within another bounds.

View File

@@ -31,6 +31,8 @@ mod path_builder;
mod platform;
pub mod prelude;
mod profiler;
#[cfg(any(target_os = "windows", target_os = "linux"))]
mod queue;
mod scene;
mod shared_string;
mod shared_uri;
@@ -89,16 +91,20 @@ pub use keymap::*;
pub use path_builder::*;
pub use platform::*;
pub use profiler::*;
#[cfg(any(target_os = "windows", target_os = "linux"))]
pub(crate) use queue::{PriorityQueueReceiver, PriorityQueueSender};
pub use refineable::*;
pub use scene::*;
pub use shared_string::*;
pub use shared_uri::*;
pub use smol::Timer;
use std::{any::Any, future::Future};
pub use style::*;
pub use styled::*;
pub use subscription::*;
pub use svg_renderer::*;
pub(crate) use tab_stop::*;
use taffy::TaffyLayoutEngine;
pub use taffy::{AvailableSpace, LayoutId};
#[cfg(any(test, feature = "test-support"))]
pub use test::*;
@@ -109,9 +115,6 @@ pub use util::{FutureExt, Timeout, arc_cow::ArcCow};
pub use view::*;
pub use window::*;
use std::{any::Any, future::Future};
use taffy::TaffyLayoutEngine;
/// The context trait, allows the different contexts in GPUI to be used
/// interchangeably for certain operations.
pub trait AppContext {

View File

@@ -39,9 +39,10 @@ use crate::{
Action, AnyWindowHandle, App, AsyncWindowContext, BackgroundExecutor, Bounds,
DEFAULT_WINDOW_SIZE, DevicePixels, DispatchEventResult, Font, FontId, FontMetrics, FontRun,
ForegroundExecutor, GlyphId, GpuSpecs, ImageSource, Keymap, LineLayout, Pixels, PlatformInput,
Point, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Scene, ShapedGlyph,
ShapedRun, SharedString, Size, SvgRenderer, SystemWindowTab, Task, TaskLabel, TaskTiming,
ThreadTaskTimings, Window, WindowControlArea, hash, point, px, size,
Point, Priority, RealtimePriority, RenderGlyphParams, RenderImage, RenderImageParams,
RenderSvgParams, Scene, ShapedGlyph, ShapedRun, SharedString, Size, SvgRenderer,
SystemWindowTab, Task, TaskLabel, TaskTiming, ThreadTaskTimings, Window, WindowControlArea,
hash, point, px, size,
};
use anyhow::Result;
use async_task::Runnable;
@@ -289,6 +290,13 @@ pub trait PlatformDisplay: Send + Sync + Debug {
/// Get the bounds for this display
fn bounds(&self) -> Bounds<Pixels>;
/// Get the visible bounds for this display, excluding taskbar/dock areas.
/// This is the usable area where windows can be placed without being obscured.
/// Defaults to the full display bounds if not overridden.
fn visible_bounds(&self) -> Bounds<Pixels> {
self.bounds()
}
/// Get the default bounds for this display to place a window
fn default_bounds(&self) -> Bounds<Pixels> {
let bounds = self.bounds();
@@ -580,9 +588,10 @@ pub trait PlatformDispatcher: Send + Sync {
fn get_all_timings(&self) -> Vec<ThreadTaskTimings>;
fn get_current_thread_timings(&self) -> Vec<TaskTiming>;
fn is_main_thread(&self) -> bool;
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>);
fn dispatch_on_main_thread(&self, runnable: RunnableVariant);
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority);
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority);
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant);
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>);
fn now(&self) -> Instant {
Instant::now()

View File

@@ -1,9 +1,10 @@
use crate::{
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, RunnableVariant, THREAD_TIMINGS, TaskLabel,
TaskTiming, ThreadTaskTimings,
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, PriorityQueueReceiver,
PriorityQueueSender, RealtimePriority, RunnableVariant, THREAD_TIMINGS, TaskLabel, TaskTiming,
ThreadTaskTimings, profiler,
};
use calloop::{
EventLoop,
EventLoop, PostAction,
channel::{self, Sender},
timer::TimeoutAction,
};
@@ -19,9 +20,9 @@ struct TimerAfter {
}
pub(crate) struct LinuxDispatcher {
main_sender: Sender<RunnableVariant>,
main_sender: PriorityQueueCalloopSender<RunnableVariant>,
timer_sender: Sender<TimerAfter>,
background_sender: flume::Sender<RunnableVariant>,
background_sender: PriorityQueueSender<RunnableVariant>,
_background_threads: Vec<thread::JoinHandle<()>>,
main_thread_id: thread::ThreadId,
}
@@ -29,18 +30,20 @@ pub(crate) struct LinuxDispatcher {
const MIN_THREADS: usize = 2;
impl LinuxDispatcher {
pub fn new(main_sender: Sender<RunnableVariant>) -> Self {
let (background_sender, background_receiver) = flume::unbounded::<RunnableVariant>();
pub fn new(main_sender: PriorityQueueCalloopSender<RunnableVariant>) -> Self {
let (background_sender, background_receiver) = PriorityQueueReceiver::new();
let thread_count =
std::thread::available_parallelism().map_or(MIN_THREADS, |i| i.get().max(MIN_THREADS));
// These thread should really be lower prio then the foreground
// executor
let mut background_threads = (0..thread_count)
.map(|i| {
let receiver = background_receiver.clone();
let mut receiver = background_receiver.clone();
std::thread::Builder::new()
.name(format!("Worker-{i}"))
.spawn(move || {
for runnable in receiver {
for runnable in receiver.iter() {
let start = Instant::now();
let mut location = match runnable {
@@ -51,7 +54,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -63,7 +66,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -72,7 +75,7 @@ impl LinuxDispatcher {
let end = Instant::now();
location.end = Some(end);
Self::add_task_timing(location);
profiler::add_task_timing(location);
log::trace!(
"background thread {}: ran runnable. took: {:?}",
@@ -113,7 +116,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -124,7 +127,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -133,7 +136,7 @@ impl LinuxDispatcher {
let end = Instant::now();
timing.end = Some(end);
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
}
TimeoutAction::Drop
},
@@ -157,22 +160,6 @@ impl LinuxDispatcher {
main_thread_id: thread::current().id(),
}
}
pub(crate) fn add_task_timing(timing: TaskTiming) {
THREAD_TIMINGS.with(|timings| {
let mut timings = timings.lock();
let timings = &mut timings.timings;
if let Some(last_timing) = timings.iter_mut().rev().next() {
if last_timing.location == timing.location {
last_timing.end = timing.end;
return;
}
}
timings.push_back(timing);
});
}
}
impl PlatformDispatcher for LinuxDispatcher {
@@ -199,22 +186,26 @@ impl PlatformDispatcher for LinuxDispatcher {
thread::current().id() == self.main_thread_id
}
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>) {
self.background_sender.send(runnable).unwrap();
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>, priority: Priority) {
self.background_sender
.send(priority, runnable)
.unwrap_or_else(|_| panic!("blocking sender returned without value"));
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
self.main_sender.send(runnable).unwrap_or_else(|runnable| {
// NOTE: Runnable may wrap a Future that is !Send.
//
// This is usually safe because we only poll it on the main thread.
// However if the send fails, we know that:
// 1. main_receiver has been dropped (which implies the app is shutting down)
// 2. we are on a background thread.
// It is not safe to drop something !Send on the wrong thread, and
// the app will exit soon anyway, so we must forget the runnable.
std::mem::forget(runnable);
});
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
self.main_sender
.send(priority, runnable)
.unwrap_or_else(|runnable| {
// NOTE: Runnable may wrap a Future that is !Send.
//
// This is usually safe because we only poll it on the main thread.
// However if the send fails, we know that:
// 1. main_receiver has been dropped (which implies the app is shutting down)
// 2. we are on a background thread.
// It is not safe to drop something !Send on the wrong thread, and
// the app will exit soon anyway, so we must forget the runnable.
std::mem::forget(runnable);
});
}
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
@@ -222,4 +213,252 @@ impl PlatformDispatcher for LinuxDispatcher {
.send(TimerAfter { duration, runnable })
.ok();
}
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
// SAFETY: always safe to call
let thread_id = unsafe { libc::pthread_self() };
let policy = match priority {
RealtimePriority::Audio => libc::SCHED_FIFO,
RealtimePriority::Other => libc::SCHED_RR,
};
let sched_priority = match priority {
RealtimePriority::Audio => 65,
RealtimePriority::Other => 45,
};
let sched_param = libc::sched_param { sched_priority };
// SAFETY: sched_param is a valid initialized structure
let result = unsafe { libc::pthread_setschedparam(thread_id, policy, &sched_param) };
if result != 0 {
log::warn!("failed to set realtime thread priority to {:?}", priority);
}
f();
});
}
}
pub struct PriorityQueueCalloopSender<T> {
sender: PriorityQueueSender<T>,
ping: calloop::ping::Ping,
}
impl<T> PriorityQueueCalloopSender<T> {
fn new(tx: PriorityQueueSender<T>, ping: calloop::ping::Ping) -> Self {
Self { sender: tx, ping }
}
fn send(&self, priority: Priority, item: T) -> Result<(), crate::queue::SendError<T>> {
let res = self.sender.send(priority, item);
if res.is_ok() {
self.ping.ping();
}
res
}
}
impl<T> Drop for PriorityQueueCalloopSender<T> {
fn drop(&mut self) {
self.ping.ping();
}
}
pub struct PriorityQueueCalloopReceiver<T> {
receiver: PriorityQueueReceiver<T>,
source: calloop::ping::PingSource,
ping: calloop::ping::Ping,
}
impl<T> PriorityQueueCalloopReceiver<T> {
pub fn new() -> (PriorityQueueCalloopSender<T>, Self) {
let (ping, source) = calloop::ping::make_ping().expect("Failed to create a Ping.");
let (tx, rx) = PriorityQueueReceiver::new();
(
PriorityQueueCalloopSender::new(tx, ping.clone()),
Self {
receiver: rx,
source,
ping,
},
)
}
}
use calloop::channel::Event;
#[derive(Debug)]
pub struct ChannelError(calloop::ping::PingError);
impl std::fmt::Display for ChannelError {
#[cfg_attr(feature = "nightly_coverage", coverage(off))]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for ChannelError {
#[cfg_attr(feature = "nightly_coverage", coverage(off))]
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.0)
}
}
impl<T> calloop::EventSource for PriorityQueueCalloopReceiver<T> {
type Event = Event<T>;
type Metadata = ();
type Ret = ();
type Error = ChannelError;
fn process_events<F>(
&mut self,
readiness: calloop::Readiness,
token: calloop::Token,
mut callback: F,
) -> Result<calloop::PostAction, Self::Error>
where
F: FnMut(Self::Event, &mut Self::Metadata) -> Self::Ret,
{
let mut clear_readiness = false;
let mut disconnected = false;
let action = self
.source
.process_events(readiness, token, |(), &mut ()| {
let mut is_empty = true;
let mut receiver = self.receiver.clone();
for runnable in receiver.try_iter() {
match runnable {
Ok(r) => {
callback(Event::Msg(r), &mut ());
is_empty = false;
}
Err(_) => {
disconnected = true;
}
}
}
if disconnected {
callback(Event::Closed, &mut ());
}
if is_empty {
clear_readiness = true;
}
})
.map_err(ChannelError)?;
if disconnected {
Ok(PostAction::Remove)
} else if clear_readiness {
Ok(action)
} else {
// Re-notify the ping source so we can try again.
self.ping.ping();
Ok(PostAction::Continue)
}
}
fn register(
&mut self,
poll: &mut calloop::Poll,
token_factory: &mut calloop::TokenFactory,
) -> calloop::Result<()> {
self.source.register(poll, token_factory)
}
fn reregister(
&mut self,
poll: &mut calloop::Poll,
token_factory: &mut calloop::TokenFactory,
) -> calloop::Result<()> {
self.source.reregister(poll, token_factory)
}
fn unregister(&mut self, poll: &mut calloop::Poll) -> calloop::Result<()> {
self.source.unregister(poll)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn calloop_works() {
let mut event_loop = calloop::EventLoop::try_new().unwrap();
let handle = event_loop.handle();
let (tx, rx) = PriorityQueueCalloopReceiver::new();
struct Data {
got_msg: bool,
got_closed: bool,
}
let mut data = Data {
got_msg: false,
got_closed: false,
};
let _channel_token = handle
.insert_source(rx, move |evt, &mut (), data: &mut Data| match evt {
Event::Msg(()) => {
data.got_msg = true;
}
Event::Closed => {
data.got_closed = true;
}
})
.unwrap();
// nothing is sent, nothing is received
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
.unwrap();
assert!(!data.got_msg);
assert!(!data.got_closed);
// a message is send
tx.send(Priority::Medium, ()).unwrap();
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
.unwrap();
assert!(data.got_msg);
assert!(!data.got_closed);
// the sender is dropped
drop(tx);
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
.unwrap();
assert!(data.got_msg);
assert!(data.got_closed);
}
}
// running 1 test
// test platform::linux::dispatcher::tests::tomato ... FAILED
// failures:
// ---- platform::linux::dispatcher::tests::tomato stdout ----
// [crates/gpui/src/platform/linux/dispatcher.rs:262:9]
// returning 1 tasks to process
// [crates/gpui/src/platform/linux/dispatcher.rs:480:75] evt = Msg(
// (),
// )
// returning 0 tasks to process
// thread 'platform::linux::dispatcher::tests::tomato' (478301) panicked at crates/gpui/src/platform/linux/dispatcher.rs:515:9:
// assertion failed: data.got_closed
// note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

View File

@@ -14,7 +14,7 @@ use std::{
};
use anyhow::{Context as _, anyhow};
use calloop::{LoopSignal, channel::Channel};
use calloop::LoopSignal;
use futures::channel::oneshot;
use util::ResultExt as _;
use util::command::{new_smol_command, new_std_command};
@@ -25,8 +25,8 @@ use crate::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId,
ForegroundExecutor, Keymap, LinuxDispatcher, Menu, MenuItem, OwnedMenu, PathPromptOptions,
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper,
PlatformTextSystem, PlatformWindow, Point, Result, RunnableVariant, Task, WindowAppearance,
WindowParams, px,
PlatformTextSystem, PlatformWindow, Point, PriorityQueueCalloopReceiver, Result,
RunnableVariant, Task, WindowAppearance, WindowParams, px,
};
#[cfg(any(feature = "wayland", feature = "x11"))]
@@ -149,8 +149,8 @@ pub(crate) struct LinuxCommon {
}
impl LinuxCommon {
pub fn new(signal: LoopSignal) -> (Self, Channel<RunnableVariant>) {
let (main_sender, main_receiver) = calloop::channel::channel::<RunnableVariant>();
pub fn new(signal: LoopSignal) -> (Self, PriorityQueueCalloopReceiver<RunnableVariant>) {
let (main_sender, main_receiver) = PriorityQueueCalloopReceiver::new();
#[cfg(any(feature = "wayland", feature = "x11"))]
let text_system = Arc::new(crate::CosmicTextSystem::new());

View File

@@ -77,10 +77,10 @@ use crate::{
LinuxKeyboardLayout, Modifiers, ModifiersChangedEvent, MouseButton, MouseDownEvent,
MouseExitEvent, MouseMoveEvent, MouseUpEvent, NavigationDirection, Pixels, PlatformDisplay,
PlatformInput, PlatformKeyboardLayout, Point, ResultExt as _, SCROLL_LINES, ScrollDelta,
ScrollWheelEvent, Size, TouchPhase, WindowParams, point, px, size,
ScrollWheelEvent, Size, TouchPhase, WindowParams, point, profiler, px, size,
};
use crate::{
LinuxDispatcher, RunnableVariant, TaskTiming,
RunnableVariant, TaskTiming,
platform::{PlatformWindow, blade::BladeContext},
};
use crate::{
@@ -503,7 +503,7 @@ impl WaylandClient {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -515,7 +515,7 @@ impl WaylandClient {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -524,7 +524,7 @@ impl WaylandClient {
let end = Instant::now();
timing.end = Some(end);
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
});
}
}

View File

@@ -1,4 +1,4 @@
use crate::{Capslock, LinuxDispatcher, ResultExt as _, RunnableVariant, TaskTiming, xcb_flush};
use crate::{Capslock, ResultExt as _, RunnableVariant, TaskTiming, profiler, xcb_flush};
use anyhow::{Context as _, anyhow};
use ashpd::WindowIdentifier;
use calloop::{
@@ -322,7 +322,7 @@ impl X11Client {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -334,7 +334,7 @@ impl X11Client {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -343,7 +343,7 @@ impl X11Client {
let end = Instant::now();
timing.end = Some(end);
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
});
}
}

View File

@@ -3,11 +3,22 @@
#![allow(non_snake_case)]
use crate::{
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, RunnableMeta, RunnableVariant, THREAD_TIMINGS,
TaskLabel, TaskTiming, ThreadTaskTimings,
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, RealtimePriority, RunnableMeta,
RunnableVariant, THREAD_TIMINGS, TaskLabel, TaskTiming, ThreadTaskTimings,
};
use anyhow::Context;
use async_task::Runnable;
use mach2::{
kern_return::KERN_SUCCESS,
mach_time::mach_timebase_info_data_t,
thread_policy::{
THREAD_EXTENDED_POLICY, THREAD_EXTENDED_POLICY_COUNT, THREAD_PRECEDENCE_POLICY,
THREAD_PRECEDENCE_POLICY_COUNT, THREAD_TIME_CONSTRAINT_POLICY,
THREAD_TIME_CONSTRAINT_POLICY_COUNT, thread_extended_policy_data_t,
thread_precedence_policy_data_t, thread_time_constraint_policy_data_t,
},
};
use objc::{
class, msg_send,
runtime::{BOOL, YES},
@@ -15,9 +26,11 @@ use objc::{
};
use std::{
ffi::c_void,
mem::MaybeUninit,
ptr::{NonNull, addr_of},
time::{Duration, Instant},
};
use util::ResultExt;
/// All items in the generated file are marked as pub, so we're gonna wrap it in a separate mod to prevent
/// these pub items from leaking into public API.
@@ -56,7 +69,7 @@ impl PlatformDispatcher for MacDispatcher {
is_main_thread == YES
}
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>) {
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>, priority: Priority) {
let (context, trampoline) = match runnable {
RunnableVariant::Meta(runnable) => (
runnable.into_raw().as_ptr() as *mut c_void,
@@ -67,16 +80,24 @@ impl PlatformDispatcher for MacDispatcher {
Some(trampoline_compat as unsafe extern "C" fn(*mut c_void)),
),
};
let queue_priority = match priority {
Priority::Realtime(_) => unreachable!(),
Priority::High => DISPATCH_QUEUE_PRIORITY_HIGH as isize,
Priority::Medium => DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
Priority::Low => DISPATCH_QUEUE_PRIORITY_LOW as isize,
};
unsafe {
dispatch_async_f(
dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH.try_into().unwrap(), 0),
dispatch_get_global_queue(queue_priority, 0),
context,
trampoline,
);
}
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
let (context, trampoline) = match runnable {
RunnableVariant::Meta(runnable) => (
runnable.into_raw().as_ptr() as *mut c_void,
@@ -110,6 +131,120 @@ impl PlatformDispatcher for MacDispatcher {
dispatch_after_f(when, queue, context, trampoline);
}
}
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
match priority {
RealtimePriority::Audio => set_audio_thread_priority(),
RealtimePriority::Other => set_high_thread_priority(),
}
.context(format!("for priority {:?}", priority))
.log_err();
f();
});
}
}
fn set_high_thread_priority() -> anyhow::Result<()> {
// SAFETY: always safe to call
let thread_id = unsafe { libc::pthread_self() };
// SAFETY: all sched_param members are valid when initialized to zero.
let mut sched_param = unsafe { MaybeUninit::<libc::sched_param>::zeroed().assume_init() };
sched_param.sched_priority = 45;
let result = unsafe { libc::pthread_setschedparam(thread_id, libc::SCHED_FIFO, &sched_param) };
if result != 0 {
anyhow::bail!("failed to set realtime thread priority")
}
Ok(())
}
fn set_audio_thread_priority() -> anyhow::Result<()> {
// https://chromium.googlesource.com/chromium/chromium/+/master/base/threading/platform_thread_mac.mm#93
// SAFETY: always safe to call
let thread_id = unsafe { libc::pthread_self() };
// SAFETY: thread_id is a valid thread id
let thread_id = unsafe { libc::pthread_mach_thread_np(thread_id) };
// Fixed priority thread
let mut policy = thread_extended_policy_data_t { timeshare: 0 };
// SAFETY: thread_id is a valid thread id
// SAFETY: thread_extended_policy_data_t is passed as THREAD_EXTENDED_POLICY
let result = unsafe {
mach2::thread_policy::thread_policy_set(
thread_id,
THREAD_EXTENDED_POLICY,
&mut policy as *mut _ as *mut _,
THREAD_EXTENDED_POLICY_COUNT,
)
};
if result != KERN_SUCCESS {
anyhow::bail!("failed to set thread extended policy");
}
// relatively high priority
let mut precedence = thread_precedence_policy_data_t { importance: 63 };
// SAFETY: thread_id is a valid thread id
// SAFETY: thread_precedence_policy_data_t is passed as THREAD_PRECEDENCE_POLICY
let result = unsafe {
mach2::thread_policy::thread_policy_set(
thread_id,
THREAD_PRECEDENCE_POLICY,
&mut precedence as *mut _ as *mut _,
THREAD_PRECEDENCE_POLICY_COUNT,
)
};
if result != KERN_SUCCESS {
anyhow::bail!("failed to set thread precedence policy");
}
const GUARANTEED_AUDIO_DUTY_CYCLE: f32 = 0.75;
const MAX_AUDIO_DUTY_CYCLE: f32 = 0.85;
// ~128 frames @ 44.1KHz
const TIME_QUANTUM: f32 = 2.9;
const AUDIO_TIME_NEEDED: f32 = GUARANTEED_AUDIO_DUTY_CYCLE * TIME_QUANTUM;
const MAX_TIME_ALLOWED: f32 = MAX_AUDIO_DUTY_CYCLE * TIME_QUANTUM;
let mut timebase_info = mach_timebase_info_data_t { numer: 0, denom: 0 };
// SAFETY: timebase_info is a valid pointer to a mach_timebase_info_data_t struct
unsafe { mach2::mach_time::mach_timebase_info(&mut timebase_info) };
let ms_to_abs_time = ((timebase_info.denom as f32) / (timebase_info.numer as f32)) * 1000000f32;
let mut time_constraints = thread_time_constraint_policy_data_t {
period: (TIME_QUANTUM * ms_to_abs_time) as u32,
computation: (AUDIO_TIME_NEEDED * ms_to_abs_time) as u32,
constraint: (MAX_TIME_ALLOWED * ms_to_abs_time) as u32,
preemptible: 0,
};
// SAFETY: thread_id is a valid thread id
// SAFETY: thread_precedence_pthread_time_constraint_policy_data_t is passed as THREAD_TIME_CONSTRAINT_POLICY
let result = unsafe {
mach2::thread_policy::thread_policy_set(
thread_id,
THREAD_TIME_CONSTRAINT_POLICY,
&mut time_constraints as *mut _ as *mut _,
THREAD_TIME_CONSTRAINT_POLICY_COUNT,
)
};
if result != KERN_SUCCESS {
anyhow::bail!("failed to set thread time constraint policy");
}
Ok(())
}
extern "C" fn trampoline(runnable: *mut c_void) {

View File

@@ -1,9 +1,9 @@
use crate::{Bounds, DisplayId, Pixels, PlatformDisplay, px, size};
use crate::{Bounds, DisplayId, Pixels, PlatformDisplay, point, px, size};
use anyhow::Result;
use cocoa::{
appkit::NSScreen,
base::{id, nil},
foundation::{NSDictionary, NSString},
foundation::{NSArray, NSDictionary, NSString},
};
use core_foundation::uuid::{CFUUIDGetUUIDBytes, CFUUIDRef};
use core_graphics::display::{CGDirectDisplayID, CGDisplayBounds, CGGetActiveDisplayList};
@@ -114,4 +114,53 @@ impl PlatformDisplay for MacDisplay {
}
}
}
fn visible_bounds(&self) -> Bounds<Pixels> {
unsafe {
let dominated_screen = self.get_nsscreen();
if dominated_screen == nil {
return self.bounds();
}
let screen_frame = NSScreen::frame(dominated_screen);
let visible_frame = NSScreen::visibleFrame(dominated_screen);
// Convert from bottom-left origin (AppKit) to top-left origin
let origin_y =
screen_frame.size.height - visible_frame.origin.y - visible_frame.size.height
+ screen_frame.origin.y;
Bounds {
origin: point(
px(visible_frame.origin.x as f32 - screen_frame.origin.x as f32),
px(origin_y as f32),
),
size: size(
px(visible_frame.size.width as f32),
px(visible_frame.size.height as f32),
),
}
}
}
}
impl MacDisplay {
/// Find the NSScreen corresponding to this display
unsafe fn get_nsscreen(&self) -> id {
let screens = unsafe { NSScreen::screens(nil) };
let count = unsafe { NSArray::count(screens) };
let screen_number_key: id = unsafe { NSString::alloc(nil).init_str("NSScreenNumber") };
for i in 0..count {
let screen = unsafe { NSArray::objectAtIndex(screens, i) };
let device_description = unsafe { NSScreen::deviceDescription(screen) };
let screen_number = unsafe { device_description.objectForKey_(screen_number_key) };
let screen_id: CGDirectDisplayID = msg_send![screen_number, unsignedIntegerValue];
if screen_id == self.0 {
return screen;
}
}
nil
}
}

View File

@@ -1,4 +1,4 @@
use crate::{PlatformDispatcher, RunnableVariant, TaskLabel};
use crate::{PlatformDispatcher, Priority, RunnableVariant, TaskLabel};
use backtrace::Backtrace;
use collections::{HashMap, HashSet, VecDeque};
use parking::Unparker;
@@ -284,7 +284,7 @@ impl PlatformDispatcher for TestDispatcher {
state.start_time + state.time
}
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, _priority: Priority) {
{
let mut state = self.state.lock();
if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
@@ -296,7 +296,7 @@ impl PlatformDispatcher for TestDispatcher {
self.unpark_all();
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
self.state
.lock()
.foreground
@@ -318,4 +318,10 @@ impl PlatformDispatcher for TestDispatcher {
fn as_test(&self) -> Option<&TestDispatcher> {
Some(self)
}
fn spawn_realtime(&self, _priority: crate::RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
f();
});
}
}

View File

@@ -4,24 +4,31 @@ use std::{
time::{Duration, Instant},
};
use flume::Sender;
use anyhow::Context;
use util::ResultExt;
use windows::{
System::Threading::{ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler},
System::Threading::{
ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
},
Win32::{
Foundation::{LPARAM, WPARAM},
System::Threading::{
GetCurrentThread, HIGH_PRIORITY_CLASS, SetPriorityClass, SetThreadPriority,
THREAD_PRIORITY_HIGHEST, THREAD_PRIORITY_TIME_CRITICAL,
},
UI::WindowsAndMessaging::PostMessageW,
},
};
use crate::{
GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, RunnableVariant, SafeHwnd, THREAD_TIMINGS,
TaskLabel, TaskTiming, ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, Priority, PriorityQueueSender,
RealtimePriority, RunnableVariant, SafeHwnd, THREAD_TIMINGS, TaskLabel, TaskTiming,
ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD, profiler,
};
pub(crate) struct WindowsDispatcher {
pub(crate) wake_posted: AtomicBool,
main_sender: Sender<RunnableVariant>,
main_sender: PriorityQueueSender<RunnableVariant>,
main_thread_id: ThreadId,
pub(crate) platform_window_handle: SafeHwnd,
validation_number: usize,
@@ -29,7 +36,7 @@ pub(crate) struct WindowsDispatcher {
impl WindowsDispatcher {
pub(crate) fn new(
main_sender: Sender<RunnableVariant>,
main_sender: PriorityQueueSender<RunnableVariant>,
platform_window_handle: HWND,
validation_number: usize,
) -> Self {
@@ -45,7 +52,7 @@ impl WindowsDispatcher {
}
}
fn dispatch_on_threadpool(&self, runnable: RunnableVariant) {
fn dispatch_on_threadpool(&self, priority: WorkItemPriority, runnable: RunnableVariant) {
let handler = {
let mut task_wrapper = Some(runnable);
WorkItemHandler::new(move |_| {
@@ -53,7 +60,8 @@ impl WindowsDispatcher {
Ok(())
})
};
ThreadPool::RunAsync(&handler).log_err();
ThreadPool::RunWithPriorityAsync(&handler, priority).log_err();
}
fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
@@ -79,7 +87,7 @@ impl WindowsDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
@@ -91,7 +99,7 @@ impl WindowsDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
@@ -102,23 +110,7 @@ impl WindowsDispatcher {
let end = Instant::now();
timing.end = Some(end);
Self::add_task_timing(timing);
}
pub(crate) fn add_task_timing(timing: TaskTiming) {
THREAD_TIMINGS.with(|timings| {
let mut timings = timings.lock();
let timings = &mut timings.timings;
if let Some(last_timing) = timings.iter_mut().rev().next() {
if last_timing.location == timing.location {
last_timing.end = timing.end;
return;
}
}
timings.push_back(timing);
});
profiler::add_task_timing(timing);
}
}
@@ -146,15 +138,22 @@ impl PlatformDispatcher for WindowsDispatcher {
current().id() == self.main_thread_id
}
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
self.dispatch_on_threadpool(runnable);
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority) {
let priority = match priority {
Priority::Realtime(_) => unreachable!(),
Priority::High => WorkItemPriority::High,
Priority::Medium => WorkItemPriority::Normal,
Priority::Low => WorkItemPriority::Low,
};
self.dispatch_on_threadpool(priority, runnable);
if let Some(label) = label {
log::debug!("TaskLabel: {label:?}");
}
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
match self.main_sender.send(runnable) {
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
match self.main_sender.send(priority, runnable) {
Ok(_) => {
if !self.wake_posted.swap(true, Ordering::AcqRel) {
unsafe {
@@ -185,4 +184,28 @@ impl PlatformDispatcher for WindowsDispatcher {
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
self.dispatch_on_threadpool_after(runnable, duration);
}
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
// SAFETY: always safe to call
let thread_handle = unsafe { GetCurrentThread() };
let thread_priority = match priority {
RealtimePriority::Audio => THREAD_PRIORITY_TIME_CRITICAL,
RealtimePriority::Other => THREAD_PRIORITY_HIGHEST,
};
// SAFETY: thread_handle is a valid handle to a thread
unsafe { SetPriorityClass(thread_handle, HIGH_PRIORITY_CLASS) }
.context("thread priority class")
.log_err();
// SAFETY: thread_handle is a valid handle to a thread
unsafe { SetThreadPriority(thread_handle, thread_priority) }
.context("thread priority")
.log_err();
f();
});
}
}

View File

@@ -23,6 +23,7 @@ pub(crate) struct WindowsDisplay {
pub display_id: DisplayId,
scale_factor: f32,
bounds: Bounds<Pixels>,
visible_bounds: Bounds<Pixels>,
physical_bounds: Bounds<DevicePixels>,
uuid: Uuid,
}
@@ -36,6 +37,7 @@ impl WindowsDisplay {
let screen = available_monitors().into_iter().nth(display_id.0 as _)?;
let info = get_monitor_info(screen).log_err()?;
let monitor_size = info.monitorInfo.rcMonitor;
let work_area = info.monitorInfo.rcWork;
let uuid = generate_uuid(&info.szDevice);
let scale_factor = get_scale_factor_for_monitor(screen).log_err()?;
let physical_size = size(
@@ -55,6 +57,14 @@ impl WindowsDisplay {
),
size: physical_size.to_pixels(scale_factor),
},
visible_bounds: Bounds {
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
size: size(
(work_area.right - work_area.left) as f32 / scale_factor,
(work_area.bottom - work_area.top) as f32 / scale_factor,
)
.map(crate::px),
},
physical_bounds: Bounds {
origin: point(monitor_size.left.into(), monitor_size.top.into()),
size: physical_size,
@@ -66,6 +76,7 @@ impl WindowsDisplay {
pub fn new_with_handle(monitor: HMONITOR) -> anyhow::Result<Self> {
let info = get_monitor_info(monitor)?;
let monitor_size = info.monitorInfo.rcMonitor;
let work_area = info.monitorInfo.rcWork;
let uuid = generate_uuid(&info.szDevice);
let display_id = available_monitors()
.iter()
@@ -89,6 +100,14 @@ impl WindowsDisplay {
),
size: physical_size.to_pixels(scale_factor),
},
visible_bounds: Bounds {
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
size: size(
(work_area.right - work_area.left) as f32 / scale_factor,
(work_area.bottom - work_area.top) as f32 / scale_factor,
)
.map(crate::px),
},
physical_bounds: Bounds {
origin: point(monitor_size.left.into(), monitor_size.top.into()),
size: physical_size,
@@ -100,6 +119,7 @@ impl WindowsDisplay {
fn new_with_handle_and_id(handle: HMONITOR, display_id: DisplayId) -> anyhow::Result<Self> {
let info = get_monitor_info(handle)?;
let monitor_size = info.monitorInfo.rcMonitor;
let work_area = info.monitorInfo.rcWork;
let uuid = generate_uuid(&info.szDevice);
let scale_factor = get_scale_factor_for_monitor(handle)?;
let physical_size = size(
@@ -119,6 +139,14 @@ impl WindowsDisplay {
),
size: physical_size.to_pixels(scale_factor),
},
visible_bounds: Bounds {
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
size: size(
(work_area.right - work_area.left) as f32 / scale_factor,
(work_area.bottom - work_area.top) as f32 / scale_factor,
)
.map(crate::px),
},
physical_bounds: Bounds {
origin: point(monitor_size.left.into(), monitor_size.top.into()),
size: physical_size,
@@ -193,6 +221,10 @@ impl PlatformDisplay for WindowsDisplay {
fn bounds(&self) -> Bounds<Pixels> {
self.bounds
}
fn visible_bounds(&self) -> Bounds<Pixels> {
self.visible_bounds
}
}
fn available_monitors() -> SmallVec<[HMONITOR; 4]> {

View File

@@ -243,7 +243,8 @@ impl WindowsWindowInner {
fn handle_timer_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> {
if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID {
for runnable in self.main_receiver.drain() {
let mut runnables = self.main_receiver.clone().try_iter();
while let Some(Ok(runnable)) = runnables.next() {
WindowsDispatcher::execute_runnable(runnable);
}
self.handle_paint_msg(handle)

View File

@@ -51,7 +51,7 @@ struct WindowsPlatformInner {
raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
// The below members will never change throughout the entire lifecycle of the app.
validation_number: usize,
main_receiver: flume::Receiver<RunnableVariant>,
main_receiver: PriorityQueueReceiver<RunnableVariant>,
dispatcher: Arc<WindowsDispatcher>,
}
@@ -98,7 +98,7 @@ impl WindowsPlatform {
OleInitialize(None).context("unable to initialize Windows OLE")?;
}
let directx_devices = DirectXDevices::new().context("Creating DirectX devices")?;
let (main_sender, main_receiver) = flume::unbounded::<RunnableVariant>();
let (main_sender, main_receiver) = PriorityQueueReceiver::new();
let validation_number = if usize::BITS == 64 {
rand::random::<u64>() as usize
} else {
@@ -857,22 +857,24 @@ impl WindowsPlatformInner {
}
break 'tasks;
}
match self.main_receiver.try_recv() {
Err(_) => break 'timeout_loop,
Ok(runnable) => WindowsDispatcher::execute_runnable(runnable),
let mut main_receiver = self.main_receiver.clone();
match main_receiver.try_pop() {
Ok(Some(runnable)) => WindowsDispatcher::execute_runnable(runnable),
_ => break 'timeout_loop,
}
}
// Someone could enqueue a Runnable here. The flag is still true, so they will not PostMessage.
// We need to check for those Runnables after we clear the flag.
self.dispatcher.wake_posted.store(false, Ordering::Release);
match self.main_receiver.try_recv() {
Err(_) => break 'tasks,
Ok(runnable) => {
let mut main_receiver = self.main_receiver.clone();
match main_receiver.try_pop() {
Ok(Some(runnable)) => {
self.dispatcher.wake_posted.store(true, Ordering::Release);
WindowsDispatcher::execute_runnable(runnable);
}
_ => break 'tasks,
}
}
@@ -934,7 +936,7 @@ pub(crate) struct WindowCreationInfo {
pub(crate) windows_version: WindowsVersion,
pub(crate) drop_target_helper: IDropTargetHelper,
pub(crate) validation_number: usize,
pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
pub(crate) platform_window_handle: HWND,
pub(crate) disable_direct_composition: bool,
pub(crate) directx_devices: DirectXDevices,
@@ -947,8 +949,8 @@ struct PlatformWindowCreateContext {
inner: Option<Result<Rc<WindowsPlatformInner>>>,
raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
validation_number: usize,
main_sender: Option<flume::Sender<RunnableVariant>>,
main_receiver: Option<flume::Receiver<RunnableVariant>>,
main_sender: Option<PriorityQueueSender<RunnableVariant>>,
main_receiver: Option<PriorityQueueReceiver<RunnableVariant>>,
directx_devices: Option<DirectXDevices>,
dispatcher: Option<Arc<WindowsDispatcher>>,
}

View File

@@ -81,7 +81,7 @@ pub(crate) struct WindowsWindowInner {
pub(crate) executor: ForegroundExecutor,
pub(crate) windows_version: WindowsVersion,
pub(crate) validation_number: usize,
pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
pub(crate) platform_window_handle: HWND,
}
@@ -362,7 +362,7 @@ struct WindowCreateContext {
windows_version: WindowsVersion,
drop_target_helper: IDropTargetHelper,
validation_number: usize,
main_receiver: flume::Receiver<RunnableVariant>,
main_receiver: PriorityQueueReceiver<RunnableVariant>,
platform_window_handle: HWND,
appearance: WindowAppearance,
disable_direct_composition: bool,

View File

@@ -216,3 +216,19 @@ impl Drop for ThreadTimings {
thread_timings.swap_remove(index);
}
}
pub(crate) fn add_task_timing(timing: TaskTiming) {
THREAD_TIMINGS.with(|timings| {
let mut timings = timings.lock();
let timings = &mut timings.timings;
if let Some(last_timing) = timings.iter_mut().rev().next() {
if last_timing.location == timing.location {
last_timing.end = timing.end;
return;
}
}
timings.push_back(timing);
});
}

329
crates/gpui/src/queue.rs Normal file
View File

@@ -0,0 +1,329 @@
use std::{
fmt,
iter::FusedIterator,
sync::{Arc, atomic::AtomicUsize},
};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use crate::Priority;
struct PriorityQueues<T> {
high_priority: Vec<T>,
medium_priority: Vec<T>,
low_priority: Vec<T>,
}
impl<T> PriorityQueues<T> {
fn is_empty(&self) -> bool {
self.high_priority.is_empty()
&& self.medium_priority.is_empty()
&& self.low_priority.is_empty()
}
}
struct PriorityQueueState<T> {
queues: parking_lot::Mutex<PriorityQueues<T>>,
condvar: parking_lot::Condvar,
receiver_count: AtomicUsize,
sender_count: AtomicUsize,
}
impl<T> PriorityQueueState<T> {
fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
if self
.receiver_count
.load(std::sync::atomic::Ordering::Relaxed)
== 0
{
return Err(SendError(item));
}
let mut queues = self.queues.lock();
match priority {
Priority::Realtime(_) => unreachable!(),
Priority::High => queues.high_priority.push(item),
Priority::Medium => queues.medium_priority.push(item),
Priority::Low => queues.low_priority.push(item),
};
self.condvar.notify_one();
Ok(())
}
fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
let mut queues = self.queues.lock();
let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
if queues.is_empty() && sender_count == 0 {
return Err(crate::queue::RecvError);
}
// parking_lot doesn't do spurious wakeups so an if is fine
if queues.is_empty() {
self.condvar.wait(&mut queues);
}
Ok(queues)
}
fn try_recv<'a>(
&'a self,
) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
let mut queues = self.queues.lock();
let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
if queues.is_empty() && sender_count == 0 {
return Err(crate::queue::RecvError);
}
if queues.is_empty() {
Ok(None)
} else {
Ok(Some(queues))
}
}
}
pub(crate) struct PriorityQueueSender<T> {
state: Arc<PriorityQueueState<T>>,
}
impl<T> PriorityQueueSender<T> {
fn new(state: Arc<PriorityQueueState<T>>) -> Self {
Self { state }
}
pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
self.state.send(priority, item)?;
Ok(())
}
}
impl<T> Drop for PriorityQueueSender<T> {
fn drop(&mut self) {
self.state
.sender_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
}
}
pub(crate) struct PriorityQueueReceiver<T> {
state: Arc<PriorityQueueState<T>>,
rand: SmallRng,
disconnected: bool,
}
impl<T> Clone for PriorityQueueReceiver<T> {
fn clone(&self) -> Self {
self.state
.receiver_count
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
Self {
state: Arc::clone(&self.state),
rand: SmallRng::seed_from_u64(0),
disconnected: self.disconnected,
}
}
}
pub(crate) struct SendError<T>(T);
impl<T: fmt::Debug> fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("SendError").field(&self.0).finish()
}
}
#[derive(Debug)]
pub(crate) struct RecvError;
#[allow(dead_code)]
impl<T> PriorityQueueReceiver<T> {
pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
let state = PriorityQueueState {
queues: parking_lot::Mutex::new(PriorityQueues {
high_priority: Vec::new(),
medium_priority: Vec::new(),
low_priority: Vec::new(),
}),
condvar: parking_lot::Condvar::new(),
receiver_count: AtomicUsize::new(1),
sender_count: AtomicUsize::new(1),
};
let state = Arc::new(state);
let sender = PriorityQueueSender::new(Arc::clone(&state));
let receiver = PriorityQueueReceiver {
state,
rand: SmallRng::seed_from_u64(0),
disconnected: false,
};
(sender, receiver)
}
/// Tries to pop one element from the priority queue without blocking.
///
/// This will early return if there are no elements in the queue.
///
/// This method is best suited if you only intend to pop one element, for better performance
/// on large queues see [`Self::try_iter`]
///
/// # Errors
///
/// If the sender was dropped
pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
self.pop_inner(false)
}
/// Pops an element from the priority queue blocking if necessary.
///
/// This method is best suited if you only intend to pop one element, for better performance
/// on large queues see [`Self::iter``]
///
/// # Errors
///
/// If the sender was dropped
pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
self.pop_inner(true).map(|e| e.unwrap())
}
/// Returns an iterator over the elements of the queue
/// this iterator will end when all elements have been consumed and will not wait for new ones.
pub(crate) fn try_iter(self) -> TryIter<T> {
TryIter {
receiver: self,
ended: false,
}
}
/// Returns an iterator over the elements of the queue
/// this iterator will wait for new elements if the queue is empty.
pub(crate) fn iter(self) -> Iter<T> {
Iter(self)
}
#[inline(always)]
// algorithm is the loaded die from biased coin from
// https://www.keithschwarz.com/darts-dice-coins/
fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
use Priority as P;
let mut queues = if !block {
let Some(queues) = self.state.try_recv()? else {
return Ok(None);
};
queues
} else {
self.state.recv()?
};
let high = P::High.probability() * !queues.high_priority.is_empty() as u32;
let medium = P::Medium.probability() * !queues.medium_priority.is_empty() as u32;
let low = P::Low.probability() * !queues.low_priority.is_empty() as u32;
let mut mass = high + medium + low; //%
if !queues.high_priority.is_empty() {
let flip = self.rand.random_ratio(P::High.probability(), mass);
if flip {
return Ok(queues.high_priority.pop());
}
mass -= P::High.probability();
}
if !queues.medium_priority.is_empty() {
let flip = self.rand.random_ratio(P::Medium.probability(), mass);
if flip {
return Ok(queues.medium_priority.pop());
}
mass -= P::Medium.probability();
}
if !queues.low_priority.is_empty() {
let flip = self.rand.random_ratio(P::Low.probability(), mass);
if flip {
return Ok(queues.low_priority.pop());
}
}
Ok(None)
}
}
impl<T> Drop for PriorityQueueReceiver<T> {
fn drop(&mut self) {
self.state
.receiver_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
}
}
/// If None is returned the sender disconnected
pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
impl<T> Iterator for Iter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.0.pop_inner(true).ok().flatten()
}
}
impl<T> FusedIterator for Iter<T> {}
/// If None is returned there are no more elements in the queue
pub(crate) struct TryIter<T> {
receiver: PriorityQueueReceiver<T>,
ended: bool,
}
impl<T> Iterator for TryIter<T> {
type Item = Result<T, RecvError>;
fn next(&mut self) -> Option<Self::Item> {
if self.ended {
return None;
}
let res = self.receiver.pop_inner(false);
self.ended = res.is_err();
res.transpose()
}
}
impl<T> FusedIterator for TryIter<T> {}
#[cfg(test)]
mod tests {
use collections::HashSet;
use super::*;
#[test]
fn all_tasks_get_yielded() {
let (tx, mut rx) = PriorityQueueReceiver::new();
tx.send(Priority::Medium, 20).unwrap();
tx.send(Priority::High, 30).unwrap();
tx.send(Priority::Low, 10).unwrap();
tx.send(Priority::Medium, 21).unwrap();
tx.send(Priority::High, 31).unwrap();
drop(tx);
assert_eq!(
rx.iter().collect::<HashSet<_>>(),
[30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
)
}
#[test]
fn new_high_prio_task_get_scheduled_quickly() {
let (tx, mut rx) = PriorityQueueReceiver::new();
for _ in 0..100 {
tx.send(Priority::Low, 1).unwrap();
}
assert_eq!(rx.pop().unwrap(), 1);
tx.send(Priority::High, 3).unwrap();
assert_eq!(rx.pop().unwrap(), 3);
assert_eq!(rx.pop().unwrap(), 1);
}
}

View File

@@ -1,8 +1,9 @@
use crate::{
self as gpui, AbsoluteLength, AlignContent, AlignItems, BorderStyle, CursorStyle,
DefiniteLength, Display, Fill, FlexDirection, FlexWrap, Font, FontStyle, FontWeight,
GridPlacement, Hsla, JustifyContent, Length, SharedString, StrikethroughStyle, StyleRefinement,
TextAlign, TextOverflow, TextStyleRefinement, UnderlineStyle, WhiteSpace, px, relative, rems,
DefiniteLength, Display, Fill, FlexDirection, FlexWrap, Font, FontFeatures, FontStyle,
FontWeight, GridPlacement, Hsla, JustifyContent, Length, SharedString, StrikethroughStyle,
StyleRefinement, TextAlign, TextOverflow, TextStyleRefinement, UnderlineStyle, WhiteSpace, px,
relative, rems,
};
pub use gpui_macros::{
border_style_methods, box_shadow_style_methods, cursor_style_methods, margin_style_methods,
@@ -630,6 +631,14 @@ pub trait Styled: Sized {
self
}
/// Sets the font features of this element and its children.
fn font_features(mut self, features: FontFeatures) -> Self {
self.text_style()
.get_or_insert_with(Default::default)
.font_features = Some(features);
self
}
/// Sets the font of this element and its children.
fn font(mut self, font: Font) -> Self {
let Font {

Some files were not shown because too many files have changed in this diff Show More