Compare commits
52 Commits
revert-pre
...
ep-distill
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
610536201b | ||
|
|
60f4aa333b | ||
|
|
a698f1bf63 | ||
|
|
636d11ebec | ||
|
|
4d0e760b04 | ||
|
|
8bd4d866b9 | ||
|
|
47c30b6da7 | ||
|
|
18d344e118 | ||
|
|
610cc1b138 | ||
|
|
a07ea1a272 | ||
|
|
e03fa114a7 | ||
|
|
17db7b0e99 | ||
|
|
1afe29422b | ||
|
|
a8aa7622b7 | ||
|
|
a66854e435 | ||
|
|
12073e10f8 | ||
|
|
a2a96e4038 | ||
|
|
1186b50ca4 | ||
|
|
65130a9ca9 | ||
|
|
23d18fde8c | ||
|
|
332c0d03d1 | ||
|
|
b871130220 | ||
|
|
0a1e5f93a0 | ||
|
|
8d0fff688f | ||
|
|
717d898692 | ||
|
|
1cd7563f04 | ||
|
|
fc6ca38989 | ||
|
|
1029a8fbaf | ||
|
|
07748b7bae | ||
|
|
37f2ac24b8 | ||
|
|
b5a0a3322d | ||
|
|
eb7da26d19 | ||
|
|
9c099e7ed3 | ||
|
|
7669b05268 | ||
|
|
ec26556dab | ||
|
|
2098b67304 | ||
|
|
1a8d8e9572 | ||
|
|
5a6198cc39 | ||
|
|
ab893ca754 | ||
|
|
cda78c12ab | ||
|
|
f4378672b8 | ||
|
|
ecb8d3d4dd | ||
|
|
95dbc0efc2 | ||
|
|
8572c19a02 | ||
|
|
045c14593f | ||
|
|
0ff3b68a5e | ||
|
|
a6b9524d78 | ||
|
|
7ed5d42696 | ||
|
|
25d74480aa | ||
|
|
37077a8ebb | ||
|
|
7c4a85f5f1 | ||
|
|
d21628c349 |
16
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
16
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -75,6 +75,22 @@ body:
|
||||
</details>
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Relevant Keymap
|
||||
description: |
|
||||
Open the command palette in Zed, then type “zed: open keymap file” and copy/paste the file's contents.
|
||||
value: |
|
||||
<details><summary>keymap.json</summary>
|
||||
|
||||
<!-- Paste your keymap file inside the code block. -->
|
||||
```json
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: (for AI issues) Model provider details
|
||||
|
||||
25
.github/workflows/after_release.yml
vendored
25
.github/workflows/after_release.yml
vendored
@@ -5,13 +5,27 @@ on:
|
||||
release:
|
||||
types:
|
||||
- published
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag_name:
|
||||
description: tag_name
|
||||
required: true
|
||||
type: string
|
||||
prerelease:
|
||||
description: prerelease
|
||||
required: true
|
||||
type: boolean
|
||||
body:
|
||||
description: body
|
||||
type: string
|
||||
default: ''
|
||||
jobs:
|
||||
rebuild_releases_page:
|
||||
if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
|
||||
runs-on: namespace-profile-2x4-ubuntu-2404
|
||||
steps:
|
||||
- name: after_release::rebuild_releases_page::refresh_cloud_releases
|
||||
run: curl -fX POST https://cloud.zed.dev/releases/refresh?expect_tag=${{ github.event.release.tag_name }}
|
||||
run: curl -fX POST https://cloud.zed.dev/releases/refresh?expect_tag=${{ github.event.release.tag_name || inputs.tag_name }}
|
||||
shell: bash -euxo pipefail {0}
|
||||
- name: after_release::rebuild_releases_page::redeploy_zed_dev
|
||||
run: npm exec --yes -- vercel@37 --token="$VERCEL_TOKEN" --scope zed-industries redeploy https://zed.dev
|
||||
@@ -27,7 +41,7 @@ jobs:
|
||||
- id: get-release-url
|
||||
name: after_release::post_to_discord::get_release_url
|
||||
run: |
|
||||
if [ "${{ github.event.release.prerelease }}" == "true" ]; then
|
||||
if [ "${{ github.event.release.prerelease || inputs.prerelease }}" == "true" ]; then
|
||||
URL="https://zed.dev/releases/preview"
|
||||
else
|
||||
URL="https://zed.dev/releases/stable"
|
||||
@@ -40,9 +54,9 @@ jobs:
|
||||
uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757
|
||||
with:
|
||||
stringToTruncate: |
|
||||
📣 Zed [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released!
|
||||
📣 Zed [${{ github.event.release.tag_name || inputs.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released!
|
||||
|
||||
${{ github.event.release.body }}
|
||||
${{ github.event.release.body || inputs.body }}
|
||||
maxLength: 2000
|
||||
truncationSymbol: '...'
|
||||
- name: after_release::post_to_discord::discord_webhook_action
|
||||
@@ -56,7 +70,7 @@ jobs:
|
||||
- id: set-package-name
|
||||
name: after_release::publish_winget::set_package_name
|
||||
run: |
|
||||
if ("${{ github.event.release.prerelease }}" -eq "true") {
|
||||
if ("${{ github.event.release.prerelease || inputs.prerelease }}" -eq "true") {
|
||||
$PACKAGE_NAME = "ZedIndustries.Zed.Preview"
|
||||
} else {
|
||||
$PACKAGE_NAME = "ZedIndustries.Zed"
|
||||
@@ -68,6 +82,7 @@ jobs:
|
||||
uses: vedantmgoyal9/winget-releaser@19e706d4c9121098010096f9c495a70a7518b30f
|
||||
with:
|
||||
identifier: ${{ steps.set-package-name.outputs.PACKAGE_NAME }}
|
||||
release-tag: ${{ github.event.release.tag_name || inputs.tag_name }}
|
||||
max-versions-to-keep: 5
|
||||
token: ${{ secrets.WINGET_TOKEN }}
|
||||
create_sentry_release:
|
||||
|
||||
2
.github/workflows/run_tests.yml
vendored
2
.github/workflows/run_tests.yml
vendored
@@ -497,6 +497,8 @@ jobs:
|
||||
env:
|
||||
GIT_AUTHOR_NAME: Protobuf Action
|
||||
GIT_AUTHOR_EMAIL: ci@zed.dev
|
||||
GIT_COMMITTER_NAME: Protobuf Action
|
||||
GIT_COMMITTER_EMAIL: ci@zed.dev
|
||||
steps:
|
||||
- name: steps::checkout_repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
|
||||
|
||||
39
Cargo.lock
generated
39
Cargo.lock
generated
@@ -3111,16 +3111,6 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cloud_zeta2_prompt"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"cloud_llm_client",
|
||||
"indoc",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.54"
|
||||
@@ -5119,7 +5109,6 @@ dependencies = [
|
||||
"clock",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"cloud_zeta2_prompt",
|
||||
"collections",
|
||||
"copilot",
|
||||
"credentials_provider",
|
||||
@@ -5150,8 +5139,6 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"strsim",
|
||||
"strum 0.27.2",
|
||||
"telemetry",
|
||||
"telemetry_events",
|
||||
@@ -5162,6 +5149,7 @@ dependencies = [
|
||||
"workspace",
|
||||
"worktree",
|
||||
"zed_actions",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
@@ -5175,11 +5163,10 @@ dependencies = [
|
||||
"clap",
|
||||
"client",
|
||||
"cloud_llm_client",
|
||||
"cloud_zeta2_prompt",
|
||||
"collections",
|
||||
"debug_adapter_extension",
|
||||
"dirs 4.0.0",
|
||||
"edit_prediction",
|
||||
"edit_prediction_context",
|
||||
"extension",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
@@ -5192,6 +5179,7 @@ dependencies = [
|
||||
"language_model",
|
||||
"language_models",
|
||||
"languages",
|
||||
"libc",
|
||||
"log",
|
||||
"node_runtime",
|
||||
"paths",
|
||||
@@ -5209,10 +5197,10 @@ dependencies = [
|
||||
"sqlez",
|
||||
"sqlez_macros",
|
||||
"terminal_view",
|
||||
"toml 0.8.23",
|
||||
"util",
|
||||
"wasmtime",
|
||||
"watch",
|
||||
"zlog",
|
||||
"zeta_prompt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5239,6 +5227,7 @@ dependencies = [
|
||||
"text",
|
||||
"tree-sitter",
|
||||
"util",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
@@ -5260,7 +5249,6 @@ dependencies = [
|
||||
"buffer_diff",
|
||||
"client",
|
||||
"cloud_llm_client",
|
||||
"cloud_zeta2_prompt",
|
||||
"codestral",
|
||||
"command_palette_hooks",
|
||||
"copilot",
|
||||
@@ -5291,6 +5279,7 @@ dependencies = [
|
||||
"util",
|
||||
"workspace",
|
||||
"zed_actions",
|
||||
"zeta_prompt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7250,6 +7239,7 @@ dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"lyon",
|
||||
"mach2 0.5.0",
|
||||
"media",
|
||||
"metal",
|
||||
"naga",
|
||||
@@ -14456,12 +14446,14 @@ dependencies = [
|
||||
"settings",
|
||||
"smol",
|
||||
"theme",
|
||||
"tracing",
|
||||
"ui",
|
||||
"unindent",
|
||||
"util",
|
||||
"util_macros",
|
||||
"workspace",
|
||||
"zed_actions",
|
||||
"ztracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -16374,13 +16366,13 @@ dependencies = [
|
||||
"alacritty_terminal",
|
||||
"anyhow",
|
||||
"collections",
|
||||
"fancy-regex",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"itertools 0.14.0",
|
||||
"libc",
|
||||
"log",
|
||||
"rand 0.9.2",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"schemars",
|
||||
"serde",
|
||||
@@ -18108,6 +18100,7 @@ dependencies = [
|
||||
"language",
|
||||
"log",
|
||||
"lsp",
|
||||
"markdown_preview",
|
||||
"menu",
|
||||
"multi_buffer",
|
||||
"nvim-rs",
|
||||
@@ -20933,6 +20926,13 @@ dependencies = [
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeta_prompt"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.6.6"
|
||||
@@ -21025,6 +21025,7 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"tracing-tracy",
|
||||
"zlog",
|
||||
"ztracing_macro",
|
||||
]
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ members = [
|
||||
"crates/cloud_api_client",
|
||||
"crates/cloud_api_types",
|
||||
"crates/cloud_llm_client",
|
||||
"crates/cloud_zeta2_prompt",
|
||||
"crates/collab",
|
||||
"crates/collab_ui",
|
||||
"crates/collections",
|
||||
@@ -202,6 +201,7 @@ members = [
|
||||
"crates/zed_actions",
|
||||
"crates/zed_env_vars",
|
||||
"crates/edit_prediction_cli",
|
||||
"crates/zeta_prompt",
|
||||
"crates/zlog",
|
||||
"crates/zlog_settings",
|
||||
"crates/ztracing",
|
||||
@@ -266,7 +266,6 @@ clock = { path = "crates/clock" }
|
||||
cloud_api_client = { path = "crates/cloud_api_client" }
|
||||
cloud_api_types = { path = "crates/cloud_api_types" }
|
||||
cloud_llm_client = { path = "crates/cloud_llm_client" }
|
||||
cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
|
||||
collab_ui = { path = "crates/collab_ui" }
|
||||
collections = { path = "crates/collections", version = "0.1.0" }
|
||||
command_palette = { path = "crates/command_palette" }
|
||||
@@ -425,6 +424,7 @@ zed = { path = "crates/zed" }
|
||||
zed_actions = { path = "crates/zed_actions" }
|
||||
zed_env_vars = { path = "crates/zed_env_vars" }
|
||||
edit_prediction = { path = "crates/edit_prediction" }
|
||||
zeta_prompt = { path = "crates/zeta_prompt" }
|
||||
zlog = { path = "crates/zlog" }
|
||||
zlog_settings = { path = "crates/zlog_settings" }
|
||||
ztracing = { path = "crates/ztracing" }
|
||||
@@ -631,7 +631,7 @@ shellexpand = "2.1.0"
|
||||
shlex = "1.3.0"
|
||||
simplelog = "0.12.2"
|
||||
slotmap = "1.0.6"
|
||||
smallvec = { version = "1.6", features = ["union"] }
|
||||
smallvec = { version = "1.6", features = ["union", "const_new"] }
|
||||
smol = "2.0"
|
||||
sqlformat = "0.2"
|
||||
stacksafe = "0.1"
|
||||
@@ -657,6 +657,7 @@ time = { version = "0.3", features = [
|
||||
tiny_http = "0.8"
|
||||
tokio = { version = "1" }
|
||||
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
|
||||
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] }
|
||||
toml = "0.8"
|
||||
toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] }
|
||||
tower-http = "0.4.4"
|
||||
|
||||
@@ -25,7 +25,8 @@
|
||||
"ctrl-shift-w": "workspace::CloseWindow",
|
||||
"shift-escape": "workspace::ToggleZoom",
|
||||
"open": "workspace::Open",
|
||||
"ctrl-o": "workspace::Open",
|
||||
"ctrl-o": "workspace::OpenFiles",
|
||||
"ctrl-k ctrl-o": "workspace::Open",
|
||||
"ctrl-=": ["zed::IncreaseBufferFontSize", { "persist": false }],
|
||||
"ctrl-+": ["zed::IncreaseBufferFontSize", { "persist": false }],
|
||||
"ctrl--": ["zed::DecreaseBufferFontSize", { "persist": false }],
|
||||
@@ -814,7 +815,6 @@
|
||||
"ctrl-]": "agent::CycleNextInlineAssist",
|
||||
"ctrl-shift-enter": "inline_assistant::ThumbsUpResult",
|
||||
"ctrl-shift-backspace": "inline_assistant::ThumbsDownResult"
|
||||
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -1192,8 +1192,12 @@
|
||||
{
|
||||
"context": "MarkdownPreview",
|
||||
"bindings": {
|
||||
"pageup": "markdown::MovePageUp",
|
||||
"pagedown": "markdown::MovePageDown"
|
||||
"pageup": "markdown::ScrollPageUp",
|
||||
"pagedown": "markdown::ScrollPageDown",
|
||||
"up": "markdown::ScrollUp",
|
||||
"down": "markdown::ScrollDown",
|
||||
"alt-up": "markdown::ScrollUpByItem",
|
||||
"alt-down": "markdown::ScrollDownByItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1296,8 +1296,12 @@
|
||||
{
|
||||
"context": "MarkdownPreview",
|
||||
"bindings": {
|
||||
"pageup": "markdown::MovePageUp",
|
||||
"pagedown": "markdown::MovePageDown"
|
||||
"pageup": "markdown::ScrollPageUp",
|
||||
"pagedown": "markdown::ScrollPageDown",
|
||||
"up": "markdown::ScrollUp",
|
||||
"down": "markdown::ScrollDown",
|
||||
"alt-up": "markdown::ScrollUpByItem",
|
||||
"alt-down": "markdown::ScrollDownByItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -489,8 +489,8 @@
|
||||
"bindings": {
|
||||
"ctrl-[": "editor::Outdent",
|
||||
"ctrl-]": "editor::Indent",
|
||||
"ctrl-shift-alt-up": ["editor::AddSelectionAbove", { "skip_soft_wrap": true }], // Insert Cursor Above
|
||||
"ctrl-shift-alt-down": ["editor::AddSelectionBelow", { "skip_soft_wrap": true }], // Insert Cursor Below
|
||||
"ctrl-alt-up": ["editor::AddSelectionAbove", { "skip_soft_wrap": true }], // Insert Cursor Above
|
||||
"ctrl-alt-down": ["editor::AddSelectionBelow", { "skip_soft_wrap": true }], // Insert Cursor Below
|
||||
"ctrl-shift-k": "editor::DeleteLine",
|
||||
"alt-up": "editor::MoveLineUp",
|
||||
"alt-down": "editor::MoveLineDown",
|
||||
@@ -501,9 +501,12 @@
|
||||
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
|
||||
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
|
||||
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
|
||||
"ctrl-f3": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
|
||||
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip
|
||||
"ctrl-shift-f3": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
|
||||
"ctrl-k ctrl-i": "editor::Hover",
|
||||
"ctrl-k ctrl-b": "editor::BlameHover",
|
||||
"ctrl-k ctrl-f": "editor::FormatSelections",
|
||||
"ctrl-/": ["editor::ToggleComments", { "advance_downwards": false }],
|
||||
"f8": ["editor::GoToDiagnostic", { "severity": { "min": "hint", "max": "error" } }],
|
||||
"shift-f8": ["editor::GoToPreviousDiagnostic", { "severity": { "min": "hint", "max": "error" } }],
|
||||
@@ -536,7 +539,7 @@
|
||||
"ctrl-k p": "editor::CopyPath",
|
||||
"ctrl-\\": "pane::SplitRight",
|
||||
"alt-.": "editor::GoToHunk",
|
||||
"alt-,": "editor::GoToPreviousHunk"
|
||||
"alt-,": "editor::GoToPreviousHunk",
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -1220,8 +1223,12 @@
|
||||
"context": "MarkdownPreview",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"pageup": "markdown::MovePageUp",
|
||||
"pagedown": "markdown::MovePageDown"
|
||||
"pageup": "markdown::ScrollPageUp",
|
||||
"pagedown": "markdown::ScrollPageDown",
|
||||
"up": "markdown::ScrollUp",
|
||||
"down": "markdown::ScrollDown",
|
||||
"alt-up": "markdown::ScrollUpByItem",
|
||||
"alt-down": "markdown::ScrollDownByItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1046,5 +1046,14 @@
|
||||
"g g": "settings_editor::FocusFirstNavEntry",
|
||||
"shift-g": "settings_editor::FocusLastNavEntry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "MarkdownPreview",
|
||||
"bindings": {
|
||||
"ctrl-u": "markdown::ScrollPageUp",
|
||||
"ctrl-d": "markdown::ScrollPageDown",
|
||||
"ctrl-y": "markdown::ScrollUp",
|
||||
"ctrl-e": "markdown::ScrollDown"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
"theme": {
|
||||
"mode": "system",
|
||||
"light": "One Light",
|
||||
"dark": "One Dark"
|
||||
"dark": "One Dark",
|
||||
},
|
||||
"icon_theme": "Zed (Default)",
|
||||
// The name of a base set of key bindings to use.
|
||||
@@ -29,7 +29,7 @@
|
||||
// Features that can be globally enabled or disabled
|
||||
"features": {
|
||||
// Which edit prediction provider to use.
|
||||
"edit_prediction_provider": "zed"
|
||||
"edit_prediction_provider": "zed",
|
||||
},
|
||||
// The name of a font to use for rendering text in the editor
|
||||
// ".ZedMono" currently aliases to Lilex
|
||||
@@ -69,7 +69,7 @@
|
||||
// The OpenType features to enable for text in the UI
|
||||
"ui_font_features": {
|
||||
// Disable ligatures:
|
||||
"calt": false
|
||||
"calt": false,
|
||||
},
|
||||
// The weight of the UI font in standard CSS units from 100 to 900.
|
||||
"ui_font_weight": 400,
|
||||
@@ -87,7 +87,7 @@
|
||||
"border_size": 0.0,
|
||||
// Opacity of the inactive panes. 0 means transparent, 1 means opaque.
|
||||
// Values are clamped to the [0.0, 1.0] range.
|
||||
"inactive_opacity": 1.0
|
||||
"inactive_opacity": 1.0,
|
||||
},
|
||||
// Layout mode of the bottom dock. Defaults to "contained"
|
||||
// choices: contained, full, left_aligned, right_aligned
|
||||
@@ -103,12 +103,12 @@
|
||||
"left_padding": 0.2,
|
||||
// The relative width of the right padding of the central pane from the
|
||||
// workspace when the centered layout is used.
|
||||
"right_padding": 0.2
|
||||
"right_padding": 0.2,
|
||||
},
|
||||
// Image viewer settings
|
||||
"image_viewer": {
|
||||
// The unit for image file sizes: "binary" (KiB, MiB) or decimal (KB, MB)
|
||||
"unit": "binary"
|
||||
"unit": "binary",
|
||||
},
|
||||
// Determines the modifier to be used to add multiple cursors with the mouse. The open hover link mouse gestures will adapt such that it do not conflict with the multicursor modifier.
|
||||
//
|
||||
@@ -296,7 +296,7 @@
|
||||
// When true, enables drag and drop text selection in buffer.
|
||||
"enabled": true,
|
||||
// The delay in milliseconds that must elapse before drag and drop is allowed. Otherwise, a new text selection is created.
|
||||
"delay": 300
|
||||
"delay": 300,
|
||||
},
|
||||
// What to do when go to definition yields no results.
|
||||
//
|
||||
@@ -400,14 +400,14 @@
|
||||
// Visible characters used to render whitespace when show_whitespaces is enabled.
|
||||
"whitespace_map": {
|
||||
"space": "•",
|
||||
"tab": "→"
|
||||
"tab": "→",
|
||||
},
|
||||
// Settings related to calls in Zed
|
||||
"calls": {
|
||||
// Join calls with the microphone live by default
|
||||
"mute_on_join": false,
|
||||
// Share your project when you are the first to join a channel
|
||||
"share_on_join": false
|
||||
"share_on_join": false,
|
||||
},
|
||||
// Toolbar related settings
|
||||
"toolbar": {
|
||||
@@ -420,7 +420,7 @@
|
||||
// Whether to show agent review buttons in the editor toolbar.
|
||||
"agent_review": true,
|
||||
// Whether to show code action buttons in the editor toolbar.
|
||||
"code_actions": false
|
||||
"code_actions": false,
|
||||
},
|
||||
// Whether to allow windows to tab together based on the user’s tabbing preference (macOS only).
|
||||
"use_system_window_tabs": false,
|
||||
@@ -439,7 +439,7 @@
|
||||
// Whether to show the sign in button in the titlebar.
|
||||
"show_sign_in": true,
|
||||
// Whether to show the menus in the titlebar.
|
||||
"show_menus": false
|
||||
"show_menus": false,
|
||||
},
|
||||
"audio": {
|
||||
// Opt into the new audio system.
|
||||
@@ -472,7 +472,7 @@
|
||||
// the future we will migrate by setting this to false
|
||||
//
|
||||
// You need to rejoin a call for this setting to apply
|
||||
"experimental.legacy_audio_compatible": true
|
||||
"experimental.legacy_audio_compatible": true,
|
||||
},
|
||||
// Scrollbar related settings
|
||||
"scrollbar": {
|
||||
@@ -511,8 +511,8 @@
|
||||
// When false, forcefully disables the horizontal scrollbar. Otherwise, obey other settings.
|
||||
"horizontal": true,
|
||||
// When false, forcefully disables the vertical scrollbar. Otherwise, obey other settings.
|
||||
"vertical": true
|
||||
}
|
||||
"vertical": true,
|
||||
},
|
||||
},
|
||||
// Minimap related settings
|
||||
"minimap": {
|
||||
@@ -560,7 +560,7 @@
|
||||
// 3. "gutter" or "none" to not highlight the current line in the minimap.
|
||||
"current_line_highlight": null,
|
||||
// Maximum number of columns to display in the minimap.
|
||||
"max_width_columns": 80
|
||||
"max_width_columns": 80,
|
||||
},
|
||||
// Enable middle-click paste on Linux.
|
||||
"middle_click_paste": true,
|
||||
@@ -583,7 +583,7 @@
|
||||
// Whether to show fold buttons in the gutter.
|
||||
"folds": true,
|
||||
// Minimum number of characters to reserve space for in the gutter.
|
||||
"min_line_number_digits": 4
|
||||
"min_line_number_digits": 4,
|
||||
},
|
||||
"indent_guides": {
|
||||
// Whether to show indent guides in the editor.
|
||||
@@ -604,7 +604,7 @@
|
||||
//
|
||||
// 1. "disabled"
|
||||
// 2. "indent_aware"
|
||||
"background_coloring": "disabled"
|
||||
"background_coloring": "disabled",
|
||||
},
|
||||
// Whether the editor will scroll beyond the last line.
|
||||
"scroll_beyond_last_line": "one_page",
|
||||
@@ -623,7 +623,7 @@
|
||||
"fast_scroll_sensitivity": 4.0,
|
||||
"sticky_scroll": {
|
||||
// Whether to stick scopes to the top of the editor.
|
||||
"enabled": false
|
||||
"enabled": false,
|
||||
},
|
||||
"relative_line_numbers": "disabled",
|
||||
// If 'search_wrap' is disabled, search result do not wrap around the end of the file.
|
||||
@@ -641,7 +641,7 @@
|
||||
// Whether to interpret the search query as a regular expression.
|
||||
"regex": false,
|
||||
// Whether to center the cursor on each search match when navigating.
|
||||
"center_on_match": false
|
||||
"center_on_match": false,
|
||||
},
|
||||
// When to populate a new search's query based on the text under the cursor.
|
||||
// This setting can take the following three values:
|
||||
@@ -684,8 +684,8 @@
|
||||
"shift": false,
|
||||
"alt": false,
|
||||
"platform": false,
|
||||
"function": false
|
||||
}
|
||||
"function": false,
|
||||
},
|
||||
},
|
||||
// Whether to resize all the panels in a dock when resizing the dock.
|
||||
// Can be a combination of "left", "right" and "bottom".
|
||||
@@ -733,7 +733,7 @@
|
||||
// "always"
|
||||
// 5. Never show the scrollbar:
|
||||
// "never"
|
||||
"show": null
|
||||
"show": null,
|
||||
},
|
||||
// Which files containing diagnostic errors/warnings to mark in the project panel.
|
||||
// This setting can take the following three values:
|
||||
@@ -756,7 +756,7 @@
|
||||
// "always"
|
||||
// 2. Never show indent guides:
|
||||
// "never"
|
||||
"show": "always"
|
||||
"show": "always",
|
||||
},
|
||||
// Sort order for entries in the project panel.
|
||||
// This setting can take three values:
|
||||
@@ -781,8 +781,8 @@
|
||||
// Whether to automatically open files after pasting or duplicating them.
|
||||
"on_paste": true,
|
||||
// Whether to automatically open files dropped from external sources.
|
||||
"on_drop": true
|
||||
}
|
||||
"on_drop": true,
|
||||
},
|
||||
},
|
||||
"outline_panel": {
|
||||
// Whether to show the outline panel button in the status bar
|
||||
@@ -815,7 +815,7 @@
|
||||
// "always"
|
||||
// 2. Never show indent guides:
|
||||
// "never"
|
||||
"show": "always"
|
||||
"show": "always",
|
||||
},
|
||||
// Scrollbar-related settings
|
||||
"scrollbar": {
|
||||
@@ -832,11 +832,11 @@
|
||||
// "always"
|
||||
// 5. Never show the scrollbar:
|
||||
// "never"
|
||||
"show": null
|
||||
"show": null,
|
||||
},
|
||||
// Default depth to expand outline items in the current file.
|
||||
// Set to 0 to collapse all items that have children, 1 or higher to collapse items at that depth or deeper.
|
||||
"expand_outlines_with_depth": 100
|
||||
"expand_outlines_with_depth": 100,
|
||||
},
|
||||
"collaboration_panel": {
|
||||
// Whether to show the collaboration panel button in the status bar.
|
||||
@@ -844,7 +844,7 @@
|
||||
// Where to dock the collaboration panel. Can be 'left' or 'right'.
|
||||
"dock": "left",
|
||||
// Default width of the collaboration panel.
|
||||
"default_width": 240
|
||||
"default_width": 240,
|
||||
},
|
||||
"git_panel": {
|
||||
// Whether to show the git panel button in the status bar.
|
||||
@@ -880,12 +880,12 @@
|
||||
// Choices: always, auto, never, system
|
||||
// Default: inherits editor scrollbar settings
|
||||
// "show": null
|
||||
}
|
||||
},
|
||||
},
|
||||
"message_editor": {
|
||||
// Whether to automatically replace emoji shortcodes with emoji characters.
|
||||
// For example: typing `:wave:` gets replaced with `👋`.
|
||||
"auto_replace_emoji_shortcode": true
|
||||
"auto_replace_emoji_shortcode": true,
|
||||
},
|
||||
"notification_panel": {
|
||||
// Whether to show the notification panel button in the status bar.
|
||||
@@ -893,7 +893,7 @@
|
||||
// Where to dock the notification panel. Can be 'left' or 'right'.
|
||||
"dock": "right",
|
||||
// Default width of the notification panel.
|
||||
"default_width": 380
|
||||
"default_width": 380,
|
||||
},
|
||||
"agent": {
|
||||
// Whether the agent is enabled.
|
||||
@@ -915,7 +915,7 @@
|
||||
// The provider to use.
|
||||
"provider": "zed.dev",
|
||||
// The model to use.
|
||||
"model": "claude-sonnet-4"
|
||||
"model": "claude-sonnet-4",
|
||||
},
|
||||
// Additional parameters for language model requests. When making a request to a model, parameters will be taken
|
||||
// from the last entry in this list that matches the model's provider and name. In each entry, both provider
|
||||
@@ -970,8 +970,8 @@
|
||||
"grep": true,
|
||||
"terminal": true,
|
||||
"thinking": true,
|
||||
"web_search": true
|
||||
}
|
||||
"web_search": true,
|
||||
},
|
||||
},
|
||||
"ask": {
|
||||
"name": "Ask",
|
||||
@@ -988,14 +988,14 @@
|
||||
"open": true,
|
||||
"grep": true,
|
||||
"thinking": true,
|
||||
"web_search": true
|
||||
}
|
||||
"web_search": true,
|
||||
},
|
||||
},
|
||||
"minimal": {
|
||||
"name": "Minimal",
|
||||
"enable_all_context_servers": false,
|
||||
"tools": {}
|
||||
}
|
||||
"tools": {},
|
||||
},
|
||||
},
|
||||
// Where to show notifications when the agent has either completed
|
||||
// its response, or else needs confirmation before it can run a
|
||||
@@ -1024,7 +1024,7 @@
|
||||
// Minimum number of lines to display in the agent message editor.
|
||||
//
|
||||
// Default: 4
|
||||
"message_editor_min_lines": 4
|
||||
"message_editor_min_lines": 4,
|
||||
},
|
||||
// Whether the screen sharing icon is shown in the os status bar.
|
||||
"show_call_status_icon": true,
|
||||
@@ -1059,7 +1059,7 @@
|
||||
// Whether or not to show the navigation history buttons.
|
||||
"show_nav_history_buttons": true,
|
||||
// Whether or not to show the tab bar buttons.
|
||||
"show_tab_bar_buttons": true
|
||||
"show_tab_bar_buttons": true,
|
||||
},
|
||||
// Settings related to the editor's tabs
|
||||
"tabs": {
|
||||
@@ -1098,7 +1098,7 @@
|
||||
// "errors"
|
||||
// 3. Mark files with errors and warnings:
|
||||
// "all"
|
||||
"show_diagnostics": "off"
|
||||
"show_diagnostics": "off",
|
||||
},
|
||||
// Settings related to preview tabs.
|
||||
"preview_tabs": {
|
||||
@@ -1119,7 +1119,7 @@
|
||||
"enable_preview_file_from_code_navigation": true,
|
||||
// Whether to keep tabs in preview mode when code navigation is used to navigate away from them.
|
||||
// If `enable_preview_file_from_code_navigation` or `enable_preview_multibuffer_from_code_navigation` is also true, the new tab may replace the existing one.
|
||||
"enable_keep_preview_on_code_navigation": false
|
||||
"enable_keep_preview_on_code_navigation": false,
|
||||
},
|
||||
// Settings related to the file finder.
|
||||
"file_finder": {
|
||||
@@ -1163,7 +1163,7 @@
|
||||
// * "all": Use all gitignored files
|
||||
// * "indexed": Use only the files Zed had indexed
|
||||
// * "smart": Be smart and search for ignored when called from a gitignored worktree
|
||||
"include_ignored": "smart"
|
||||
"include_ignored": "smart",
|
||||
},
|
||||
// Whether or not to remove any trailing whitespace from lines of a buffer
|
||||
// before saving it.
|
||||
@@ -1234,7 +1234,7 @@
|
||||
// Send debug info like crash reports.
|
||||
"diagnostics": true,
|
||||
// Send anonymized usage data like what languages you're using Zed with.
|
||||
"metrics": true
|
||||
"metrics": true,
|
||||
},
|
||||
// Whether to disable all AI features in Zed.
|
||||
//
|
||||
@@ -1268,7 +1268,7 @@
|
||||
"enabled": true,
|
||||
// Minimum time to wait before pulling diagnostics from the language server(s).
|
||||
// 0 turns the debounce off.
|
||||
"debounce_ms": 50
|
||||
"debounce_ms": 50,
|
||||
},
|
||||
// Settings for inline diagnostics
|
||||
"inline": {
|
||||
@@ -1286,8 +1286,8 @@
|
||||
"min_column": 0,
|
||||
// The minimum severity of the diagnostics to show inline.
|
||||
// Inherits editor's diagnostics' max severity settings when `null`.
|
||||
"max_severity": null
|
||||
}
|
||||
"max_severity": null,
|
||||
},
|
||||
},
|
||||
// Files or globs of files that will be excluded by Zed entirely. They will be skipped during file
|
||||
// scans, file searches, and not be displayed in the project file tree. Takes precedence over `file_scan_inclusions`.
|
||||
@@ -1301,7 +1301,7 @@
|
||||
"**/.DS_Store",
|
||||
"**/Thumbs.db",
|
||||
"**/.classpath",
|
||||
"**/.settings"
|
||||
"**/.settings",
|
||||
],
|
||||
// Files or globs of files that will be included by Zed, even when ignored by git. This is useful
|
||||
// for files that are not tracked by git, but are still important to your project. Note that globs
|
||||
@@ -1336,14 +1336,14 @@
|
||||
// Whether or not to display the git commit summary on the same line.
|
||||
"show_commit_summary": false,
|
||||
// The minimum column number to show the inline blame information at
|
||||
"min_column": 0
|
||||
"min_column": 0,
|
||||
},
|
||||
"blame": {
|
||||
"show_avatar": true
|
||||
"show_avatar": true,
|
||||
},
|
||||
// Control which information is shown in the branch picker.
|
||||
"branch_picker": {
|
||||
"show_author_name": true
|
||||
"show_author_name": true,
|
||||
},
|
||||
// How git hunks are displayed visually in the editor.
|
||||
// This setting can take two values:
|
||||
@@ -1355,7 +1355,7 @@
|
||||
"hunk_style": "staged_hollow",
|
||||
// Should the name or path be displayed first in the git view.
|
||||
// "path_style": "file_name_first" or "file_path_first"
|
||||
"path_style": "file_name_first"
|
||||
"path_style": "file_name_first",
|
||||
},
|
||||
// The list of custom Git hosting providers.
|
||||
"git_hosting_providers": [
|
||||
@@ -1389,7 +1389,7 @@
|
||||
"**/secrets.yml",
|
||||
"**/.zed/settings.json", // zed project settings
|
||||
"/**/zed/settings.json", // zed user settings
|
||||
"/**/zed/keymap.json"
|
||||
"/**/zed/keymap.json",
|
||||
],
|
||||
// When to show edit predictions previews in buffer.
|
||||
// This setting takes two possible values:
|
||||
@@ -1407,15 +1407,15 @@
|
||||
"copilot": {
|
||||
"enterprise_uri": null,
|
||||
"proxy": null,
|
||||
"proxy_no_verify": null
|
||||
"proxy_no_verify": null,
|
||||
},
|
||||
"codestral": {
|
||||
"model": null,
|
||||
"max_tokens": null
|
||||
"max_tokens": null,
|
||||
},
|
||||
// Whether edit predictions are enabled when editing text threads in the agent panel.
|
||||
// This setting has no effect if globally disabled.
|
||||
"enabled_in_text_threads": true
|
||||
"enabled_in_text_threads": true,
|
||||
},
|
||||
// Settings specific to journaling
|
||||
"journal": {
|
||||
@@ -1425,7 +1425,7 @@
|
||||
// May take 2 values:
|
||||
// 1. hour12
|
||||
// 2. hour24
|
||||
"hour_format": "hour12"
|
||||
"hour_format": "hour12",
|
||||
},
|
||||
// Status bar-related settings.
|
||||
"status_bar": {
|
||||
@@ -1436,7 +1436,7 @@
|
||||
// Whether to show the cursor position button in the status bar.
|
||||
"cursor_position_button": true,
|
||||
// Whether to show active line endings button in the status bar.
|
||||
"line_endings_button": false
|
||||
"line_endings_button": false,
|
||||
},
|
||||
// Settings specific to the terminal
|
||||
"terminal": {
|
||||
@@ -1557,8 +1557,8 @@
|
||||
// Preferred Conda manager to use when activating Conda environments.
|
||||
// Values: "auto", "conda", "mamba", "micromamba"
|
||||
// Default: "auto"
|
||||
"conda_manager": "auto"
|
||||
}
|
||||
"conda_manager": "auto",
|
||||
},
|
||||
},
|
||||
"toolbar": {
|
||||
// Whether to display the terminal title in its toolbar's breadcrumbs.
|
||||
@@ -1566,7 +1566,7 @@
|
||||
//
|
||||
// The shell running in the terminal needs to be configured to emit the title.
|
||||
// Example: `echo -e "\e]2;New Title\007";`
|
||||
"breadcrumbs": false
|
||||
"breadcrumbs": false,
|
||||
},
|
||||
// Scrollbar-related settings
|
||||
"scrollbar": {
|
||||
@@ -1583,7 +1583,7 @@
|
||||
// "always"
|
||||
// 5. Never show the scrollbar:
|
||||
// "never"
|
||||
"show": null
|
||||
"show": null,
|
||||
},
|
||||
// Set the terminal's font size. If this option is not included,
|
||||
// the terminal will default to matching the buffer's font size.
|
||||
@@ -1646,30 +1646,26 @@
|
||||
// surrounding symbols or quotes
|
||||
[
|
||||
"(?x)",
|
||||
"# optionally starts with 0-2 opening prefix symbols",
|
||||
"[({\\[<]{0,2}",
|
||||
"# which may be followed by an opening quote",
|
||||
"(?<quote>[\"'`])?",
|
||||
"# `path` is the shortest sequence of any non-space character",
|
||||
"(?<link>(?<path>[^ ]+?",
|
||||
" # which may end with a line and optionally a column,",
|
||||
" (?<line_column>:+[0-9]+(:[0-9]+)?|:?\\([0-9]+([,:][0-9]+)?\\))?",
|
||||
"))",
|
||||
"# which must be followed by a matching quote",
|
||||
"(?(<quote>)\\k<quote>)",
|
||||
"# and optionally a single closing symbol",
|
||||
"[)}\\]>]?",
|
||||
"# if line/column matched, may be followed by a description",
|
||||
"(?(<line_column>):[^ 0-9][^ ]*)?",
|
||||
"# which may be followed by trailing punctuation",
|
||||
"[.,:)}\\]>]*",
|
||||
"# and always includes trailing whitespace or end of line",
|
||||
"([ ]+|$)"
|
||||
]
|
||||
"(?<path>",
|
||||
" (",
|
||||
" # multi-char path: first char (not opening delimiter or space)",
|
||||
" [^({\\[<\"'`\\ ]",
|
||||
" # middle chars: non-space, and colon/paren only if not followed by digit/paren",
|
||||
" ([^\\ :(]|[:(][^0-9()])*",
|
||||
" # last char: not closing delimiter or colon",
|
||||
" [^()}\\]>\"'`.,;:\\ ]",
|
||||
" |",
|
||||
" # single-char path: not delimiter, punctuation, or space",
|
||||
" [^(){}\\[\\]<>\"'`.,;:\\ ]",
|
||||
" )",
|
||||
" # optional line/column suffix (included in path for PathWithPosition::parse_str)",
|
||||
" (:+[0-9]+(:[0-9]+)?|:?\\([0-9]+([,:]?[0-9]+)?\\))?",
|
||||
")",
|
||||
],
|
||||
],
|
||||
// Timeout for hover and Cmd-click path hyperlink discovery in milliseconds. Specifying a
|
||||
// timeout of `0` will disable path hyperlinking in terminal.
|
||||
"path_hyperlink_timeout_ms": 1
|
||||
"path_hyperlink_timeout_ms": 1,
|
||||
},
|
||||
"code_actions_on_format": {},
|
||||
// Settings related to running tasks.
|
||||
@@ -1685,7 +1681,7 @@
|
||||
// * Zed task from history (e.g. one-off task was spawned before)
|
||||
//
|
||||
// Default: true
|
||||
"prefer_lsp": true
|
||||
"prefer_lsp": true,
|
||||
},
|
||||
// An object whose keys are language names, and whose values
|
||||
// are arrays of filenames or extensions of files that should
|
||||
@@ -1702,7 +1698,7 @@
|
||||
"file_types": {
|
||||
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "**/.vscode/**/*.json", "tsconfig*.json"],
|
||||
"Markdown": [".rules", ".cursorrules", ".windsurfrules", ".clinerules"],
|
||||
"Shell Script": [".env.*"]
|
||||
"Shell Script": [".env.*"],
|
||||
},
|
||||
// Settings for which version of Node.js and NPM to use when installing
|
||||
// language servers and Copilot.
|
||||
@@ -1718,14 +1714,14 @@
|
||||
// `path`, but not `npm_path`, Zed will assume that `npm` is located at
|
||||
// `${path}/../npm`.
|
||||
"path": null,
|
||||
"npm_path": null
|
||||
"npm_path": null,
|
||||
},
|
||||
// The extensions that Zed should automatically install on startup.
|
||||
//
|
||||
// If you don't want any of these extensions, add this field to your settings
|
||||
// and change the value to `false`.
|
||||
"auto_install_extensions": {
|
||||
"html": true
|
||||
"html": true,
|
||||
},
|
||||
// The capabilities granted to extensions.
|
||||
//
|
||||
@@ -1733,7 +1729,7 @@
|
||||
"granted_extension_capabilities": [
|
||||
{ "kind": "process:exec", "command": "*", "args": ["**"] },
|
||||
{ "kind": "download_file", "host": "*", "path": ["**"] },
|
||||
{ "kind": "npm:install", "package": "*" }
|
||||
{ "kind": "npm:install", "package": "*" },
|
||||
],
|
||||
// Controls how completions are processed for this language.
|
||||
"completions": {
|
||||
@@ -1784,7 +1780,7 @@
|
||||
// 4. "replace_suffix"
|
||||
// Behaves like `"replace"` if the text after the cursor is a suffix of the completion, and like
|
||||
// `"insert"` otherwise.
|
||||
"lsp_insert_mode": "replace_suffix"
|
||||
"lsp_insert_mode": "replace_suffix",
|
||||
},
|
||||
// Different settings for specific languages.
|
||||
"languages": {
|
||||
@@ -1792,116 +1788,116 @@
|
||||
"language_servers": ["astro-language-server", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-astro"]
|
||||
}
|
||||
"plugins": ["prettier-plugin-astro"],
|
||||
},
|
||||
},
|
||||
"Blade": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"C": {
|
||||
"format_on_save": "off",
|
||||
"use_on_type_format": false,
|
||||
"prettier": {
|
||||
"allowed": false
|
||||
}
|
||||
"allowed": false,
|
||||
},
|
||||
},
|
||||
"C++": {
|
||||
"format_on_save": "off",
|
||||
"use_on_type_format": false,
|
||||
"prettier": {
|
||||
"allowed": false
|
||||
}
|
||||
"allowed": false,
|
||||
},
|
||||
},
|
||||
"CSharp": {
|
||||
"language_servers": ["roslyn", "!omnisharp", "..."]
|
||||
"language_servers": ["roslyn", "!omnisharp", "..."],
|
||||
},
|
||||
"CSS": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"Dart": {
|
||||
"tab_size": 2
|
||||
"tab_size": 2,
|
||||
},
|
||||
"Diff": {
|
||||
"show_edit_predictions": false,
|
||||
"remove_trailing_whitespace_on_save": false,
|
||||
"ensure_final_newline_on_save": false
|
||||
"ensure_final_newline_on_save": false,
|
||||
},
|
||||
"Elixir": {
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."]
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."],
|
||||
},
|
||||
"Elm": {
|
||||
"tab_size": 4
|
||||
"tab_size": 4,
|
||||
},
|
||||
"Erlang": {
|
||||
"language_servers": ["erlang-ls", "!elp", "..."]
|
||||
"language_servers": ["erlang-ls", "!elp", "..."],
|
||||
},
|
||||
"Git Commit": {
|
||||
"allow_rewrap": "anywhere",
|
||||
"soft_wrap": "editor_width",
|
||||
"preferred_line_length": 72
|
||||
"preferred_line_length": 72,
|
||||
},
|
||||
"Go": {
|
||||
"hard_tabs": true,
|
||||
"code_actions_on_format": {
|
||||
"source.organizeImports": true
|
||||
"source.organizeImports": true,
|
||||
},
|
||||
"debuggers": ["Delve"]
|
||||
"debuggers": ["Delve"],
|
||||
},
|
||||
"GraphQL": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"HEEX": {
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."]
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."],
|
||||
},
|
||||
"HTML": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"HTML+ERB": {
|
||||
"language_servers": ["herb", "!ruby-lsp", "..."]
|
||||
"language_servers": ["herb", "!ruby-lsp", "..."],
|
||||
},
|
||||
"Java": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-java"]
|
||||
}
|
||||
"plugins": ["prettier-plugin-java"],
|
||||
},
|
||||
},
|
||||
"JavaScript": {
|
||||
"language_servers": ["!typescript-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"JSON": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"JSONC": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"JS+ERB": {
|
||||
"language_servers": ["!ruby-lsp", "..."]
|
||||
"language_servers": ["!ruby-lsp", "..."],
|
||||
},
|
||||
"Kotlin": {
|
||||
"language_servers": ["!kotlin-language-server", "kotlin-lsp", "..."]
|
||||
"language_servers": ["!kotlin-language-server", "kotlin-lsp", "..."],
|
||||
},
|
||||
"LaTeX": {
|
||||
"formatter": "language_server",
|
||||
"language_servers": ["texlab", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-latex"]
|
||||
}
|
||||
"plugins": ["prettier-plugin-latex"],
|
||||
},
|
||||
},
|
||||
"Markdown": {
|
||||
"format_on_save": "off",
|
||||
@@ -1909,136 +1905,142 @@
|
||||
"remove_trailing_whitespace_on_save": false,
|
||||
"allow_rewrap": "anywhere",
|
||||
"soft_wrap": "editor_width",
|
||||
"completions": {
|
||||
"words": "disabled",
|
||||
},
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"PHP": {
|
||||
"language_servers": ["phpactor", "!intelephense", "!phptools", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["@prettier/plugin-php"],
|
||||
"parser": "php"
|
||||
}
|
||||
"parser": "php",
|
||||
},
|
||||
},
|
||||
"Plain Text": {
|
||||
"allow_rewrap": "anywhere",
|
||||
"soft_wrap": "editor_width"
|
||||
"soft_wrap": "editor_width",
|
||||
"completions": {
|
||||
"words": "disabled",
|
||||
},
|
||||
},
|
||||
"Python": {
|
||||
"code_actions_on_format": {
|
||||
"source.organizeImports.ruff": true
|
||||
"source.organizeImports.ruff": true,
|
||||
},
|
||||
"formatter": {
|
||||
"language_server": {
|
||||
"name": "ruff"
|
||||
}
|
||||
"name": "ruff",
|
||||
},
|
||||
},
|
||||
"debuggers": ["Debugpy"],
|
||||
"language_servers": ["basedpyright", "ruff", "!ty", "!pyrefly", "!pyright", "!pylsp", "..."]
|
||||
"language_servers": ["basedpyright", "ruff", "!ty", "!pyrefly", "!pyright", "!pylsp", "..."],
|
||||
},
|
||||
"Ruby": {
|
||||
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "!sorbet", "!steep", "..."]
|
||||
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "!sorbet", "!steep", "..."],
|
||||
},
|
||||
"Rust": {
|
||||
"debuggers": ["CodeLLDB"]
|
||||
"debuggers": ["CodeLLDB"],
|
||||
},
|
||||
"SCSS": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"Starlark": {
|
||||
"language_servers": ["starpls", "!buck2-lsp", "..."]
|
||||
"language_servers": ["starpls", "!buck2-lsp", "..."],
|
||||
},
|
||||
"Svelte": {
|
||||
"language_servers": ["svelte-language-server", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-svelte"]
|
||||
}
|
||||
"plugins": ["prettier-plugin-svelte"],
|
||||
},
|
||||
},
|
||||
"TSX": {
|
||||
"language_servers": ["!typescript-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"Twig": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"TypeScript": {
|
||||
"language_servers": ["!typescript-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"SystemVerilog": {
|
||||
"format_on_save": "off",
|
||||
"language_servers": ["!slang", "..."],
|
||||
"use_on_type_format": false
|
||||
"use_on_type_format": false,
|
||||
},
|
||||
"Vue.js": {
|
||||
"language_servers": ["vue-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"XML": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["@prettier/plugin-xml"]
|
||||
}
|
||||
"plugins": ["@prettier/plugin-xml"],
|
||||
},
|
||||
},
|
||||
"YAML": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"YAML+ERB": {
|
||||
"language_servers": ["!ruby-lsp", "..."]
|
||||
"language_servers": ["!ruby-lsp", "..."],
|
||||
},
|
||||
"Zig": {
|
||||
"language_servers": ["zls", "..."]
|
||||
}
|
||||
"language_servers": ["zls", "..."],
|
||||
},
|
||||
},
|
||||
// Different settings for specific language models.
|
||||
"language_models": {
|
||||
"anthropic": {
|
||||
"api_url": "https://api.anthropic.com"
|
||||
"api_url": "https://api.anthropic.com",
|
||||
},
|
||||
"bedrock": {},
|
||||
"google": {
|
||||
"api_url": "https://generativelanguage.googleapis.com"
|
||||
"api_url": "https://generativelanguage.googleapis.com",
|
||||
},
|
||||
"ollama": {
|
||||
"api_url": "http://localhost:11434"
|
||||
"api_url": "http://localhost:11434",
|
||||
},
|
||||
"openai": {
|
||||
"api_url": "https://api.openai.com/v1"
|
||||
"api_url": "https://api.openai.com/v1",
|
||||
},
|
||||
"openai_compatible": {},
|
||||
"open_router": {
|
||||
"api_url": "https://openrouter.ai/api/v1"
|
||||
"api_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
"lmstudio": {
|
||||
"api_url": "http://localhost:1234/api/v0"
|
||||
"api_url": "http://localhost:1234/api/v0",
|
||||
},
|
||||
"deepseek": {
|
||||
"api_url": "https://api.deepseek.com/v1"
|
||||
"api_url": "https://api.deepseek.com/v1",
|
||||
},
|
||||
"mistral": {
|
||||
"api_url": "https://api.mistral.ai/v1"
|
||||
"api_url": "https://api.mistral.ai/v1",
|
||||
},
|
||||
"vercel": {
|
||||
"api_url": "https://api.v0.dev/v1"
|
||||
"api_url": "https://api.v0.dev/v1",
|
||||
},
|
||||
"x_ai": {
|
||||
"api_url": "https://api.x.ai/v1"
|
||||
"api_url": "https://api.x.ai/v1",
|
||||
},
|
||||
"zed.dev": {}
|
||||
"zed.dev": {},
|
||||
},
|
||||
"session": {
|
||||
// Whether or not to restore unsaved buffers on restart.
|
||||
@@ -2047,7 +2049,7 @@
|
||||
// dirty files when closing the application.
|
||||
//
|
||||
// Default: true
|
||||
"restore_unsaved_buffers": true
|
||||
"restore_unsaved_buffers": true,
|
||||
},
|
||||
// Zed's Prettier integration settings.
|
||||
// Allows to enable/disable formatting with Prettier
|
||||
@@ -2065,11 +2067,11 @@
|
||||
// "singleQuote": true
|
||||
// Forces Prettier integration to use a specific parser name when formatting files with the language
|
||||
// when set to a non-empty string.
|
||||
"parser": ""
|
||||
"parser": "",
|
||||
},
|
||||
// Settings for auto-closing of JSX tags.
|
||||
"jsx_tag_auto_close": {
|
||||
"enabled": true
|
||||
"enabled": true,
|
||||
},
|
||||
// LSP Specific settings.
|
||||
"lsp": {
|
||||
@@ -2090,19 +2092,19 @@
|
||||
// Specify the DAP name as a key here.
|
||||
"CodeLLDB": {
|
||||
"env": {
|
||||
"RUST_LOG": "info"
|
||||
}
|
||||
}
|
||||
"RUST_LOG": "info",
|
||||
},
|
||||
},
|
||||
},
|
||||
// Common language server settings.
|
||||
"global_lsp_settings": {
|
||||
// Whether to show the LSP servers button in the status bar.
|
||||
"button": true
|
||||
"button": true,
|
||||
},
|
||||
// Jupyter settings
|
||||
"jupyter": {
|
||||
"enabled": true,
|
||||
"kernel_selections": {}
|
||||
"kernel_selections": {},
|
||||
// Specify the language name as the key and the kernel name as the value.
|
||||
// "kernel_selections": {
|
||||
// "python": "conda-base"
|
||||
@@ -2116,7 +2118,7 @@
|
||||
"max_columns": 128,
|
||||
// Maximum number of lines to keep in REPL's scrollback buffer.
|
||||
// Clamped with [4, 256] range.
|
||||
"max_lines": 32
|
||||
"max_lines": 32,
|
||||
},
|
||||
// Vim settings
|
||||
"vim": {
|
||||
@@ -2130,7 +2132,7 @@
|
||||
// Specify the mode as the key and the shape as the value.
|
||||
// The mode can be one of the following: "normal", "replace", "insert", "visual".
|
||||
// The shape can be one of the following: "block", "bar", "underline", "hollow".
|
||||
"cursor_shape": {}
|
||||
"cursor_shape": {},
|
||||
},
|
||||
// The server to connect to. If the environment variable
|
||||
// ZED_SERVER_URL is set, it will override this setting.
|
||||
@@ -2163,9 +2165,9 @@
|
||||
"windows": {
|
||||
"languages": {
|
||||
"PHP": {
|
||||
"language_servers": ["intelephense", "!phpactor", "!phptools", "..."]
|
||||
}
|
||||
}
|
||||
"language_servers": ["intelephense", "!phpactor", "!phptools", "..."],
|
||||
},
|
||||
},
|
||||
},
|
||||
// Whether to show full labels in line indicator or short ones
|
||||
//
|
||||
@@ -2224,7 +2226,7 @@
|
||||
"dock": "bottom",
|
||||
"log_dap_communications": true,
|
||||
"format_dap_log_messages": true,
|
||||
"button": true
|
||||
"button": true,
|
||||
},
|
||||
// Configures any number of settings profiles that are temporarily applied on
|
||||
// top of your existing user settings when selected from
|
||||
@@ -2251,5 +2253,5 @@
|
||||
// Useful for filtering out noisy logs or enabling more verbose logging.
|
||||
//
|
||||
// Example: {"log": {"client": "warn"}}
|
||||
"log": {}
|
||||
"log": {},
|
||||
}
|
||||
|
||||
@@ -11,8 +11,6 @@ use project::agent_server_store::AgentServerCommand;
|
||||
use serde::Deserialize;
|
||||
use settings::Settings as _;
|
||||
use task::ShellBuilder;
|
||||
#[cfg(windows)]
|
||||
use task::ShellKind;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use std::path::PathBuf;
|
||||
@@ -92,23 +90,8 @@ impl AcpConnection {
|
||||
) -> Result<Self> {
|
||||
let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
|
||||
let builder = ShellBuilder::new(&shell, cfg!(windows));
|
||||
#[cfg(windows)]
|
||||
let kind = builder.kind();
|
||||
let (cmd, args) = builder.build(Some(command.path.display().to_string()), &command.args);
|
||||
|
||||
let mut child = util::command::new_smol_command(cmd);
|
||||
#[cfg(windows)]
|
||||
if kind == ShellKind::Cmd {
|
||||
use smol::process::windows::CommandExt;
|
||||
for arg in args {
|
||||
child.raw_arg(arg);
|
||||
}
|
||||
} else {
|
||||
child.args(args);
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
child.args(args);
|
||||
|
||||
let mut child =
|
||||
builder.build_command(Some(command.path.display().to_string()), &command.args);
|
||||
child
|
||||
.envs(command.env.iter().flatten())
|
||||
.stdin(std::process::Stdio::piped())
|
||||
|
||||
@@ -53,7 +53,7 @@ text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tiny_http.workspace = true
|
||||
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
|
||||
tokio-socks.workspace = true
|
||||
tokio.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
[package]
|
||||
name = "cloud_zeta2_prompt"
|
||||
version = "0.1.0"
|
||||
publish.workspace = true
|
||||
edition.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/cloud_zeta2_prompt.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
indoc.workspace = true
|
||||
serde.workspace = true
|
||||
@@ -1,485 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use std::cmp;
|
||||
use std::fmt::Write;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
|
||||
|
||||
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
|
||||
/// NOTE: Differs from zed version of constant - includes a newline
|
||||
pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
|
||||
/// NOTE: Differs from zed version of constant - includes a newline
|
||||
pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
|
||||
|
||||
const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
|
||||
You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
|
||||
|
||||
## Edit History
|
||||
|
||||
"#};
|
||||
|
||||
const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
|
||||
---
|
||||
|
||||
Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
|
||||
Do not include the cursor marker in your output.
|
||||
If you're editing multiple files, be sure to reflect filename in the hunk's header.
|
||||
"};
|
||||
|
||||
const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
|
||||
# Instructions
|
||||
|
||||
You are an edit prediction agent in a code editor.
|
||||
|
||||
Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish.
|
||||
Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make.
|
||||
Always continue along the user's current trajectory, rather than changing course.
|
||||
|
||||
## Output Format
|
||||
|
||||
You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change
|
||||
along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format:
|
||||
|
||||
<edits path="my-project/src/myapp/cli.py">
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
</edits>
|
||||
|
||||
- Specify the file to edit using the `path` attribute.
|
||||
- Use `<old_text>` and `<new_text>` tags to replace content
|
||||
- `<old_text>` must exactly match existing file content, including indentation
|
||||
- `<old_text>` cannot be empty
|
||||
- Do not escape quotes, newlines, or other characters within tags
|
||||
- Always close all tags properly
|
||||
- Don't include the <|user_cursor|> marker in your output.
|
||||
|
||||
## Edit History
|
||||
|
||||
"#};
|
||||
|
||||
const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
|
||||
---
|
||||
|
||||
Remember that the edits in the edit history have already been applied.
|
||||
"#};
|
||||
|
||||
pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
|
||||
let prompt_data = PromptData {
|
||||
events: request.events.clone(),
|
||||
cursor_point: request.cursor_point,
|
||||
cursor_path: request.excerpt_path.clone(),
|
||||
included_files: request.related_files.clone(),
|
||||
};
|
||||
match request.prompt_format {
|
||||
PromptFormat::MinimalQwen => {
|
||||
return Ok(MinimalQwenPrompt.render(&prompt_data));
|
||||
}
|
||||
PromptFormat::SeedCoder1120 => {
|
||||
return Ok(SeedCoder1120Prompt.render(&prompt_data));
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let insertions = match request.prompt_format {
|
||||
PromptFormat::Minimal | PromptFormat::OldTextNewText => {
|
||||
vec![(request.cursor_point, CURSOR_MARKER)]
|
||||
}
|
||||
PromptFormat::OnlySnippets => vec![],
|
||||
PromptFormat::MinimalQwen => unreachable!(),
|
||||
PromptFormat::SeedCoder1120 => unreachable!(),
|
||||
};
|
||||
|
||||
let mut prompt = match request.prompt_format {
|
||||
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::OnlySnippets => String::new(),
|
||||
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::MinimalQwen => unreachable!(),
|
||||
PromptFormat::SeedCoder1120 => unreachable!(),
|
||||
};
|
||||
|
||||
if request.events.is_empty() {
|
||||
prompt.push_str("(No edit history)\n\n");
|
||||
} else {
|
||||
let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
|
||||
"The following are the latest edits made by the user, from earlier to later.\n\n"
|
||||
} else {
|
||||
"Here are the latest edits made by the user, from earlier to later.\n\n"
|
||||
};
|
||||
prompt.push_str(edit_preamble);
|
||||
push_events(&mut prompt, &request.events);
|
||||
}
|
||||
|
||||
let excerpts_preamble = match request.prompt_format {
|
||||
PromptFormat::Minimal => indoc! {"
|
||||
## Part of the file under the cursor
|
||||
|
||||
(The cursor marker <|user_cursor|> indicates the current user cursor position.
|
||||
The file is in current state, edits from edit history has been applied.
|
||||
We only show part of the file around the cursor.
|
||||
You can only edit exactly this part of the file.
|
||||
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
|
||||
"},
|
||||
PromptFormat::OldTextNewText => indoc! {"
|
||||
## Code Excerpts
|
||||
|
||||
Here is some excerpts of code that you should take into account to predict the next edit.
|
||||
|
||||
The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history.
|
||||
|
||||
In addition other excerpts are included to better understand what the edit will be, including the declaration
|
||||
or references of symbols around the cursor, or other similar code snippets that may need to be updated
|
||||
following patterns that appear in the edit history.
|
||||
|
||||
Consider each of them carefully in relation to the edit history, and that the user may not have navigated
|
||||
to the next place they want to edit yet.
|
||||
|
||||
Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
|
||||
"},
|
||||
PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
|
||||
indoc! {"
|
||||
## Code Excerpts
|
||||
|
||||
The cursor marker <|user_cursor|> indicates the current user cursor position.
|
||||
The file is in current state, edits from edit history have been applied.
|
||||
"}
|
||||
}
|
||||
};
|
||||
|
||||
prompt.push_str(excerpts_preamble);
|
||||
prompt.push('\n');
|
||||
|
||||
let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
|
||||
for related_file in &request.related_files {
|
||||
if request.prompt_format == PromptFormat::Minimal {
|
||||
write_codeblock_with_filename(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
} else {
|
||||
write_codeblock(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match request.prompt_format {
|
||||
PromptFormat::OldTextNewText => {
|
||||
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
|
||||
}
|
||||
PromptFormat::Minimal => {
|
||||
prompt.push_str(MINIMAL_PROMPT_REMINDER);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
|
||||
match prompt_format {
|
||||
PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
|
||||
_ => GenerationParams::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_codeblock<'a>(
|
||||
path: &Path,
|
||||
excerpts: impl IntoIterator<Item = &'a Excerpt>,
|
||||
sorted_insertions: &[(Point, &str)],
|
||||
file_line_count: Line,
|
||||
include_line_numbers: bool,
|
||||
output: &'a mut String,
|
||||
) {
|
||||
writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
|
||||
|
||||
write_excerpts(
|
||||
excerpts,
|
||||
sorted_insertions,
|
||||
file_line_count,
|
||||
include_line_numbers,
|
||||
output,
|
||||
);
|
||||
write!(output, "`````\n\n").unwrap();
|
||||
}
|
||||
|
||||
fn write_codeblock_with_filename<'a>(
|
||||
path: &Path,
|
||||
excerpts: impl IntoIterator<Item = &'a Excerpt>,
|
||||
sorted_insertions: &[(Point, &str)],
|
||||
file_line_count: Line,
|
||||
include_line_numbers: bool,
|
||||
output: &'a mut String,
|
||||
) {
|
||||
writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
|
||||
|
||||
write_excerpts(
|
||||
excerpts,
|
||||
sorted_insertions,
|
||||
file_line_count,
|
||||
include_line_numbers,
|
||||
output,
|
||||
);
|
||||
write!(output, "`````\n\n").unwrap();
|
||||
}
|
||||
|
||||
pub fn write_excerpts<'a>(
|
||||
excerpts: impl IntoIterator<Item = &'a Excerpt>,
|
||||
sorted_insertions: &[(Point, &str)],
|
||||
file_line_count: Line,
|
||||
include_line_numbers: bool,
|
||||
output: &mut String,
|
||||
) {
|
||||
let mut current_row = Line(0);
|
||||
let mut sorted_insertions = sorted_insertions.iter().peekable();
|
||||
|
||||
for excerpt in excerpts {
|
||||
if excerpt.start_line > current_row {
|
||||
writeln!(output, "…").unwrap();
|
||||
}
|
||||
if excerpt.text.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
current_row = excerpt.start_line;
|
||||
|
||||
for mut line in excerpt.text.lines() {
|
||||
if include_line_numbers {
|
||||
write!(output, "{}|", current_row.0 + 1).unwrap();
|
||||
}
|
||||
|
||||
while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
|
||||
match current_row.cmp(&insertion_location.line) {
|
||||
cmp::Ordering::Equal => {
|
||||
let (prefix, suffix) = line.split_at(insertion_location.column as usize);
|
||||
output.push_str(prefix);
|
||||
output.push_str(insertion_marker);
|
||||
line = suffix;
|
||||
sorted_insertions.next();
|
||||
}
|
||||
cmp::Ordering::Less => break,
|
||||
cmp::Ordering::Greater => {
|
||||
sorted_insertions.next();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
output.push_str(line);
|
||||
output.push('\n');
|
||||
current_row.0 += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if current_row < file_line_count {
|
||||
writeln!(output, "…").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
|
||||
if events.is_empty() {
|
||||
return;
|
||||
};
|
||||
|
||||
writeln!(output, "`````diff").unwrap();
|
||||
for event in events {
|
||||
writeln!(output, "{}", event).unwrap();
|
||||
}
|
||||
writeln!(output, "`````\n").unwrap();
|
||||
}
|
||||
|
||||
struct PromptData {
|
||||
events: Vec<Arc<Event>>,
|
||||
cursor_point: Point,
|
||||
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
|
||||
included_files: Vec<RelatedFile>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct GenerationParams {
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
trait PromptFormatter {
|
||||
fn render(&self, data: &PromptData) -> String;
|
||||
|
||||
fn generation_params() -> GenerationParams {
|
||||
return GenerationParams::default();
|
||||
}
|
||||
}
|
||||
|
||||
struct MinimalQwenPrompt;
|
||||
|
||||
impl PromptFormatter for MinimalQwenPrompt {
|
||||
fn render(&self, data: &PromptData) -> String {
|
||||
let edit_history = self.fmt_edit_history(data);
|
||||
let context = self.fmt_context(data);
|
||||
|
||||
format!(
|
||||
"{instructions}\n\n{edit_history}\n\n{context}",
|
||||
instructions = MinimalQwenPrompt::INSTRUCTIONS,
|
||||
edit_history = edit_history,
|
||||
context = context
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl MinimalQwenPrompt {
|
||||
const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
|
||||
|
||||
fn fmt_edit_history(&self, data: &PromptData) -> String {
|
||||
if data.events.is_empty() {
|
||||
"(No edit history)\n\n".to_string()
|
||||
} else {
|
||||
let mut events_str = String::new();
|
||||
push_events(&mut events_str, &data.events);
|
||||
format!(
|
||||
"The following are the latest edits made by the user, from earlier to later.\n\n{}",
|
||||
events_str
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn fmt_context(&self, data: &PromptData) -> String {
|
||||
let mut context = String::new();
|
||||
let include_line_numbers = true;
|
||||
|
||||
for related_file in &data.included_files {
|
||||
writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
|
||||
|
||||
if related_file.path == data.cursor_path {
|
||||
write!(context, "<|fim_prefix|>").unwrap();
|
||||
write_excerpts(
|
||||
&related_file.excerpts,
|
||||
&[(data.cursor_point, "<|fim_suffix|>")],
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut context,
|
||||
);
|
||||
writeln!(context, "<|fim_middle|>").unwrap();
|
||||
} else {
|
||||
write_excerpts(
|
||||
&related_file.excerpts,
|
||||
&[],
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut context,
|
||||
);
|
||||
}
|
||||
}
|
||||
context
|
||||
}
|
||||
}
|
||||
|
||||
struct SeedCoder1120Prompt;
|
||||
|
||||
impl PromptFormatter for SeedCoder1120Prompt {
|
||||
fn render(&self, data: &PromptData) -> String {
|
||||
let edit_history = self.fmt_edit_history(data);
|
||||
let context = self.fmt_context(data);
|
||||
|
||||
format!(
|
||||
"# Edit History:\n{edit_history}\n\n{context}",
|
||||
edit_history = edit_history,
|
||||
context = context
|
||||
)
|
||||
}
|
||||
|
||||
fn generation_params() -> GenerationParams {
|
||||
GenerationParams {
|
||||
temperature: Some(0.2),
|
||||
top_p: Some(0.9),
|
||||
stop: Some(vec!["<[end_of_sentence]>".into()]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SeedCoder1120Prompt {
|
||||
fn fmt_edit_history(&self, data: &PromptData) -> String {
|
||||
if data.events.is_empty() {
|
||||
"(No edit history)\n\n".to_string()
|
||||
} else {
|
||||
let mut events_str = String::new();
|
||||
push_events(&mut events_str, &data.events);
|
||||
events_str
|
||||
}
|
||||
}
|
||||
|
||||
fn fmt_context(&self, data: &PromptData) -> String {
|
||||
let mut context = String::new();
|
||||
let include_line_numbers = true;
|
||||
|
||||
for related_file in &data.included_files {
|
||||
writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
|
||||
|
||||
if related_file.path == data.cursor_path {
|
||||
let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
|
||||
context.push_str(&fim_prompt);
|
||||
} else {
|
||||
write_excerpts(
|
||||
&related_file.excerpts,
|
||||
&[],
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut context,
|
||||
);
|
||||
}
|
||||
}
|
||||
context
|
||||
}
|
||||
|
||||
fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
|
||||
let mut buf = String::new();
|
||||
const FIM_SUFFIX: &str = "<[fim-suffix]>";
|
||||
const FIM_PREFIX: &str = "<[fim-prefix]>";
|
||||
const FIM_MIDDLE: &str = "<[fim-middle]>";
|
||||
write!(buf, "{}", FIM_PREFIX).unwrap();
|
||||
write_excerpts(
|
||||
&file.excerpts,
|
||||
&[(cursor_point, FIM_SUFFIX)],
|
||||
file.max_row,
|
||||
true,
|
||||
&mut buf,
|
||||
);
|
||||
|
||||
// Swap prefix and suffix parts
|
||||
let index = buf.find(FIM_SUFFIX).unwrap();
|
||||
let prefix = &buf[..index];
|
||||
let suffix = &buf[index..];
|
||||
|
||||
format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
|
||||
}
|
||||
}
|
||||
@@ -33,12 +33,10 @@ impl StdioTransport {
|
||||
) -> Result<Self> {
|
||||
let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
|
||||
let builder = ShellBuilder::new(&shell, cfg!(windows));
|
||||
let (command, args) =
|
||||
builder.build(Some(binary.executable.display().to_string()), &binary.args);
|
||||
let mut command =
|
||||
builder.build_command(Some(binary.executable.display().to_string()), &binary.args);
|
||||
|
||||
let mut command = util::command::new_smol_command(command);
|
||||
command
|
||||
.args(args)
|
||||
.envs(binary.env.unwrap_or_default())
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
|
||||
@@ -1045,54 +1045,47 @@ async fn heuristic_syntactic_expand(
|
||||
let node_range = node_start..node_end;
|
||||
let row_count = node_end.row - node_start.row + 1;
|
||||
let mut ancestor_range = None;
|
||||
let reached_outline_node = cx.background_executor().scoped({
|
||||
let node_range = node_range.clone();
|
||||
let outline_range = outline_range.clone();
|
||||
let ancestor_range = &mut ancestor_range;
|
||||
|scope| {
|
||||
scope.spawn(async move {
|
||||
// Stop if we've exceeded the row count or reached an outline node. Then, find the interval
|
||||
// of node children which contains the query range. For example, this allows just returning
|
||||
// the header of a declaration rather than the entire declaration.
|
||||
if row_count > max_row_count || outline_range == Some(node_range.clone()) {
|
||||
let mut cursor = node.walk();
|
||||
let mut included_child_start = None;
|
||||
let mut included_child_end = None;
|
||||
let mut previous_end = node_start;
|
||||
if cursor.goto_first_child() {
|
||||
loop {
|
||||
let child_node = cursor.node();
|
||||
let child_range =
|
||||
previous_end..Point::from_ts_point(child_node.end_position());
|
||||
if included_child_start.is_none()
|
||||
&& child_range.contains(&input_range.start)
|
||||
{
|
||||
included_child_start = Some(child_range.start);
|
||||
}
|
||||
if child_range.contains(&input_range.end) {
|
||||
included_child_end = Some(child_range.end);
|
||||
}
|
||||
previous_end = child_range.end;
|
||||
if !cursor.goto_next_sibling() {
|
||||
break;
|
||||
}
|
||||
cx.background_executor()
|
||||
.await_on_background(async {
|
||||
// Stop if we've exceeded the row count or reached an outline node. Then, find the interval
|
||||
// of node children which contains the query range. For example, this allows just returning
|
||||
// the header of a declaration rather than the entire declaration.
|
||||
if row_count > max_row_count || outline_range == Some(node_range.clone()) {
|
||||
let mut cursor = node.walk();
|
||||
let mut included_child_start = None;
|
||||
let mut included_child_end = None;
|
||||
let mut previous_end = node_start;
|
||||
if cursor.goto_first_child() {
|
||||
loop {
|
||||
let child_node = cursor.node();
|
||||
let child_range =
|
||||
previous_end..Point::from_ts_point(child_node.end_position());
|
||||
if included_child_start.is_none()
|
||||
&& child_range.contains(&input_range.start)
|
||||
{
|
||||
included_child_start = Some(child_range.start);
|
||||
}
|
||||
if child_range.contains(&input_range.end) {
|
||||
included_child_end = Some(child_range.end);
|
||||
}
|
||||
previous_end = child_range.end;
|
||||
if !cursor.goto_next_sibling() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let end = included_child_end.unwrap_or(node_range.end);
|
||||
if let Some(start) = included_child_start {
|
||||
let row_count = end.row - start.row;
|
||||
if row_count < max_row_count {
|
||||
*ancestor_range =
|
||||
Some(Some(RangeInclusive::new(start.row, end.row)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
*ancestor_range = Some(None);
|
||||
}
|
||||
})
|
||||
}
|
||||
});
|
||||
reached_outline_node.await;
|
||||
let end = included_child_end.unwrap_or(node_range.end);
|
||||
if let Some(start) = included_child_start {
|
||||
let row_count = end.row - start.row;
|
||||
if row_count < max_row_count {
|
||||
ancestor_range = Some(Some(RangeInclusive::new(start.row, end.row)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
ancestor_range = Some(None);
|
||||
}
|
||||
})
|
||||
.await;
|
||||
if let Some(node) = ancestor_range {
|
||||
return node;
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ workspace = true
|
||||
path = "src/edit_prediction.rs"
|
||||
|
||||
[features]
|
||||
eval-support = []
|
||||
cli-support = []
|
||||
|
||||
[dependencies]
|
||||
ai_onboarding.workspace = true
|
||||
@@ -21,7 +21,6 @@ arrayvec.workspace = true
|
||||
brotli.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
collections.workspace = true
|
||||
copilot.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
@@ -50,8 +49,6 @@ semver.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
strsim.workspace = true
|
||||
strum.workspace = true
|
||||
telemetry.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
@@ -62,6 +59,7 @@ uuid.workspace = true
|
||||
workspace.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
use anyhow::Result;
|
||||
use arrayvec::ArrayVec;
|
||||
use client::{Client, EditPredictionUsage, UserStore};
|
||||
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
|
||||
use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
|
||||
use cloud_llm_client::{
|
||||
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
|
||||
EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
|
||||
MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
|
||||
ZED_VERSION_HEADER_NAME,
|
||||
};
|
||||
use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
|
||||
use collections::{HashMap, HashSet};
|
||||
use db::kvp::{Dismissable, KEY_VALUE_STORE};
|
||||
use edit_prediction_context::EditPredictionExcerptOptions;
|
||||
@@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
|
||||
use futures::{
|
||||
AsyncReadExt as _, FutureExt as _, StreamExt as _,
|
||||
channel::{
|
||||
mpsc::{self, UnboundedReceiver},
|
||||
oneshot,
|
||||
},
|
||||
channel::mpsc::{self, UnboundedReceiver},
|
||||
select_biased,
|
||||
};
|
||||
use gpui::BackgroundExecutor;
|
||||
@@ -58,8 +54,10 @@ mod onboarding_modal;
|
||||
pub mod open_ai_response;
|
||||
mod prediction;
|
||||
pub mod sweep_ai;
|
||||
|
||||
#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
|
||||
pub mod udiff;
|
||||
mod xml_edits;
|
||||
|
||||
mod zed_edit_prediction_delegate;
|
||||
pub mod zeta1;
|
||||
pub mod zeta2;
|
||||
@@ -72,7 +70,6 @@ use crate::mercury::Mercury;
|
||||
use crate::onboarding_modal::ZedPredictModal;
|
||||
pub use crate::prediction::EditPrediction;
|
||||
pub use crate::prediction::EditPredictionId;
|
||||
pub use crate::prediction::EditPredictionInputs;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
pub use crate::sweep_ai::SweepAi;
|
||||
pub use telemetry_events::EditPredictionRating;
|
||||
@@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
|
||||
min_bytes: 128,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
},
|
||||
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
|
||||
prompt_format: PromptFormat::DEFAULT,
|
||||
};
|
||||
|
||||
@@ -162,8 +158,7 @@ pub struct EditPredictionStore {
|
||||
use_context: bool,
|
||||
options: ZetaOptions,
|
||||
update_required: bool,
|
||||
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
edit_prediction_model: EditPredictionModel,
|
||||
pub sweep_ai: SweepAi,
|
||||
@@ -183,10 +178,22 @@ pub enum EditPredictionModel {
|
||||
Mercury,
|
||||
}
|
||||
|
||||
pub struct EditPredictionModelInput {
|
||||
project: Entity<Project>,
|
||||
buffer: Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: Anchor,
|
||||
events: Vec<Arc<zeta_prompt::Event>>,
|
||||
related_files: Arc<[RelatedFile]>,
|
||||
recent_paths: VecDeque<ProjectPath>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
diagnostic_search_range: Range<Point>,
|
||||
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ZetaOptions {
|
||||
pub context: EditPredictionExcerptOptions,
|
||||
pub max_prompt_bytes: usize,
|
||||
pub prompt_format: predict_edits_v3::PromptFormat,
|
||||
}
|
||||
|
||||
@@ -194,7 +201,8 @@ pub struct ZetaOptions {
|
||||
pub enum DebugEvent {
|
||||
ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
|
||||
ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
|
||||
EditPredictionRequested(EditPredictionRequestedDebugEvent),
|
||||
EditPredictionStarted(EditPredictionStartedDebugEvent),
|
||||
EditPredictionFinished(EditPredictionFinishedDebugEvent),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EditPredictionRequestedDebugEvent {
|
||||
pub inputs: EditPredictionInputs,
|
||||
pub retrieval_time: Duration,
|
||||
pub struct EditPredictionStartedDebugEvent {
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub position: Anchor,
|
||||
pub local_prompt: Result<String, String>,
|
||||
pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
|
||||
pub prompt: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EditPredictionFinishedDebugEvent {
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub position: Anchor,
|
||||
pub model_output: Option<String>,
|
||||
}
|
||||
|
||||
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
|
||||
|
||||
struct ProjectState {
|
||||
events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
|
||||
events: VecDeque<Arc<zeta_prompt::Event>>,
|
||||
last_event: Option<LastEvent>,
|
||||
recent_paths: VecDeque<ProjectPath>,
|
||||
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
|
||||
current_prediction: Option<CurrentEditPrediction>,
|
||||
next_pending_prediction_id: usize,
|
||||
pending_predictions: ArrayVec<PendingPrediction, 2>,
|
||||
context_updates_tx: smol::channel::Sender<()>,
|
||||
context_updates_rx: smol::channel::Receiver<()>,
|
||||
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
|
||||
last_prediction_refresh: Option<(EntityId, Instant)>,
|
||||
cancelled_predictions: HashSet<usize>,
|
||||
context: Entity<RelatedExcerptStore>,
|
||||
@@ -241,7 +252,7 @@ struct ProjectState {
|
||||
}
|
||||
|
||||
impl ProjectState {
|
||||
pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
|
||||
pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
self.events
|
||||
.iter()
|
||||
.cloned()
|
||||
@@ -272,6 +283,18 @@ impl ProjectState {
|
||||
})
|
||||
.detach()
|
||||
}
|
||||
|
||||
fn active_buffer(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Option<(Entity<Buffer>, Option<Anchor>)> {
|
||||
let project = project.read(cx);
|
||||
let active_path = project.path_for_entry(project.active_entry()?, cx)?;
|
||||
let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
|
||||
let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
|
||||
Some((active_buffer, registered_buffer.last_position))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -362,6 +385,7 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
|
||||
|
||||
struct RegisteredBuffer {
|
||||
snapshot: BufferSnapshot,
|
||||
last_position: Option<Anchor>,
|
||||
_subscriptions: [gpui::Subscription; 2],
|
||||
}
|
||||
|
||||
@@ -376,7 +400,7 @@ impl LastEvent {
|
||||
&self,
|
||||
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
|
||||
cx: &App,
|
||||
) -> Option<Arc<predict_edits_v3::Event>> {
|
||||
) -> Option<Arc<zeta_prompt::Event>> {
|
||||
let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
|
||||
let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
|
||||
|
||||
@@ -396,7 +420,7 @@ impl LastEvent {
|
||||
if path == old_path && diff.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Arc::new(predict_edits_v3::Event::BufferChange {
|
||||
Some(Arc::new(zeta_prompt::Event::BufferChange {
|
||||
old_path,
|
||||
path,
|
||||
diff,
|
||||
@@ -481,8 +505,7 @@ impl EditPredictionStore {
|
||||
},
|
||||
),
|
||||
update_required: false,
|
||||
debug_tx: None,
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
eval_cache: None,
|
||||
edit_prediction_model: EditPredictionModel::Zeta2,
|
||||
sweep_ai: SweepAi::new(cx),
|
||||
@@ -531,17 +554,11 @@ impl EditPredictionStore {
|
||||
.is_some()
|
||||
}
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
|
||||
self.eval_cache = Some(cache);
|
||||
}
|
||||
|
||||
pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
|
||||
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
|
||||
self.debug_tx = Some(debug_watch_tx);
|
||||
debug_watch_rx
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ZetaOptions {
|
||||
&self.options
|
||||
}
|
||||
@@ -560,15 +577,41 @@ impl EditPredictionStore {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
|
||||
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
|
||||
project_state.events.clear();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn edit_history_for_project(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project_state| project_state.events.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn context_for_project<'a>(
|
||||
&'a self,
|
||||
project: &Entity<Project>,
|
||||
cx: &'a App,
|
||||
) -> &'a [RelatedFile] {
|
||||
) -> Arc<[RelatedFile]> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project| project.context.read(cx).related_files())
|
||||
.unwrap_or(&[])
|
||||
.unwrap_or_else(|| vec![].into())
|
||||
}
|
||||
|
||||
pub fn context_for_project_with_buffers<'a>(
|
||||
&'a self,
|
||||
project: &Entity<Project>,
|
||||
cx: &'a App,
|
||||
) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project| project.context.read(cx).related_files_with_buffers())
|
||||
}
|
||||
|
||||
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
|
||||
@@ -599,85 +642,21 @@ impl EditPredictionStore {
|
||||
cx: &mut Context<Self>,
|
||||
) -> &mut ProjectState {
|
||||
let entity_id = project.entity_id();
|
||||
let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
|
||||
self.projects
|
||||
.entry(entity_id)
|
||||
.or_insert_with(|| ProjectState {
|
||||
context: {
|
||||
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
|
||||
cx.subscribe(
|
||||
&related_excerpt_store,
|
||||
move |this, _, event, _| match event {
|
||||
RelatedExcerptStoreEvent::StartedRefresh => {
|
||||
if let Some(debug_tx) = this.debug_tx.clone() {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalStarted(
|
||||
ContextRetrievalStartedDebugEvent {
|
||||
project_entity_id: entity_id,
|
||||
timestamp: Instant::now(),
|
||||
search_prompt: String::new(),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
mean_definition_latency,
|
||||
max_definition_latency,
|
||||
} => {
|
||||
if let Some(debug_tx) = this.debug_tx.clone() {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalFinished(
|
||||
ContextRetrievalFinishedDebugEvent {
|
||||
project_entity_id: entity_id,
|
||||
timestamp: Instant::now(),
|
||||
metadata: vec![
|
||||
(
|
||||
"Cache Hits",
|
||||
format!(
|
||||
"{}/{}",
|
||||
cache_hit_count,
|
||||
cache_hit_count + cache_miss_count
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Max LSP Time",
|
||||
format!(
|
||||
"{} ms",
|
||||
max_definition_latency.as_millis()
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Mean LSP Time",
|
||||
format!(
|
||||
"{} ms",
|
||||
mean_definition_latency.as_millis()
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
if let Some(project_state) = this.projects.get(&entity_id) {
|
||||
project_state.context_updates_tx.send_blocking(()).ok();
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
|
||||
this.handle_excerpt_store_event(entity_id, event);
|
||||
})
|
||||
.detach();
|
||||
related_excerpt_store
|
||||
},
|
||||
events: VecDeque::new(),
|
||||
last_event: None,
|
||||
recent_paths: VecDeque::new(),
|
||||
context_updates_rx,
|
||||
context_updates_tx,
|
||||
debug_tx: None,
|
||||
registered_buffers: HashMap::default(),
|
||||
current_prediction: None,
|
||||
cancelled_predictions: HashSet::default(),
|
||||
@@ -689,12 +668,79 @@ impl EditPredictionStore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn project_context_updates(
|
||||
&self,
|
||||
pub fn remove_project(&mut self, project: &Entity<Project>) {
|
||||
self.projects.remove(&project.entity_id());
|
||||
}
|
||||
|
||||
fn handle_excerpt_store_event(
|
||||
&mut self,
|
||||
project_entity_id: EntityId,
|
||||
event: &RelatedExcerptStoreEvent,
|
||||
) {
|
||||
if let Some(project_state) = self.projects.get(&project_entity_id) {
|
||||
if let Some(debug_tx) = project_state.debug_tx.clone() {
|
||||
match event {
|
||||
RelatedExcerptStoreEvent::StartedRefresh => {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalStarted(
|
||||
ContextRetrievalStartedDebugEvent {
|
||||
project_entity_id: project_entity_id,
|
||||
timestamp: Instant::now(),
|
||||
search_prompt: String::new(),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
mean_definition_latency,
|
||||
max_definition_latency,
|
||||
} => {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalFinished(
|
||||
ContextRetrievalFinishedDebugEvent {
|
||||
project_entity_id: project_entity_id,
|
||||
timestamp: Instant::now(),
|
||||
metadata: vec![
|
||||
(
|
||||
"Cache Hits",
|
||||
format!(
|
||||
"{}/{}",
|
||||
cache_hit_count,
|
||||
cache_hit_count + cache_miss_count
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Max LSP Time",
|
||||
format!("{} ms", max_definition_latency.as_millis())
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Mean LSP Time",
|
||||
format!("{} ms", mean_definition_latency.as_millis())
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_info(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
) -> Option<smol::channel::Receiver<()>> {
|
||||
let project_state = self.projects.get(&project.entity_id())?;
|
||||
Some(project_state.context_updates_rx.clone())
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<DebugEvent> {
|
||||
let project_state = self.get_or_init_project(project, cx);
|
||||
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
|
||||
project_state.debug_tx = Some(debug_watch_tx);
|
||||
debug_watch_rx
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
@@ -768,6 +814,7 @@ impl EditPredictionStore {
|
||||
let project_entity_id = project.entity_id();
|
||||
entry.insert(RegisteredBuffer {
|
||||
snapshot,
|
||||
last_position: None,
|
||||
_subscriptions: [
|
||||
cx.subscribe(buffer, {
|
||||
let project = project.downgrade();
|
||||
@@ -855,13 +902,21 @@ impl EditPredictionStore {
|
||||
});
|
||||
}
|
||||
|
||||
fn current_prediction_for_buffer(
|
||||
&self,
|
||||
fn prediction_at(
|
||||
&mut self,
|
||||
buffer: &Entity<Buffer>,
|
||||
position: Option<language::Anchor>,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Option<BufferEditPrediction<'_>> {
|
||||
let project_state = self.projects.get(&project.entity_id())?;
|
||||
let project_state = self.projects.get_mut(&project.entity_id())?;
|
||||
if let Some(position) = position
|
||||
&& let Some(buffer) = project_state
|
||||
.registered_buffers
|
||||
.get_mut(&buffer.entity_id())
|
||||
{
|
||||
buffer.last_position = Some(position);
|
||||
}
|
||||
|
||||
let CurrentEditPrediction {
|
||||
requested_by,
|
||||
@@ -1104,12 +1159,21 @@ impl EditPredictionStore {
|
||||
};
|
||||
|
||||
self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
|
||||
let Some(open_buffer_task) = project
|
||||
.update(cx, |project, cx| {
|
||||
project
|
||||
.active_entry()
|
||||
.and_then(|entry| project.path_for_entry(entry, cx))
|
||||
.map(|path| project.open_buffer(path, cx))
|
||||
let Some((active_buffer, snapshot, cursor_point)) = this
|
||||
.read_with(cx, |this, cx| {
|
||||
let project_state = this.projects.get(&project.entity_id())?;
|
||||
let (buffer, position) = project_state.active_buffer(&project, cx)?;
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
if !Self::predictions_enabled_at(&snapshot, position, cx) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let cursor_point = position
|
||||
.map(|pos| pos.to_point(&snapshot))
|
||||
.unwrap_or_default();
|
||||
|
||||
Some((buffer, snapshot, cursor_point))
|
||||
})
|
||||
.log_err()
|
||||
.flatten()
|
||||
@@ -1118,14 +1182,11 @@ impl EditPredictionStore {
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let active_buffer = open_buffer_task.await?;
|
||||
let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
|
||||
active_buffer,
|
||||
&snapshot,
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
cursor_point,
|
||||
&project,
|
||||
cx,
|
||||
)
|
||||
@@ -1170,6 +1231,37 @@ impl EditPredictionStore {
|
||||
});
|
||||
}
|
||||
|
||||
fn predictions_enabled_at(
|
||||
snapshot: &BufferSnapshot,
|
||||
position: Option<language::Anchor>,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
let file = snapshot.file();
|
||||
let all_settings = all_language_settings(file, cx);
|
||||
if !all_settings.show_edit_predictions(snapshot.language(), cx)
|
||||
|| file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(last_position) = position {
|
||||
let settings = snapshot.settings_at(last_position, cx);
|
||||
|
||||
if !settings.edit_predictions_disabled_in.is_empty()
|
||||
&& let Some(scope) = snapshot.language_scope_at(last_position)
|
||||
&& let Some(scope_name) = scope.override_name()
|
||||
&& settings
|
||||
.edit_predictions_disabled_in
|
||||
.iter()
|
||||
.any(|s| s == scope_name)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
|
||||
#[cfg(test)]
|
||||
@@ -1348,6 +1440,7 @@ impl EditPredictionStore {
|
||||
let project_state = self.projects.get(&project.entity_id()).unwrap();
|
||||
let events = project_state.events(cx);
|
||||
let has_events = !events.is_empty();
|
||||
let debug_tx = project_state.debug_tx.clone();
|
||||
|
||||
let snapshot = active_buffer.read(cx).snapshot();
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
@@ -1357,55 +1450,29 @@ impl EditPredictionStore {
|
||||
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
|
||||
|
||||
let related_files = if self.use_context {
|
||||
self.context_for_project(&project, cx).to_vec()
|
||||
self.context_for_project(&project, cx)
|
||||
} else {
|
||||
Vec::new()
|
||||
Vec::new().into()
|
||||
};
|
||||
|
||||
let inputs = EditPredictionModelInput {
|
||||
project: project.clone(),
|
||||
buffer: active_buffer.clone(),
|
||||
snapshot: snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
related_files,
|
||||
recent_paths: project_state.recent_paths.clone(),
|
||||
trigger,
|
||||
diagnostic_search_range: diagnostic_search_range.clone(),
|
||||
debug_tx,
|
||||
};
|
||||
|
||||
let task = match self.edit_prediction_model {
|
||||
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
|
||||
self,
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
trigger,
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
|
||||
self,
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
related_files,
|
||||
trigger,
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
&project_state.recent_paths,
|
||||
related_files,
|
||||
diagnostic_search_range.clone(),
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Mercury => self.mercury.request_prediction(
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
&project_state.recent_paths,
|
||||
related_files,
|
||||
diagnostic_search_range.clone(),
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
|
||||
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
|
||||
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
|
||||
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
|
||||
};
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
@@ -1529,8 +1596,8 @@ impl EditPredictionStore {
|
||||
client: Arc<Client>,
|
||||
llm_token: LlmApiToken,
|
||||
app_version: Version,
|
||||
#[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
#[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
|
||||
#[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
#[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
|
||||
) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
|
||||
let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
|
||||
http_client::Url::parse(&predict_edits_url)?
|
||||
@@ -1540,7 +1607,7 @@ impl EditPredictionStore {
|
||||
.build_zed_llm_url("/predict_edits/raw", &[])?
|
||||
};
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
let cache_key = if let Some(cache) = eval_cache {
|
||||
use collections::FxHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
@@ -1574,7 +1641,7 @@ impl EditPredictionStore {
|
||||
)
|
||||
.await?;
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
if let Some((cache, request, key)) = cache_key {
|
||||
cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
|
||||
}
|
||||
@@ -1706,6 +1773,20 @@ impl EditPredictionStore {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
pub fn set_context_for_buffer(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.get_or_init_project(project, cx)
|
||||
.context
|
||||
.update(cx, |store, _| {
|
||||
store.set_related_files(related_files);
|
||||
});
|
||||
}
|
||||
|
||||
fn is_file_open_source(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
@@ -1729,14 +1810,14 @@ impl EditPredictionStore {
|
||||
self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
|
||||
}
|
||||
|
||||
fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
|
||||
fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
|
||||
if !self.data_collection_choice.is_enabled() {
|
||||
return false;
|
||||
}
|
||||
events.iter().all(|event| {
|
||||
matches!(
|
||||
event.as_ref(),
|
||||
Event::BufferChange {
|
||||
zeta_prompt::Event::BufferChange {
|
||||
in_open_source_repo: true,
|
||||
..
|
||||
}
|
||||
@@ -1817,10 +1898,10 @@ pub struct ZedUpdateRequiredError {
|
||||
minimum_version: Version,
|
||||
}
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
pub type EvalCacheKey = (EvalCacheEntryKind, u64);
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum EvalCacheEntryKind {
|
||||
Context,
|
||||
@@ -1828,7 +1909,7 @@ pub enum EvalCacheEntryKind {
|
||||
Prediction,
|
||||
}
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
impl std::fmt::Display for EvalCacheEntryKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
@@ -1839,7 +1920,7 @@ impl std::fmt::Display for EvalCacheEntryKind {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
pub trait EvalCache: Send + Sync {
|
||||
fn read(&self, key: EvalCacheKey) -> Option<String>;
|
||||
fn write(&self, key: EvalCacheKey, input: &str, value: &str);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::*;
|
||||
use crate::zeta1::MAX_EVENT_TOKENS;
|
||||
use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
|
||||
use client::{UserStore, test::FakeServer};
|
||||
use clock::{FakeSystemClock, ReplicaId};
|
||||
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
|
||||
@@ -7,7 +7,6 @@ use cloud_llm_client::{
|
||||
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
|
||||
RejectEditPredictionsBody,
|
||||
};
|
||||
use edit_prediction_context::Line;
|
||||
use futures::{
|
||||
AsyncReadExt, StreamExt,
|
||||
channel::{mpsc, oneshot},
|
||||
@@ -28,6 +27,7 @@ use settings::SettingsStore;
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
use util::{path, rel_path::rel_path};
|
||||
use uuid::Uuid;
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
|
||||
|
||||
@@ -45,10 +45,6 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
.await;
|
||||
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.register_project(&project, cx);
|
||||
});
|
||||
|
||||
let buffer1 = project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
|
||||
@@ -60,30 +56,38 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
let position = snapshot1.anchor_before(language::Point::new(1, 3));
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.register_project(&project, cx);
|
||||
ep_store.register_buffer(&buffer1, &project, cx);
|
||||
});
|
||||
|
||||
// Prediction for current file
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
|
||||
});
|
||||
let (_request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
|
||||
respond_tx
|
||||
.send(model_response(indoc! {r"
|
||||
--- a/root/1.txt
|
||||
+++ b/root/1.txt
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! {r"
|
||||
--- a/root/1.txt
|
||||
+++ b/root/1.txt
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"},
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
let prediction = ep_store
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
.prediction_at(&buffer1, None, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(prediction, BufferEditPrediction::Local { .. });
|
||||
});
|
||||
@@ -120,22 +124,26 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
});
|
||||
});
|
||||
|
||||
let (_request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
respond_tx
|
||||
.send(model_response(indoc! {r#"
|
||||
--- a/root/2.txt
|
||||
+++ b/root/2.txt
|
||||
Hola!
|
||||
-Como
|
||||
+Como estas?
|
||||
Adios
|
||||
"#}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! {r#"
|
||||
--- a/root/2.txt
|
||||
+++ b/root/2.txt
|
||||
@@ ... @@
|
||||
Hola!
|
||||
-Como
|
||||
+Como estas?
|
||||
Adios
|
||||
"#},
|
||||
))
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
let prediction = ep_store
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
.prediction_at(&buffer1, None, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
prediction,
|
||||
@@ -151,9 +159,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
let prediction = ep_store
|
||||
.current_prediction_for_buffer(&buffer2, &project, cx)
|
||||
.prediction_at(&buffer2, None, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(prediction, BufferEditPrediction::Local { .. });
|
||||
});
|
||||
@@ -186,7 +194,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
|
||||
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
|
||||
// TODO Put back when we have a structured request again
|
||||
// assert_eq!(
|
||||
@@ -202,15 +210,18 @@ async fn test_simple_request(cx: &mut TestAppContext) {
|
||||
// );
|
||||
|
||||
respond_tx
|
||||
.send(model_response(indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"},
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
|
||||
@@ -276,15 +287,18 @@ async fn test_request_events(cx: &mut TestAppContext) {
|
||||
);
|
||||
|
||||
respond_tx
|
||||
.send(model_response(indoc! {r#"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"#}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! {r#"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"#},
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
|
||||
@@ -324,27 +338,17 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
const NO_OP_DIFF: &str = indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How
|
||||
Bye
|
||||
"};
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let response = model_response(NO_OP_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let response = model_response(request, "");
|
||||
let id = response.id.clone();
|
||||
respond_tx.send(response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
assert!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.is_none()
|
||||
);
|
||||
});
|
||||
@@ -389,22 +393,22 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_text("Hello!\nHow are you?\nBye", cx);
|
||||
});
|
||||
|
||||
let response = model_response(SIMPLE_DIFF);
|
||||
let response = model_response(request, SIMPLE_DIFF);
|
||||
let id = response.id.clone();
|
||||
respond_tx.send(response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
assert!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.is_none()
|
||||
);
|
||||
});
|
||||
@@ -459,17 +463,17 @@ async fn test_replace_current(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(request, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_tx.send(first_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -482,18 +486,18 @@ async fn test_replace_current(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let second_response = model_response(SIMPLE_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let second_response = model_response(request, SIMPLE_DIFF);
|
||||
let second_id = second_response.id.clone();
|
||||
respond_tx.send(second_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// second replaces first
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -541,17 +545,17 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(request, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_tx.send(first_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -564,27 +568,30 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
// worse than current prediction
|
||||
let second_response = model_response(indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are
|
||||
Bye
|
||||
"});
|
||||
let second_response = model_response(
|
||||
request,
|
||||
indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are
|
||||
Bye
|
||||
"},
|
||||
);
|
||||
let second_id = second_response.id.clone();
|
||||
respond_tx.send(second_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// first is preferred over second
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -633,29 +640,29 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_first) = requests.predict.next().await.unwrap();
|
||||
let (request1, respond_first) = requests.predict.next().await.unwrap();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_second) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_second) = requests.predict.next().await.unwrap();
|
||||
|
||||
// wait for throttle
|
||||
cx.run_until_parked();
|
||||
|
||||
// second responds first
|
||||
let second_response = model_response(SIMPLE_DIFF);
|
||||
let second_response = model_response(request, SIMPLE_DIFF);
|
||||
let second_id = second_response.id.clone();
|
||||
respond_second.send(second_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// current prediction is second
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -663,17 +670,17 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
|
||||
);
|
||||
});
|
||||
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let first_response = model_response(request1, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_first.send(first_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// current prediction is still second, since first was cancelled
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -724,13 +731,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_first) = requests.predict.next().await.unwrap();
|
||||
let (request1, respond_first) = requests.predict.next().await.unwrap();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_second) = requests.predict.next().await.unwrap();
|
||||
let (request2, respond_second) = requests.predict.next().await.unwrap();
|
||||
|
||||
// wait for throttle, so requests are sent
|
||||
cx.run_until_parked();
|
||||
@@ -754,19 +761,19 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
// wait for throttle
|
||||
cx.run_until_parked();
|
||||
|
||||
let (_, respond_third) = requests.predict.next().await.unwrap();
|
||||
let (request3, respond_third) = requests.predict.next().await.unwrap();
|
||||
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let first_response = model_response(request1, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_first.send(first_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// current prediction is first
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -774,17 +781,17 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
);
|
||||
});
|
||||
|
||||
let cancelled_response = model_response(SIMPLE_DIFF);
|
||||
let cancelled_response = model_response(request2, SIMPLE_DIFF);
|
||||
let cancelled_id = cancelled_response.id.clone();
|
||||
respond_second.send(cancelled_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// current prediction is still first, since second was cancelled
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -792,17 +799,17 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
);
|
||||
});
|
||||
|
||||
let third_response = model_response(SIMPLE_DIFF);
|
||||
let third_response = model_response(request3, SIMPLE_DIFF);
|
||||
let third_response_id = third_response.id.clone();
|
||||
respond_third.send(third_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// third completes and replaces first
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -1036,7 +1043,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
|
||||
// );
|
||||
// }
|
||||
|
||||
fn model_response(text: &str) -> open_ai::Response {
|
||||
// Generate a model response that would apply the given diff to the active file.
|
||||
fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
|
||||
let prompt = match &request.messages[0] {
|
||||
open_ai::RequestMessage::User {
|
||||
content: open_ai::MessageContent::Plain(content),
|
||||
} => content,
|
||||
_ => panic!("unexpected request {request:?}"),
|
||||
};
|
||||
|
||||
let open = "<editable_region>\n";
|
||||
let close = "</editable_region>";
|
||||
let cursor = "<|user_cursor|>";
|
||||
|
||||
let start_ix = open.len() + prompt.find(open).unwrap();
|
||||
let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
|
||||
let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
|
||||
let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
|
||||
|
||||
open_ai::Response {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
object: "response".into(),
|
||||
@@ -1045,7 +1069,7 @@ fn model_response(text: &str) -> open_ai::Response {
|
||||
choices: vec![open_ai::Choice {
|
||||
index: 0,
|
||||
message: open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Plain(text.to_string())),
|
||||
content: Some(open_ai::MessageContent::Plain(new_excerpt)),
|
||||
tool_calls: vec![],
|
||||
},
|
||||
finish_reason: None,
|
||||
@@ -1160,20 +1184,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
|
||||
.await;
|
||||
|
||||
let completion = EditPrediction {
|
||||
let prediction = EditPrediction {
|
||||
edits,
|
||||
edit_preview,
|
||||
buffer: buffer.clone(),
|
||||
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
|
||||
id: EditPredictionId("the-id".into()),
|
||||
inputs: EditPredictionInputs {
|
||||
inputs: ZetaPromptInput {
|
||||
events: Default::default(),
|
||||
included_files: Default::default(),
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
line: Line(0),
|
||||
column: 0,
|
||||
},
|
||||
related_files: Default::default(),
|
||||
cursor_path: Path::new("").into(),
|
||||
cursor_excerpt: "".into(),
|
||||
editable_range_in_excerpt: 0..0,
|
||||
cursor_offset_in_excerpt: 0,
|
||||
},
|
||||
buffer_snapshotted_at: Instant::now(),
|
||||
response_received_at: Instant::now(),
|
||||
@@ -1182,7 +1205,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1192,7 +1215,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1202,7 +1225,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1212,7 +1235,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1222,7 +1245,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1232,7 +1255,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1242,7 +1265,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1252,7 +1275,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1260,7 +1283,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
|
||||
assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
|
||||
assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -735,6 +735,7 @@ mod tests {
|
||||
true,
|
||||
fs.clone(),
|
||||
Default::default(),
|
||||
true,
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
@@ -758,6 +759,7 @@ mod tests {
|
||||
true,
|
||||
fs.clone(),
|
||||
Default::default(),
|
||||
true,
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
@@ -816,6 +818,7 @@ mod tests {
|
||||
true,
|
||||
fs.clone(),
|
||||
Default::default(),
|
||||
true,
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::predict_edits_v3::Event;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use edit_prediction_context::RelatedFile;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, Task,
|
||||
App, AppContext as _, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
|
||||
};
|
||||
use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
|
||||
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
use crate::{
|
||||
EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
|
||||
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
|
||||
EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
|
||||
prediction::EditPredictionResult,
|
||||
};
|
||||
|
||||
@@ -38,16 +35,17 @@ impl Mercury {
|
||||
store_api_token_in_keychain(api_token, cx)
|
||||
}
|
||||
|
||||
pub fn request_prediction(
|
||||
pub(crate) fn request_prediction(
|
||||
&self,
|
||||
_project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
_recent_paths: &VecDeque<ProjectPath>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
_diagnostic_search_range: Range<Point>,
|
||||
EditPredictionModelInput {
|
||||
buffer,
|
||||
snapshot,
|
||||
position,
|
||||
events,
|
||||
related_files,
|
||||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
|
||||
@@ -62,6 +60,7 @@ impl Mercury {
|
||||
let http_client = cx.http_client();
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let active_buffer = buffer.clone();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let (editable_range, context_range) =
|
||||
@@ -72,39 +71,39 @@ impl Mercury {
|
||||
MAX_REWRITE_TOKENS,
|
||||
);
|
||||
|
||||
let offset_range = editable_range.to_offset(&snapshot);
|
||||
let prompt = build_prompt(
|
||||
&events,
|
||||
&related_files,
|
||||
&snapshot,
|
||||
full_path.as_ref(),
|
||||
cursor_point,
|
||||
editable_range,
|
||||
context_range.clone(),
|
||||
);
|
||||
let context_offset_range = context_range.to_offset(&snapshot);
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events: events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(
|
||||
context_range.start.row,
|
||||
),
|
||||
text: snapshot
|
||||
.text_for_range(context_range.clone())
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
let editable_offset_range = editable_range.to_offset(&snapshot);
|
||||
|
||||
let inputs = zeta_prompt::ZetaPromptInput {
|
||||
events,
|
||||
related_files,
|
||||
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
|
||||
- context_range.start.to_offset(&snapshot),
|
||||
cursor_path: full_path.clone(),
|
||||
cursor_excerpt: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
editable_range_in_excerpt: (editable_offset_range.start
|
||||
- context_offset_range.start)
|
||||
..(editable_offset_range.end - context_offset_range.start),
|
||||
};
|
||||
|
||||
let prompt = build_prompt(&inputs);
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionStarted(
|
||||
EditPredictionStartedDebugEvent {
|
||||
buffer: active_buffer.downgrade(),
|
||||
prompt: Some(prompt.clone()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let request_body = open_ai::Request {
|
||||
model: "mercury-coder".into(),
|
||||
messages: vec![open_ai::RequestMessage::User {
|
||||
@@ -160,6 +159,18 @@ impl Mercury {
|
||||
let id = mem::take(&mut response.id);
|
||||
let response_str = text_from_response(response).unwrap_or_default();
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionFinished(
|
||||
EditPredictionFinishedDebugEvent {
|
||||
buffer: active_buffer.downgrade(),
|
||||
model_output: Some(response_str.clone()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
|
||||
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
|
||||
|
||||
@@ -168,15 +179,16 @@ impl Mercury {
|
||||
|
||||
if response_str != NO_PREDICTION_OUTPUT {
|
||||
let old_text = snapshot
|
||||
.text_for_range(offset_range.clone())
|
||||
.text_for_range(editable_offset_range.clone())
|
||||
.collect::<String>();
|
||||
edits.extend(
|
||||
language::text_diff(&old_text, &response_str)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(offset_range.start + range.start)
|
||||
..snapshot.anchor_before(offset_range.start + range.end),
|
||||
snapshot.anchor_after(editable_offset_range.start + range.start)
|
||||
..snapshot
|
||||
.anchor_before(editable_offset_range.start + range.end),
|
||||
text,
|
||||
)
|
||||
}),
|
||||
@@ -186,8 +198,6 @@ impl Mercury {
|
||||
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (id, edits, old_snapshot, response_received_at, inputs) =
|
||||
result.await.context("Mercury edit prediction failed")?;
|
||||
@@ -208,15 +218,7 @@ impl Mercury {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_prompt(
|
||||
events: &[Arc<Event>],
|
||||
related_files: &[RelatedFile],
|
||||
cursor_buffer: &BufferSnapshot,
|
||||
cursor_buffer_path: &Path,
|
||||
cursor_point: Point,
|
||||
editable_range: Range<Point>,
|
||||
context_range: Range<Point>,
|
||||
) -> String {
|
||||
fn build_prompt(inputs: &ZetaPromptInput) -> String {
|
||||
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
|
||||
@@ -237,14 +239,14 @@ fn build_prompt(
|
||||
&mut prompt,
|
||||
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|
||||
|prompt| {
|
||||
for related_file in related_files {
|
||||
for related_file in inputs.related_files.iter() {
|
||||
for related_excerpt in &related_file.excerpts {
|
||||
push_delimited(
|
||||
prompt,
|
||||
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|
||||
|prompt| {
|
||||
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
|
||||
prompt.push_str(related_file.path.path.as_unix_str());
|
||||
prompt.push_str(related_file.path.to_string_lossy().as_ref());
|
||||
prompt.push('\n');
|
||||
prompt.push_str(&related_excerpt.text.to_string());
|
||||
},
|
||||
@@ -259,21 +261,22 @@ fn build_prompt(
|
||||
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|
||||
|prompt| {
|
||||
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
|
||||
prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
|
||||
prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
|
||||
prompt.push('\n');
|
||||
|
||||
let prefix_range = context_range.start..editable_range.start;
|
||||
let suffix_range = editable_range.end..context_range.end;
|
||||
|
||||
prompt.extend(cursor_buffer.text_for_range(prefix_range));
|
||||
prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
|
||||
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
|
||||
let range_before_cursor = editable_range.start..cursor_point;
|
||||
let range_after_cursor = cursor_point..editable_range.end;
|
||||
prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
|
||||
prompt.push_str(
|
||||
&inputs.cursor_excerpt
|
||||
[inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
|
||||
);
|
||||
prompt.push_str(CURSOR_TAG);
|
||||
prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
|
||||
prompt.push_str(
|
||||
&inputs.cursor_excerpt
|
||||
[inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
|
||||
);
|
||||
});
|
||||
prompt.extend(cursor_buffer.text_for_range(suffix_range));
|
||||
prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
|
||||
},
|
||||
);
|
||||
|
||||
@@ -281,8 +284,8 @@ fn build_prompt(
|
||||
&mut prompt,
|
||||
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|
||||
|prompt| {
|
||||
for event in events {
|
||||
writeln!(prompt, "{event}").unwrap();
|
||||
for event in inputs.events.iter() {
|
||||
zeta_prompt::write_event(prompt, &event);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::{
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
@@ -9,7 +8,7 @@ use cloud_llm_client::EditPredictionRejectReason;
|
||||
use edit_prediction_types::interpolate_edits;
|
||||
use gpui::{AsyncApp, Entity, SharedString};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
|
||||
use serde::Serialize;
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct EditPredictionId(pub SharedString);
|
||||
@@ -40,7 +39,7 @@ impl EditPredictionResult {
|
||||
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
|
||||
buffer_snapshotted_at: Instant,
|
||||
response_received_at: Instant,
|
||||
inputs: EditPredictionInputs,
|
||||
inputs: ZetaPromptInput,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Self {
|
||||
if edits.is_empty() {
|
||||
@@ -94,15 +93,7 @@ pub struct EditPrediction {
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub buffer_snapshotted_at: Instant,
|
||||
pub response_received_at: Instant,
|
||||
pub inputs: EditPredictionInputs,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct EditPredictionInputs {
|
||||
pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
|
||||
pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
|
||||
pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
|
||||
pub cursor_path: Arc<Path>,
|
||||
pub inputs: zeta_prompt::ZetaPromptInput,
|
||||
}
|
||||
|
||||
impl EditPrediction {
|
||||
@@ -133,9 +124,12 @@ impl std::fmt::Debug for EditPrediction {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use super::*;
|
||||
use gpui::{App, Entity, TestAppContext, prelude::*};
|
||||
use language::{Buffer, ToOffset as _};
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
@@ -154,14 +148,13 @@ mod tests {
|
||||
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
|
||||
buffer: buffer.clone(),
|
||||
edit_preview,
|
||||
inputs: EditPredictionInputs {
|
||||
inputs: ZetaPromptInput {
|
||||
events: vec![],
|
||||
included_files: vec![],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
line: cloud_llm_client::predict_edits_v3::Line(0),
|
||||
column: 0,
|
||||
},
|
||||
related_files: vec![].into(),
|
||||
cursor_path: Path::new("path.txt").into(),
|
||||
cursor_offset_in_excerpt: 0,
|
||||
cursor_excerpt: "".into(),
|
||||
editable_range_in_excerpt: 0..0,
|
||||
},
|
||||
buffer_snapshotted_at: Instant::now(),
|
||||
response_received_at: Instant::now(),
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::predict_edits_v3::Event;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use edit_prediction_context::RelatedFile;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, Task,
|
||||
App, AppContext as _, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
|
||||
use language::{Point, ToOffset as _};
|
||||
use lsp::DiagnosticSeverity;
|
||||
use project::{Project, ProjectPath};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
fmt::{self, Write as _},
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
|
||||
use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
|
||||
|
||||
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
|
||||
|
||||
@@ -44,40 +39,34 @@ impl SweepAi {
|
||||
|
||||
pub fn request_prediction_with_sweep(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
recent_paths: &VecDeque<ProjectPath>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
diagnostic_search_range: Range<Point>,
|
||||
inputs: EditPredictionModelInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let debug_info = self.debug_info.clone();
|
||||
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
let full_path: Arc<Path> = inputs
|
||||
.snapshot
|
||||
.file()
|
||||
.map(|file| file.full_path(cx))
|
||||
.unwrap_or_else(|| "untitled".into())
|
||||
.into();
|
||||
|
||||
let project_file = project::File::from_dyn(snapshot.file());
|
||||
let project_file = project::File::from_dyn(inputs.snapshot.file());
|
||||
let repo_name = project_file
|
||||
.map(|file| file.worktree.read(cx).root_name_str())
|
||||
.unwrap_or("untitled")
|
||||
.into();
|
||||
let offset = position.to_offset(&snapshot);
|
||||
let offset = inputs.position.to_offset(&inputs.snapshot);
|
||||
|
||||
let recent_buffers = recent_paths.iter().cloned();
|
||||
let recent_buffers = inputs.recent_paths.iter().cloned();
|
||||
let http_client = cx.http_client();
|
||||
|
||||
let recent_buffer_snapshots = recent_buffers
|
||||
.filter_map(|project_path| {
|
||||
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
|
||||
if active_buffer == &buffer {
|
||||
let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
|
||||
if inputs.buffer == buffer {
|
||||
None
|
||||
} else {
|
||||
Some(buffer.read(cx).snapshot())
|
||||
@@ -86,14 +75,13 @@ impl SweepAi {
|
||||
.take(3)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let text = snapshot.text();
|
||||
let text = inputs.snapshot.text();
|
||||
|
||||
let mut recent_changes = String::new();
|
||||
for event in &events {
|
||||
for event in &inputs.events {
|
||||
write_event(event.as_ref(), &mut recent_changes).unwrap();
|
||||
}
|
||||
|
||||
@@ -122,20 +110,23 @@ impl SweepAi {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let retrieval_chunks = related_files
|
||||
let retrieval_chunks = inputs
|
||||
.related_files
|
||||
.iter()
|
||||
.flat_map(|related_file| {
|
||||
related_file.excerpts.iter().map(|excerpt| FileChunk {
|
||||
file_path: related_file.path.path.as_unix_str().to_string(),
|
||||
start_line: excerpt.point_range.start.row as usize,
|
||||
end_line: excerpt.point_range.end.row as usize,
|
||||
file_path: related_file.path.to_string_lossy().to_string(),
|
||||
start_line: excerpt.row_range.start as usize,
|
||||
end_line: excerpt.row_range.end as usize,
|
||||
content: excerpt.text.to_string(),
|
||||
timestamp: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
|
||||
let diagnostic_entries = inputs
|
||||
.snapshot
|
||||
.diagnostics_in_range(inputs.diagnostic_search_range, false);
|
||||
let mut diagnostic_content = String::new();
|
||||
let mut diagnostic_count = 0;
|
||||
|
||||
@@ -195,21 +186,14 @@ impl SweepAi {
|
||||
serde_json::to_writer(writer, &request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(0),
|
||||
text: request_body.file_contents.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
let ep_inputs = zeta_prompt::ZetaPromptInput {
|
||||
events: inputs.events,
|
||||
related_files: inputs.related_files.clone(),
|
||||
cursor_path: full_path.clone(),
|
||||
cursor_excerpt: request_body.file_contents.into(),
|
||||
// we actually don't know
|
||||
editable_range_in_excerpt: 0..inputs.snapshot.len(),
|
||||
cursor_offset_in_excerpt: request_body.cursor_position,
|
||||
};
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
@@ -237,15 +221,20 @@ impl SweepAi {
|
||||
|
||||
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
|
||||
|
||||
let old_text = snapshot
|
||||
let old_text = inputs
|
||||
.snapshot
|
||||
.text_for_range(response.start_index..response.end_index)
|
||||
.collect::<String>();
|
||||
let edits = language::text_diff(&old_text, &response.completion)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(response.start_index + range.start)
|
||||
..snapshot.anchor_before(response.start_index + range.end),
|
||||
inputs
|
||||
.snapshot
|
||||
.anchor_after(response.start_index + range.start)
|
||||
..inputs
|
||||
.snapshot
|
||||
.anchor_before(response.start_index + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
@@ -254,13 +243,13 @@ impl SweepAi {
|
||||
anyhow::Ok((
|
||||
response.autocomplete_id,
|
||||
edits,
|
||||
snapshot,
|
||||
inputs.snapshot,
|
||||
response_received_at,
|
||||
inputs,
|
||||
ep_inputs,
|
||||
))
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
let buffer = inputs.buffer.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
|
||||
@@ -403,12 +392,9 @@ struct AdditionalCompletion {
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
fn write_event(
|
||||
event: &cloud_llm_client::predict_edits_v3::Event,
|
||||
f: &mut impl fmt::Write,
|
||||
) -> fmt::Result {
|
||||
fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
|
||||
match event {
|
||||
cloud_llm_client::predict_edits_v3::Event::BufferChange {
|
||||
zeta_prompt::Event::BufferChange {
|
||||
old_path,
|
||||
path,
|
||||
diff,
|
||||
|
||||
@@ -14,87 +14,48 @@ use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use gpui::AsyncApp;
|
||||
use gpui::Entity;
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
|
||||
use project::Project;
|
||||
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
|
||||
use project::{Project, ProjectPath};
|
||||
use util::paths::PathStyle;
|
||||
use util::rel_path::RelPath;
|
||||
|
||||
pub async fn parse_diff<'a>(
|
||||
diff_str: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
let mut edited_buffer = None;
|
||||
let mut edits = Vec::new();
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk {
|
||||
path: file_path,
|
||||
hunk,
|
||||
} => {
|
||||
let (buffer, ranges) = match edited_buffer {
|
||||
None => {
|
||||
edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
|
||||
edited_buffer
|
||||
.as_ref()
|
||||
.context("Model tried to edit a file that wasn't included")?
|
||||
}
|
||||
Some(ref current) => current,
|
||||
};
|
||||
|
||||
edits.extend(
|
||||
resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
|
||||
.with_context(|| format!("Diff:\n{diff_str}"))?,
|
||||
);
|
||||
}
|
||||
DiffEvent::FileEnd { renamed_to } => {
|
||||
let (buffer, _) = edited_buffer
|
||||
.take()
|
||||
.context("Got a FileEnd event before an Hunk event")?;
|
||||
|
||||
if renamed_to.is_some() {
|
||||
anyhow::bail!("edit predictions cannot rename files");
|
||||
}
|
||||
|
||||
if diff.next()?.is_some() {
|
||||
anyhow::bail!("Edited more than one file");
|
||||
}
|
||||
|
||||
return Ok((buffer, edits));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("No EOF"))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffer>>);
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
|
||||
|
||||
#[must_use]
|
||||
pub async fn apply_diff<'a>(
|
||||
diff_str: &'a str,
|
||||
pub async fn apply_diff(
|
||||
diff_str: &str,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<OpenedBuffers<'a>> {
|
||||
) -> Result<OpenedBuffers> {
|
||||
let mut included_files = HashMap::default();
|
||||
|
||||
let worktree_id = project.read_with(cx, |project, cx| {
|
||||
anyhow::Ok(
|
||||
project
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.context("no worktrees")?
|
||||
.read(cx)
|
||||
.id(),
|
||||
)
|
||||
})??;
|
||||
|
||||
for line in diff_str.lines() {
|
||||
let diff_line = DiffLine::parse(line);
|
||||
|
||||
if let DiffLine::OldPath { path } = diff_line {
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path =
|
||||
project
|
||||
.find_project_path(path.as_ref(), cx)
|
||||
.with_context(|| {
|
||||
format!("Failed to find worktree for new path: {}", path)
|
||||
})?;
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc(),
|
||||
};
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
|
||||
included_files.insert(path, buffer);
|
||||
included_files.insert(path.to_string(), buffer);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +74,7 @@ pub async fn apply_diff<'a>(
|
||||
let (buffer, ranges) = match current_file {
|
||||
None => {
|
||||
let buffer = included_files
|
||||
.get_mut(&file_path)
|
||||
.get_mut(file_path.as_ref())
|
||||
.expect("Opened all files in diff");
|
||||
|
||||
current_file = Some((buffer, ranges.as_slice()));
|
||||
@@ -167,6 +128,29 @@ pub async fn apply_diff<'a>(
|
||||
Ok(OpenedBuffers(included_files))
|
||||
}
|
||||
|
||||
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
|
||||
let mut text = text.to_string();
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk { hunk, .. } => {
|
||||
let hunk_offset = text
|
||||
.find(&hunk.context)
|
||||
.ok_or_else(|| anyhow!("couldn't resolve hunk {:?}", hunk.context))?;
|
||||
for edit in hunk.edits.iter().rev() {
|
||||
let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
|
||||
text.replace_range(range, &edit.text);
|
||||
}
|
||||
}
|
||||
DiffEvent::FileEnd { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
struct PatchFile<'a> {
|
||||
old_path: Cow<'a, str>,
|
||||
new_path: Cow<'a, str>,
|
||||
@@ -492,7 +476,6 @@ mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::Point;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
@@ -754,38 +737,38 @@ mod tests {
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
--- a/file1
|
||||
+++ b/file1
|
||||
one
|
||||
two
|
||||
-three
|
||||
+3
|
||||
four
|
||||
five
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
--- a/file1
|
||||
+++ b/file1
|
||||
3
|
||||
-four
|
||||
-five
|
||||
+4
|
||||
+5
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
--- a/file1
|
||||
+++ b/file1
|
||||
-one
|
||||
-two
|
||||
3
|
||||
4
|
||||
--- a/root/file2
|
||||
+++ b/root/file2
|
||||
--- a/file2
|
||||
+++ b/file2
|
||||
+5
|
||||
six
|
||||
--- a/root/file2
|
||||
+++ b/root/file2
|
||||
--- a/file2
|
||||
+++ b/file2
|
||||
seven
|
||||
+7.5
|
||||
eight
|
||||
--- a/root/file2
|
||||
+++ b/root/file2
|
||||
--- a/file2
|
||||
+++ b/file2
|
||||
ten
|
||||
+11
|
||||
"#};
|
||||
@@ -817,137 +800,6 @@ mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
"# };
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/file1"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one
|
||||
two
|
||||
-three
|
||||
+3
|
||||
four
|
||||
five
|
||||
"#};
|
||||
|
||||
let final_text = indoc! {r#"
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
one
|
||||
two
|
||||
3
|
||||
four
|
||||
five
|
||||
"#};
|
||||
|
||||
apply_diff(diff, &project, &mut cx.to_async())
|
||||
.await
|
||||
.expect_err("Non-unique edits should fail");
|
||||
|
||||
let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
|
||||
..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
|
||||
|
||||
let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit(edits, None, cx);
|
||||
assert_eq!(buffer.text(), final_text);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one two three four
|
||||
five six seven eight
|
||||
nine ten eleven twelve
|
||||
"# };
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/file1"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one two three four
|
||||
-five six seven eight
|
||||
+five SIX seven eight!
|
||||
nine ten eleven twelve
|
||||
"#};
|
||||
|
||||
let (buffer, edits) = parse_diff(diff, |_path| {
|
||||
Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let edits = edits
|
||||
.into_iter()
|
||||
.map(|(range, text)| (range.to_point(&buffer), text))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
edits,
|
||||
&[
|
||||
(Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
|
||||
(Point::new(1, 20)..Point::new(1, 20), "!".into())
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
@@ -985,8 +837,8 @@ mod tests {
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
--- a/file1
|
||||
+++ b/file1
|
||||
one
|
||||
two
|
||||
-three
|
||||
|
||||
@@ -1,637 +0,0 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
|
||||
use std::{cmp, ops::Range, path::Path, sync::Arc};
|
||||
|
||||
const EDITS_TAG_NAME: &'static str = "edits";
|
||||
const OLD_TEXT_TAG_NAME: &'static str = "old_text";
|
||||
const NEW_TEXT_TAG_NAME: &'static str = "new_text";
|
||||
const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
|
||||
|
||||
pub async fn parse_xml_edits<'a>(
|
||||
input: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
|
||||
parse_xml_edits_inner(input, get_buffer)
|
||||
.await
|
||||
.with_context(|| format!("Failed to parse XML edits:\n{input}"))
|
||||
}
|
||||
|
||||
async fn parse_xml_edits_inner<'a>(
|
||||
input: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
|
||||
let xml_edits = extract_xml_replacements(input)?;
|
||||
|
||||
let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
|
||||
.with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
|
||||
|
||||
let mut all_edits = vec![];
|
||||
for (old_text, new_text) in xml_edits.replacements {
|
||||
let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
|
||||
let matched_old_text = buffer
|
||||
.text_for_range(match_range.clone())
|
||||
.collect::<String>();
|
||||
let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
|
||||
all_edits.extend(
|
||||
edits_within_hunk
|
||||
.into_iter()
|
||||
.map(move |(inner_range, inner_text)| {
|
||||
(
|
||||
buffer.anchor_after(match_range.start + inner_range.start)
|
||||
..buffer.anchor_before(match_range.start + inner_range.end),
|
||||
inner_text,
|
||||
)
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
Ok((buffer, all_edits))
|
||||
}
|
||||
|
||||
fn fuzzy_match_in_ranges(
|
||||
old_text: &str,
|
||||
buffer: &BufferSnapshot,
|
||||
context_ranges: &[Range<Anchor>],
|
||||
) -> Result<Range<usize>> {
|
||||
let mut state = FuzzyMatcher::new(buffer, old_text);
|
||||
let mut best_match = None;
|
||||
let mut tie_match_range = None;
|
||||
|
||||
for range in context_ranges {
|
||||
let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
|
||||
match (best_match_cost, state.match_range(range.to_offset(buffer))) {
|
||||
(Some(lowest_cost), Some((new_cost, new_range))) => {
|
||||
if new_cost == lowest_cost {
|
||||
tie_match_range = Some(new_range);
|
||||
} else if new_cost < lowest_cost {
|
||||
tie_match_range.take();
|
||||
best_match = Some((new_cost, new_range));
|
||||
}
|
||||
}
|
||||
(None, Some(new_match)) => {
|
||||
best_match = Some(new_match);
|
||||
}
|
||||
(None, None) | (Some(_), None) => {}
|
||||
};
|
||||
}
|
||||
|
||||
if let Some((_, best_match_range)) = best_match {
|
||||
if let Some(tie_match_range) = tie_match_range {
|
||||
anyhow::bail!(
|
||||
"Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
|
||||
best_match_range.clone(),
|
||||
buffer.text_for_range(best_match_range).collect::<String>(),
|
||||
tie_match_range.clone(),
|
||||
buffer.text_for_range(tie_match_range).collect::<String>()
|
||||
);
|
||||
}
|
||||
return Ok(best_match_range);
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
|
||||
old_text,
|
||||
context_ranges
|
||||
.iter()
|
||||
.map(|range| buffer.text_for_range(range.clone()).collect::<String>())
|
||||
.collect::<Vec<String>>()
|
||||
.join("```\n```")
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct XmlEdits<'a> {
|
||||
file_path: &'a str,
|
||||
/// Vec of (old_text, new_text) pairs
|
||||
replacements: Vec<(&'a str, &'a str)>,
|
||||
}
|
||||
|
||||
fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
|
||||
let mut cursor = 0;
|
||||
|
||||
let (edits_body_start, edits_attrs) =
|
||||
find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
|
||||
|
||||
let file_path = edits_attrs
|
||||
.trim_start()
|
||||
.strip_prefix("path")
|
||||
.context("no path attribute on edits tag")?
|
||||
.trim_end()
|
||||
.strip_prefix('=')
|
||||
.context("no value for path attribute")?
|
||||
.trim()
|
||||
.trim_start_matches('"')
|
||||
.trim_end_matches('"');
|
||||
|
||||
cursor = edits_body_start;
|
||||
let mut edits_list = Vec::new();
|
||||
|
||||
while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
|
||||
let old_body_end = find_tag_close(input, &mut cursor)?;
|
||||
let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
|
||||
|
||||
let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
|
||||
.context("no new_text tag following old_text")?;
|
||||
let new_body_end = find_tag_close(input, &mut cursor)?;
|
||||
let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
|
||||
|
||||
edits_list.push((old_text, new_text));
|
||||
}
|
||||
|
||||
Ok(XmlEdits {
|
||||
file_path,
|
||||
replacements: edits_list,
|
||||
})
|
||||
}
|
||||
|
||||
/// Trims a single leading and trailing newline
|
||||
fn trim_surrounding_newlines(input: &str) -> &str {
|
||||
let start = input.strip_prefix('\n').unwrap_or(input);
|
||||
let end = start.strip_suffix('\n').unwrap_or(start);
|
||||
end
|
||||
}
|
||||
|
||||
fn find_tag_open<'a>(
|
||||
input: &'a str,
|
||||
cursor: &mut usize,
|
||||
expected_tag: &str,
|
||||
) -> Result<Option<(usize, &'a str)>> {
|
||||
let mut search_pos = *cursor;
|
||||
|
||||
while search_pos < input.len() {
|
||||
let Some(tag_start) = input[search_pos..].find("<") else {
|
||||
break;
|
||||
};
|
||||
let tag_start = search_pos + tag_start;
|
||||
if !input[tag_start + 1..].starts_with(expected_tag) {
|
||||
search_pos = search_pos + tag_start + 1;
|
||||
continue;
|
||||
};
|
||||
|
||||
let after_tag_name = tag_start + expected_tag.len() + 1;
|
||||
let close_bracket = input[after_tag_name..]
|
||||
.find('>')
|
||||
.with_context(|| format!("missing > after <{}", expected_tag))?;
|
||||
let attrs_end = after_tag_name + close_bracket;
|
||||
let body_start = attrs_end + 1;
|
||||
|
||||
let attributes = input[after_tag_name..attrs_end].trim();
|
||||
*cursor = body_start;
|
||||
|
||||
return Ok(Some((body_start, attributes)));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
|
||||
let mut depth = 1;
|
||||
let mut search_pos = *cursor;
|
||||
|
||||
while search_pos < input.len() && depth > 0 {
|
||||
let Some(bracket_offset) = input[search_pos..].find('<') else {
|
||||
break;
|
||||
};
|
||||
let bracket_pos = search_pos + bracket_offset;
|
||||
|
||||
if input[bracket_pos..].starts_with("</")
|
||||
&& let Some(close_end) = input[bracket_pos + 2..].find('>')
|
||||
{
|
||||
let close_start = bracket_pos + 2;
|
||||
let tag_name = input[close_start..close_start + close_end].trim();
|
||||
|
||||
if XML_TAGS.contains(&tag_name) {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
*cursor = close_start + close_end + 1;
|
||||
return Ok(bracket_pos);
|
||||
}
|
||||
}
|
||||
search_pos = close_start + close_end + 1;
|
||||
continue;
|
||||
} else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
|
||||
let close_bracket_pos = bracket_pos + close_bracket_offset;
|
||||
let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
|
||||
if XML_TAGS.contains(&tag_name) {
|
||||
depth += 1;
|
||||
}
|
||||
}
|
||||
|
||||
search_pos = bracket_pos + 1;
|
||||
}
|
||||
|
||||
anyhow::bail!("no closing tag found")
|
||||
}
|
||||
|
||||
const REPLACEMENT_COST: u32 = 1;
|
||||
const INSERTION_COST: u32 = 3;
|
||||
const DELETION_COST: u32 = 10;
|
||||
|
||||
/// A fuzzy matcher that can process text chunks incrementally
|
||||
/// and return the best match found so far at each step.
|
||||
struct FuzzyMatcher<'a> {
|
||||
snapshot: &'a BufferSnapshot,
|
||||
query_lines: Vec<&'a str>,
|
||||
matrix: SearchMatrix,
|
||||
}
|
||||
|
||||
impl<'a> FuzzyMatcher<'a> {
|
||||
fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
|
||||
let query_lines = old_text.lines().collect();
|
||||
Self {
|
||||
snapshot,
|
||||
query_lines,
|
||||
matrix: SearchMatrix::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
|
||||
let point_range = range.to_point(&self.snapshot);
|
||||
let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
|
||||
|
||||
self.matrix
|
||||
.reset(self.query_lines.len() + 1, buffer_line_count + 1);
|
||||
let query_line_count = self.query_lines.len();
|
||||
|
||||
for row in 0..query_line_count {
|
||||
let query_line = self.query_lines[row].trim();
|
||||
let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
|
||||
|
||||
self.matrix.set(
|
||||
row + 1,
|
||||
0,
|
||||
SearchState::new(leading_deletion_cost, SearchDirection::Up),
|
||||
);
|
||||
|
||||
let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
|
||||
|
||||
let mut col = 0;
|
||||
while let Some(buffer_line) = buffer_lines.next() {
|
||||
let buffer_line = buffer_line.trim();
|
||||
let up = SearchState::new(
|
||||
self.matrix
|
||||
.get(row, col + 1)
|
||||
.cost
|
||||
.saturating_add(DELETION_COST),
|
||||
SearchDirection::Up,
|
||||
);
|
||||
let left = SearchState::new(
|
||||
self.matrix
|
||||
.get(row + 1, col)
|
||||
.cost
|
||||
.saturating_add(INSERTION_COST),
|
||||
SearchDirection::Left,
|
||||
);
|
||||
let diagonal = SearchState::new(
|
||||
if query_line == buffer_line {
|
||||
self.matrix.get(row, col).cost
|
||||
} else if fuzzy_eq(query_line, buffer_line) {
|
||||
self.matrix.get(row, col).cost + REPLACEMENT_COST
|
||||
} else {
|
||||
self.matrix
|
||||
.get(row, col)
|
||||
.cost
|
||||
.saturating_add(DELETION_COST + INSERTION_COST)
|
||||
},
|
||||
SearchDirection::Diagonal,
|
||||
);
|
||||
self.matrix
|
||||
.set(row + 1, col + 1, up.min(left).min(diagonal));
|
||||
col += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Find all matches with the best cost
|
||||
let mut best_cost = u32::MAX;
|
||||
let mut matches_with_best_cost = Vec::new();
|
||||
|
||||
for col in 1..=buffer_line_count {
|
||||
let cost = self.matrix.get(query_line_count, col).cost;
|
||||
if cost < best_cost {
|
||||
best_cost = cost;
|
||||
matches_with_best_cost.clear();
|
||||
matches_with_best_cost.push(col as u32);
|
||||
} else if cost == best_cost {
|
||||
matches_with_best_cost.push(col as u32);
|
||||
}
|
||||
}
|
||||
|
||||
// Find ranges for the matches
|
||||
for &match_end_col in &matches_with_best_cost {
|
||||
let mut matched_lines = 0;
|
||||
let mut query_row = query_line_count;
|
||||
let mut match_start_col = match_end_col;
|
||||
while query_row > 0 && match_start_col > 0 {
|
||||
let current = self.matrix.get(query_row, match_start_col as usize);
|
||||
match current.direction {
|
||||
SearchDirection::Diagonal => {
|
||||
query_row -= 1;
|
||||
match_start_col -= 1;
|
||||
matched_lines += 1;
|
||||
}
|
||||
SearchDirection::Up => {
|
||||
query_row -= 1;
|
||||
}
|
||||
SearchDirection::Left => {
|
||||
match_start_col -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let buffer_row_start = match_start_col + point_range.start.row;
|
||||
let buffer_row_end = match_end_col + point_range.start.row;
|
||||
|
||||
let matched_buffer_row_count = buffer_row_end - buffer_row_start;
|
||||
let matched_ratio = matched_lines as f32
|
||||
/ (matched_buffer_row_count as f32).max(query_line_count as f32);
|
||||
if matched_ratio >= 0.8 {
|
||||
let buffer_start_ix = self
|
||||
.snapshot
|
||||
.point_to_offset(Point::new(buffer_row_start, 0));
|
||||
let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
|
||||
buffer_row_end - 1,
|
||||
self.snapshot.line_len(buffer_row_end - 1),
|
||||
));
|
||||
return Some((best_cost, buffer_start_ix..buffer_end_ix));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn fuzzy_eq(left: &str, right: &str) -> bool {
|
||||
const THRESHOLD: f64 = 0.8;
|
||||
|
||||
let min_levenshtein = left.len().abs_diff(right.len());
|
||||
let min_normalized_levenshtein =
|
||||
1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
|
||||
if min_normalized_levenshtein < THRESHOLD {
|
||||
return false;
|
||||
}
|
||||
|
||||
strsim::normalized_levenshtein(left, right) >= THRESHOLD
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
enum SearchDirection {
|
||||
Up,
|
||||
Left,
|
||||
Diagonal,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct SearchState {
|
||||
cost: u32,
|
||||
direction: SearchDirection,
|
||||
}
|
||||
|
||||
impl SearchState {
|
||||
fn new(cost: u32, direction: SearchDirection) -> Self {
|
||||
Self { cost, direction }
|
||||
}
|
||||
}
|
||||
|
||||
struct SearchMatrix {
|
||||
cols: usize,
|
||||
rows: usize,
|
||||
data: Vec<SearchState>,
|
||||
}
|
||||
|
||||
impl SearchMatrix {
|
||||
fn new(cols: usize) -> Self {
|
||||
SearchMatrix {
|
||||
cols,
|
||||
rows: 0,
|
||||
data: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self, rows: usize, cols: usize) {
|
||||
self.rows = rows;
|
||||
self.cols = cols;
|
||||
self.data
|
||||
.fill(SearchState::new(0, SearchDirection::Diagonal));
|
||||
self.data.resize(
|
||||
self.rows * self.cols,
|
||||
SearchState::new(0, SearchDirection::Diagonal),
|
||||
);
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> SearchState {
|
||||
debug_assert!(row < self.rows);
|
||||
debug_assert!(col < self.cols);
|
||||
self.data[row * self.cols + col]
|
||||
}
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, state: SearchState) {
|
||||
debug_assert!(row < self.rows && col < self.cols);
|
||||
self.data[row * self.cols + col] = state;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::Point;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
old content
|
||||
</old_text>
|
||||
<new_text>
|
||||
new content
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "old content");
|
||||
assert_eq!(result.replacements[0].1, "new content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_wrong_closing_tags() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
old content
|
||||
</new_text>
|
||||
<new_text>
|
||||
new content
|
||||
</old_text>
|
||||
</ edits >
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "old content");
|
||||
assert_eq!(result.replacements[0].1, "new content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_xml_like_content() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="component.tsx">
|
||||
<old_text>
|
||||
<foo><bar></bar></foo>
|
||||
</old_text>
|
||||
<new_text>
|
||||
<foo><bar><baz></baz></bar></foo>
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "component.tsx");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
|
||||
assert_eq!(
|
||||
result.replacements[0].1,
|
||||
"<foo><bar><baz></baz></bar></foo>"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_conflicting_content() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="component.tsx">
|
||||
<old_text>
|
||||
<new_text></new_text>
|
||||
</old_text>
|
||||
<new_text>
|
||||
<old_text></old_text>
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "component.tsx");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "<new_text></new_text>");
|
||||
assert_eq!(result.replacements[0].1, "<old_text></old_text>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_multiple_pairs() {
|
||||
let input = indoc! {r#"
|
||||
Some reasoning before edits. Lots of thinking going on here
|
||||
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
first old
|
||||
</old_text>
|
||||
<new_text>
|
||||
first new
|
||||
</new_text>
|
||||
<old_text>
|
||||
second old
|
||||
</edits>
|
||||
<new_text>
|
||||
second new
|
||||
</old_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 2);
|
||||
assert_eq!(result.replacements[0].0, "first old");
|
||||
assert_eq!(result.replacements[0].1, "first new");
|
||||
assert_eq!(result.replacements[1].0, "second old");
|
||||
assert_eq!(result.replacements[1].1, "second new");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_unexpected_eof() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
first old
|
||||
</
|
||||
"#};
|
||||
|
||||
extract_xml_replacements(input).expect_err("Unexpected end of file");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_parse_xml_edits(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one two three four
|
||||
five six seven eight
|
||||
nine ten eleven twelve
|
||||
thirteen fourteen fifteen
|
||||
sixteen seventeen eighteen
|
||||
"#};
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/file1"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let edits = indoc! {r#"
|
||||
<edits path="root/file1">
|
||||
<old_text>
|
||||
nine ten eleven twelve
|
||||
</old_text>
|
||||
<new_text>
|
||||
nine TEN eleven twelve!
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
|
||||
let (buffer, edits) = parse_xml_edits(edits, |_path| {
|
||||
Some((&buffer_snapshot, included_ranges.as_slice()))
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let edits = edits
|
||||
.into_iter()
|
||||
.map(|(range, text)| (range.to_point(&buffer), text))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
edits,
|
||||
&[
|
||||
(Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
|
||||
(Point::new(2, 22)..Point::new(2, 22), "!".into())
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
});
|
||||
|
||||
FakeFs::new(cx.background_executor.clone())
|
||||
}
|
||||
}
|
||||
@@ -125,14 +125,15 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
|
||||
&& let BufferEditPrediction::Local { prediction } = current
|
||||
&& prediction.interpolate(buffer.read(cx)).is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
self.store.update(cx, |store, cx| {
|
||||
if let Some(current) =
|
||||
store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
|
||||
&& let BufferEditPrediction::Local { prediction } = current
|
||||
&& prediction.interpolate(buffer.read(cx)).is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
store.refresh_context(&self.project, &buffer, cursor_position, cx);
|
||||
store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
|
||||
});
|
||||
@@ -171,69 +172,68 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
|
||||
cursor_position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<edit_prediction_types::EditPrediction> {
|
||||
let prediction =
|
||||
self.store
|
||||
.read(cx)
|
||||
.current_prediction_for_buffer(buffer, &self.project, cx)?;
|
||||
self.store.update(cx, |store, cx| {
|
||||
let prediction =
|
||||
store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
|
||||
|
||||
let prediction = match prediction {
|
||||
BufferEditPrediction::Local { prediction } => prediction,
|
||||
BufferEditPrediction::Jump { prediction } => {
|
||||
return Some(edit_prediction_types::EditPrediction::Jump {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
snapshot: prediction.snapshot.clone(),
|
||||
target: prediction.edits.first().unwrap().0.start,
|
||||
});
|
||||
}
|
||||
};
|
||||
let prediction = match prediction {
|
||||
BufferEditPrediction::Local { prediction } => prediction,
|
||||
BufferEditPrediction::Jump { prediction } => {
|
||||
return Some(edit_prediction_types::EditPrediction::Jump {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
snapshot: prediction.snapshot.clone(),
|
||||
target: prediction.edits.first().unwrap().0.start,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let buffer = buffer.read(cx);
|
||||
let snapshot = buffer.snapshot();
|
||||
let buffer = buffer.read(cx);
|
||||
let snapshot = buffer.snapshot();
|
||||
|
||||
let Some(edits) = prediction.interpolate(&snapshot) else {
|
||||
self.store.update(cx, |store, _cx| {
|
||||
let Some(edits) = prediction.interpolate(&snapshot) else {
|
||||
store.reject_current_prediction(
|
||||
EditPredictionRejectReason::InterpolatedEmpty,
|
||||
&self.project,
|
||||
);
|
||||
});
|
||||
return None;
|
||||
};
|
||||
return None;
|
||||
};
|
||||
|
||||
let cursor_row = cursor_position.to_point(&snapshot).row;
|
||||
let (closest_edit_ix, (closest_edit_range, _)) =
|
||||
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
|
||||
let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
|
||||
cmp::min(distance_from_start, distance_from_end)
|
||||
})?;
|
||||
let cursor_row = cursor_position.to_point(&snapshot).row;
|
||||
let (closest_edit_ix, (closest_edit_range, _)) =
|
||||
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
|
||||
let distance_from_start =
|
||||
cursor_row.abs_diff(range.start.to_point(&snapshot).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
|
||||
cmp::min(distance_from_start, distance_from_end)
|
||||
})?;
|
||||
|
||||
let mut edit_start_ix = closest_edit_ix;
|
||||
for (range, _) in edits[..edit_start_ix].iter().rev() {
|
||||
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
|
||||
- range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_start_ix -= 1;
|
||||
} else {
|
||||
break;
|
||||
let mut edit_start_ix = closest_edit_ix;
|
||||
for (range, _) in edits[..edit_start_ix].iter().rev() {
|
||||
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
|
||||
- range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_start_ix -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut edit_end_ix = closest_edit_ix + 1;
|
||||
for (range, _) in &edits[edit_end_ix..] {
|
||||
let distance_from_closest_edit =
|
||||
range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_end_ix += 1;
|
||||
} else {
|
||||
break;
|
||||
let mut edit_end_ix = closest_edit_ix + 1;
|
||||
for (range, _) in &edits[edit_end_ix..] {
|
||||
let distance_from_closest_edit = range.start.to_point(buffer).row
|
||||
- closest_edit_range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_end_ix += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(edit_prediction_types::EditPrediction::Local {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(prediction.edit_preview.clone()),
|
||||
Some(edit_prediction_types::EditPrediction::Local {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(prediction.edit_preview.clone()),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
|
||||
|
||||
use crate::{
|
||||
EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
|
||||
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
|
||||
EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
|
||||
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
|
||||
prediction::{EditPredictionInputs, EditPredictionResult},
|
||||
prediction::EditPredictionResult,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::{
|
||||
PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
|
||||
predict_edits_v3::Event,
|
||||
};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
|
||||
use language::{
|
||||
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
|
||||
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
|
||||
};
|
||||
use project::{Project, ProjectPath};
|
||||
use release_channel::AppVersion;
|
||||
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
|
||||
use zeta_prompt::{Event, ZetaPromptInput};
|
||||
|
||||
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
|
||||
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
|
||||
@@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500;
|
||||
|
||||
pub(crate) fn request_prediction_with_zeta1(
|
||||
store: &mut EditPredictionStore,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
EditPredictionModelInput {
|
||||
project,
|
||||
buffer,
|
||||
snapshot,
|
||||
position,
|
||||
events,
|
||||
trigger,
|
||||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let buffer = buffer.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
|
||||
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
|
||||
let can_collect_file = store.can_collect_file(project, file, cx);
|
||||
let can_collect_file = store.can_collect_file(&project, file, cx);
|
||||
let git_info = if can_collect_file {
|
||||
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
|
||||
git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
)
|
||||
.await;
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
let context_start_offset = context_range.start.to_offset(&snapshot);
|
||||
let editable_offset_range = editable_range.to_offset(&snapshot);
|
||||
|
||||
let inputs = ZetaPromptInput {
|
||||
events: included_events.into(),
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
|
||||
text: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
related_files: vec![].into(),
|
||||
cursor_path: full_path,
|
||||
cursor_excerpt: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
editable_range_in_excerpt: (editable_range.start - context_start_offset)
|
||||
..(editable_offset_range.end - context_start_offset),
|
||||
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
|
||||
};
|
||||
|
||||
// let response = perform_predict_edits(PerformPredictEditsParams {
|
||||
// client,
|
||||
// llm_token,
|
||||
// app_version,
|
||||
// body,
|
||||
// })
|
||||
// .await;
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionStarted(
|
||||
EditPredictionStartedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
prompt: Some(serde_json::to_string(&inputs).unwrap()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let (response, usage) = match response {
|
||||
Ok(response) => response,
|
||||
@@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
.ok();
|
||||
}
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionFinished(
|
||||
EditPredictionFinishedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
model_output: Some(response.output_excerpt.clone()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let edit_prediction = process_completion_response(
|
||||
response,
|
||||
buffer,
|
||||
@@ -226,7 +242,7 @@ fn process_completion_response(
|
||||
buffer: Entity<Buffer>,
|
||||
snapshot: &BufferSnapshot,
|
||||
editable_range: Range<usize>,
|
||||
inputs: EditPredictionInputs,
|
||||
inputs: ZetaPromptInput,
|
||||
buffer_snapshotted_at: Instant,
|
||||
received_response_at: Instant,
|
||||
cx: &AsyncApp,
|
||||
|
||||
@@ -1,48 +1,41 @@
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
use crate::EvalCacheEntryKind;
|
||||
use crate::open_ai_response::text_from_response;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
use crate::{
|
||||
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
|
||||
EditPredictionRequestedDebugEvent, EditPredictionStore,
|
||||
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
|
||||
EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
|
||||
};
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
|
||||
use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use edit_prediction_context::{EditPredictionExcerpt, Line};
|
||||
use edit_prediction_context::{RelatedExcerpt, RelatedFile};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{Entity, Task, prelude::*};
|
||||
use language::{Anchor, BufferSnapshot};
|
||||
use language::{Buffer, Point, ToOffset as _, ToPoint};
|
||||
use project::{Project, ProjectItem as _};
|
||||
use anyhow::{Result, anyhow};
|
||||
use cloud_llm_client::EditPredictionRejectReason;
|
||||
use gpui::{Task, prelude::*};
|
||||
use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
|
||||
use release_channel::AppVersion;
|
||||
use std::{
|
||||
env,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{path::Path, sync::Arc, time::Instant};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
use zeta_prompt::format_zeta_prompt;
|
||||
|
||||
const MAX_CONTEXT_TOKENS: usize = 150;
|
||||
const MAX_REWRITE_TOKENS: usize = 350;
|
||||
|
||||
pub fn request_prediction_with_zeta2(
|
||||
store: &mut EditPredictionStore,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
active_snapshot: BufferSnapshot,
|
||||
position: Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
mut included_files: Vec<RelatedFile>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
EditPredictionModelInput {
|
||||
buffer,
|
||||
snapshot,
|
||||
position,
|
||||
related_files,
|
||||
events,
|
||||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let options = store.options.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let Some((excerpt_path, active_project_path)) = active_snapshot
|
||||
let Some(excerpt_path) = snapshot
|
||||
.file()
|
||||
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
|
||||
.zip(active_buffer.read(cx).project_path(cx))
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("No file path for excerpt")));
|
||||
};
|
||||
@@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2(
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
let debug_tx = store.debug_tx.clone();
|
||||
|
||||
let file = active_buffer.read(cx).file();
|
||||
|
||||
let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
|
||||
|
||||
// TODO data collection
|
||||
let can_collect_data = file
|
||||
.as_ref()
|
||||
.map_or(false, |file| store.can_collect_file(project, file, cx));
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
let eval_cache = store.eval_cache.clone();
|
||||
|
||||
let request_task = cx.background_spawn({
|
||||
let active_buffer = active_buffer.clone();
|
||||
async move {
|
||||
let cursor_offset = position.to_offset(&active_snapshot);
|
||||
let cursor_point = cursor_offset.to_point(&active_snapshot);
|
||||
|
||||
let before_retrieval = Instant::now();
|
||||
|
||||
let excerpt_options = options.context;
|
||||
|
||||
let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
|
||||
cursor_point,
|
||||
&active_snapshot,
|
||||
&excerpt_options,
|
||||
) else {
|
||||
return Ok((None, None));
|
||||
};
|
||||
|
||||
let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
|
||||
..active_snapshot.anchor_before(excerpt.range.end);
|
||||
let related_excerpt = RelatedExcerpt {
|
||||
anchor_range: excerpt_anchor_range.clone(),
|
||||
point_range: Point::new(excerpt.line_range.start.0, 0)
|
||||
..Point::new(excerpt.line_range.end.0, 0),
|
||||
text: active_snapshot.as_rope().slice(excerpt.range),
|
||||
};
|
||||
|
||||
if let Some(buffer_ix) = included_files
|
||||
.iter()
|
||||
.position(|file| file.buffer.entity_id() == active_buffer.entity_id())
|
||||
{
|
||||
let file = &mut included_files[buffer_ix];
|
||||
file.excerpts.push(related_excerpt);
|
||||
file.merge_excerpts();
|
||||
let last_ix = included_files.len() - 1;
|
||||
included_files.swap(buffer_ix, last_ix);
|
||||
} else {
|
||||
let active_file = RelatedFile {
|
||||
path: active_project_path,
|
||||
buffer: active_buffer.downgrade(),
|
||||
excerpts: vec![related_excerpt],
|
||||
max_row: active_snapshot.max_point().row,
|
||||
};
|
||||
included_files.push(active_file);
|
||||
}
|
||||
|
||||
let included_files = included_files
|
||||
.iter()
|
||||
.map(|related_file| predict_edits_v3::RelatedFile {
|
||||
path: Arc::from(related_file.path.path.as_std_path()),
|
||||
max_row: Line(related_file.max_row),
|
||||
excerpts: related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| predict_edits_v3::Excerpt {
|
||||
start_line: Line(excerpt.point_range.start.row),
|
||||
text: excerpt.text.to_string().into(),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cloud_request = predict_edits_v3::PredictEditsRequest {
|
||||
excerpt_path,
|
||||
excerpt: String::new(),
|
||||
excerpt_line_range: Line(0)..Line(0),
|
||||
excerpt_range: 0..0,
|
||||
cursor_point: predict_edits_v3::Point {
|
||||
line: predict_edits_v3::Line(cursor_point.row),
|
||||
column: cursor_point.column,
|
||||
},
|
||||
related_files: included_files,
|
||||
let cursor_offset = position.to_offset(&snapshot);
|
||||
let (editable_offset_range, prompt_input) = zeta2_prompt_input(
|
||||
&snapshot,
|
||||
related_files,
|
||||
events,
|
||||
can_collect_data,
|
||||
debug_info: debug_tx.is_some(),
|
||||
prompt_max_bytes: Some(options.max_prompt_bytes),
|
||||
prompt_format: options.prompt_format,
|
||||
excerpt_parent: None,
|
||||
git_info: None,
|
||||
trigger,
|
||||
};
|
||||
excerpt_path,
|
||||
cursor_offset,
|
||||
);
|
||||
|
||||
let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
included_files: cloud_request.related_files,
|
||||
events: cloud_request.events,
|
||||
cursor_point: cloud_request.cursor_point,
|
||||
cursor_path: cloud_request.excerpt_path,
|
||||
};
|
||||
|
||||
let retrieval_time = Instant::now() - before_retrieval;
|
||||
|
||||
let debug_response_tx = if let Some(debug_tx) = &debug_tx {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
let prompt = format_zeta_prompt(&prompt_input);
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionRequested(
|
||||
EditPredictionRequestedDebugEvent {
|
||||
inputs: inputs.clone(),
|
||||
retrieval_time,
|
||||
buffer: active_buffer.downgrade(),
|
||||
local_prompt: match prompt_result.as_ref() {
|
||||
Ok(prompt) => Ok(prompt.clone()),
|
||||
Err(err) => Err(err.to_string()),
|
||||
},
|
||||
.unbounded_send(DebugEvent::EditPredictionStarted(
|
||||
EditPredictionStartedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
prompt: Some(prompt.clone()),
|
||||
position,
|
||||
response_rx,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
Some(response_tx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
|
||||
if let Some(debug_response_tx) = debug_response_tx {
|
||||
debug_response_tx
|
||||
.send((Err("Request skipped".to_string()), Duration::ZERO))
|
||||
.ok();
|
||||
}
|
||||
anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
|
||||
}
|
||||
|
||||
let prompt = prompt_result?;
|
||||
let generation_params =
|
||||
cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
|
||||
let request = open_ai::Request {
|
||||
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
|
||||
messages: vec![open_ai::RequestMessage::User {
|
||||
@@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2(
|
||||
}],
|
||||
stream: false,
|
||||
max_completion_tokens: None,
|
||||
stop: generation_params.stop.unwrap_or_default(),
|
||||
temperature: generation_params.temperature.or(Some(0.7)),
|
||||
stop: Default::default(),
|
||||
temperature: Default::default(),
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
tools: vec![],
|
||||
@@ -210,81 +90,65 @@ pub fn request_prediction_with_zeta2(
|
||||
|
||||
log::trace!("Sending edit prediction request");
|
||||
|
||||
let before_request = Instant::now();
|
||||
let response = EditPredictionStore::send_raw_llm_request(
|
||||
request,
|
||||
client,
|
||||
llm_token,
|
||||
app_version,
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
eval_cache,
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[cfg(feature = "cli-support")]
|
||||
EvalCacheEntryKind::Prediction,
|
||||
)
|
||||
.await;
|
||||
let received_response_at = Instant::now();
|
||||
let request_time = received_response_at - before_request;
|
||||
|
||||
log::trace!("Got edit prediction response");
|
||||
|
||||
if let Some(debug_response_tx) = debug_response_tx {
|
||||
debug_response_tx
|
||||
.send((
|
||||
response
|
||||
.as_ref()
|
||||
.map_err(|err| err.to_string())
|
||||
.map(|response| response.0.clone()),
|
||||
request_time,
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let (res, usage) = response?;
|
||||
let request_id = EditPredictionId(res.id.clone().into());
|
||||
let Some(mut output_text) = text_from_response(res) else {
|
||||
return Ok((Some((request_id, None)), usage));
|
||||
};
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionFinished(
|
||||
EditPredictionFinishedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
position,
|
||||
model_output: Some(output_text.clone()),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
if output_text.contains(CURSOR_MARKER) {
|
||||
log::trace!("Stripping out {CURSOR_MARKER} from response");
|
||||
output_text = output_text.replace(CURSOR_MARKER, "");
|
||||
}
|
||||
|
||||
let get_buffer_from_context = |path: &Path| {
|
||||
if Some(path) == active_file_full_path.as_deref() {
|
||||
Some((
|
||||
&active_snapshot,
|
||||
std::slice::from_ref(&excerpt_anchor_range),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let (_, edits) = match options.prompt_format {
|
||||
PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
|
||||
if output_text.contains("--- a/\n+++ b/\nNo edits") {
|
||||
let edits = vec![];
|
||||
(&active_snapshot, edits)
|
||||
} else {
|
||||
crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
|
||||
}
|
||||
}
|
||||
PromptFormat::OldTextNewText => {
|
||||
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
|
||||
}
|
||||
_ => {
|
||||
bail!("unsupported prompt format {}", options.prompt_format)
|
||||
}
|
||||
};
|
||||
let old_text = snapshot
|
||||
.text_for_range(editable_offset_range.clone())
|
||||
.collect::<String>();
|
||||
let edits: Vec<_> = language::text_diff(&old_text, &output_text)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(editable_offset_range.start + range.start)
|
||||
..snapshot.anchor_before(editable_offset_range.start + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
anyhow::Ok((
|
||||
Some((
|
||||
request_id,
|
||||
Some((
|
||||
inputs,
|
||||
active_buffer,
|
||||
active_snapshot.clone(),
|
||||
prompt_input,
|
||||
buffer,
|
||||
snapshot.clone(),
|
||||
edits,
|
||||
received_response_at,
|
||||
)),
|
||||
@@ -325,3 +189,52 @@ pub fn request_prediction_with_zeta2(
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn zeta2_prompt_input(
|
||||
snapshot: &language::BufferSnapshot,
|
||||
related_files: Arc<[zeta_prompt::RelatedFile]>,
|
||||
events: Vec<Arc<zeta_prompt::Event>>,
|
||||
excerpt_path: Arc<Path>,
|
||||
cursor_offset: usize,
|
||||
) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
|
||||
let cursor_point = cursor_offset.to_point(snapshot);
|
||||
|
||||
let (editable_range, context_range) =
|
||||
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
|
||||
cursor_point,
|
||||
snapshot,
|
||||
MAX_CONTEXT_TOKENS,
|
||||
MAX_REWRITE_TOKENS,
|
||||
);
|
||||
|
||||
let context_start_offset = context_range.start.to_offset(snapshot);
|
||||
let editable_offset_range = editable_range.to_offset(snapshot);
|
||||
let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
|
||||
let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
|
||||
..(editable_offset_range.end - context_start_offset);
|
||||
|
||||
let prompt_input = zeta_prompt::ZetaPromptInput {
|
||||
cursor_path: excerpt_path,
|
||||
cursor_excerpt: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
editable_range_in_excerpt,
|
||||
cursor_offset_in_excerpt,
|
||||
events,
|
||||
related_files,
|
||||
};
|
||||
(editable_offset_range, prompt_input)
|
||||
}
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> String {
|
||||
eprintln!("{}", patch);
|
||||
eprintln!("---------------------");
|
||||
eprintln!("{}", input.cursor_excerpt);
|
||||
crate::udiff::apply_diff_to_string(
|
||||
patch,
|
||||
&input.cursor_excerpt[input.editable_range_in_excerpt.clone()],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
|
||||
workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "ep_cli"
|
||||
name = "ep"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
@@ -20,10 +20,9 @@ chrono.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace= true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
collections.workspace = true
|
||||
debug_adapter_extension.workspace = true
|
||||
edit_prediction_context.workspace = true
|
||||
dirs.workspace = true
|
||||
extension.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -35,6 +34,7 @@ language_extension.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
languages = { workspace = true, features = ["load-grammars"] }
|
||||
libc.workspace = true
|
||||
log.workspace = true
|
||||
node_runtime.workspace = true
|
||||
paths.workspace = true
|
||||
@@ -51,11 +51,19 @@ smol.workspace = true
|
||||
sqlez.workspace = true
|
||||
sqlez_macros.workspace = true
|
||||
terminal_view.workspace = true
|
||||
toml.workspace = true
|
||||
util.workspace = true
|
||||
watch.workspace = true
|
||||
edit_prediction = { workspace = true, features = ["eval-support"] }
|
||||
zlog.workspace = true
|
||||
edit_prediction = { workspace = true, features = ["cli-support"] }
|
||||
wasmtime.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
# Wasmtime is included as a dependency in order to enable the same
|
||||
# features that are enabled in Zed.
|
||||
#
|
||||
# If we don't enable these features we get crashes when creating
|
||||
# a Tree-sitter WasmStore.
|
||||
[package.metadata.cargo-machete]
|
||||
ignored = ["wasmtime"]
|
||||
|
||||
[dev-dependencies]
|
||||
indoc.workspace = true
|
||||
|
||||
@@ -5,11 +5,13 @@ use anthropic::{
|
||||
use anyhow::Result;
|
||||
use http_client::HttpClient;
|
||||
use indoc::indoc;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use sqlez::bindable::Bind;
|
||||
use sqlez::bindable::StaticColumnCount;
|
||||
use sqlez_macros::sql;
|
||||
use std::hash::Hash;
|
||||
use std::hash::Hasher;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct PlainLlmClient {
|
||||
@@ -18,7 +20,8 @@ pub struct PlainLlmClient {
|
||||
}
|
||||
|
||||
impl PlainLlmClient {
|
||||
fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
fn new() -> Result<Self> {
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
|
||||
Ok(Self {
|
||||
@@ -29,12 +32,12 @@ impl PlainLlmClient {
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<AnthropicResponse> {
|
||||
let request = AnthropicRequest {
|
||||
model,
|
||||
model: model.to_string(),
|
||||
max_tokens,
|
||||
messages,
|
||||
tools: Vec::new(),
|
||||
@@ -105,11 +108,12 @@ struct SerializableMessage {
|
||||
}
|
||||
|
||||
impl BatchingLlmClient {
|
||||
fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
fn new(cache_path: &Path) -> Result<Self> {
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
|
||||
|
||||
let connection = sqlez::connection::Connection::open_file(&cache_path);
|
||||
let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
|
||||
let mut statement = sqlez::statement::Statement::prepare(
|
||||
&connection,
|
||||
indoc! {"
|
||||
@@ -182,16 +186,16 @@ impl BatchingLlmClient {
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<Option<AnthropicResponse>> {
|
||||
let response = self.lookup(&model, max_tokens, &messages)?;
|
||||
let response = self.lookup(model, max_tokens, &messages)?;
|
||||
if let Some(response) = response {
|
||||
return Ok(Some(response));
|
||||
}
|
||||
|
||||
self.mark_for_batch(&model, max_tokens, &messages)?;
|
||||
self.mark_for_batch(model, max_tokens, &messages)?;
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
@@ -258,7 +262,7 @@ impl BatchingLlmClient {
|
||||
}
|
||||
}
|
||||
}
|
||||
log::info!("Uploaded {} successful requests", success_count);
|
||||
log::info!("Downloaded {} successful requests", success_count);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,23 +367,20 @@ fn message_content_to_string(content: &[RequestContent]) -> String {
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub enum LlmClient {
|
||||
pub enum AnthropicClient {
|
||||
// No batching
|
||||
Plain(PlainLlmClient),
|
||||
Batch(BatchingLlmClient),
|
||||
Dummy,
|
||||
}
|
||||
|
||||
impl LlmClient {
|
||||
pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
Ok(Self::Plain(PlainLlmClient::new(http_client)?))
|
||||
impl AnthropicClient {
|
||||
pub fn plain() -> Result<Self> {
|
||||
Ok(Self::Plain(PlainLlmClient::new()?))
|
||||
}
|
||||
|
||||
pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
Ok(Self::Batch(BatchingLlmClient::new(
|
||||
cache_path,
|
||||
http_client,
|
||||
)?))
|
||||
pub fn batch(cache_path: &Path) -> Result<Self> {
|
||||
Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
@@ -389,29 +390,29 @@ impl LlmClient {
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<Option<AnthropicResponse>> {
|
||||
match self {
|
||||
LlmClient::Plain(plain_llm_client) => plain_llm_client
|
||||
AnthropicClient::Plain(plain_llm_client) => plain_llm_client
|
||||
.generate(model, max_tokens, messages)
|
||||
.await
|
||||
.map(Some),
|
||||
LlmClient::Batch(batching_llm_client) => {
|
||||
AnthropicClient::Batch(batching_llm_client) => {
|
||||
batching_llm_client
|
||||
.generate(model, max_tokens, messages)
|
||||
.await
|
||||
}
|
||||
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn sync_batches(&self) -> Result<()> {
|
||||
match self {
|
||||
LlmClient::Plain(_) => Ok(()),
|
||||
LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
|
||||
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
AnthropicClient::Plain(_) => Ok(()),
|
||||
AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
|
||||
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
}
|
||||
}
|
||||
}
|
||||
22
crates/edit_prediction_cli/src/distill.rs
Normal file
22
crates/edit_prediction_cli/src/distill.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use std::mem;
|
||||
|
||||
use crate::example::Example;
|
||||
|
||||
pub async fn run_distill(example: &mut Example) -> Result<()> {
|
||||
let [prediction]: [_; 1] =
|
||||
mem::take(&mut example.predictions)
|
||||
.try_into()
|
||||
.map_err(|preds: Vec<_>| {
|
||||
anyhow!(
|
||||
"Example has {} predictions, but it should have exactly one",
|
||||
preds.len()
|
||||
)
|
||||
})?;
|
||||
|
||||
example.expected_patch = prediction.actual_patch;
|
||||
example.prompt = None;
|
||||
example.predictions = Vec::new();
|
||||
example.score = Vec::new();
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,641 +0,0 @@
|
||||
use crate::metrics::{self, Scores};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io::{IsTerminal, Write},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use edit_prediction::{EditPredictionStore, udiff::DiffLine};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use project::Project;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{
|
||||
EvaluateArguments, PredictionOptions,
|
||||
example::{Example, NamedExample},
|
||||
headless::ZetaCliAppState,
|
||||
paths::print_run_data_dir,
|
||||
predict::{PredictionDetails, perform_predict, setup_store},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ExecutionData {
|
||||
execution_id: String,
|
||||
diff: String,
|
||||
reasoning: String,
|
||||
}
|
||||
|
||||
pub async fn run_evaluate(
|
||||
args: EvaluateArguments,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
if args.example_paths.is_empty() {
|
||||
eprintln!("No examples provided");
|
||||
return;
|
||||
}
|
||||
|
||||
let all_tasks = args.example_paths.into_iter().map(|path| {
|
||||
let options = args.options.clone();
|
||||
let app_state = app_state.clone();
|
||||
let example = NamedExample::load(&path).expect("Failed to load example");
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let project = example.setup_project(&app_state, cx).await.unwrap();
|
||||
|
||||
let providers = (0..args.repetitions)
|
||||
.map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
|
||||
|
||||
let tasks = providers
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(move |(repetition_ix, store)| {
|
||||
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
|
||||
let example = example.clone();
|
||||
let project = project.clone();
|
||||
let options = options.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let name = example.name.clone();
|
||||
run_evaluate_one(
|
||||
example,
|
||||
repetition_ix,
|
||||
project,
|
||||
store,
|
||||
options,
|
||||
!args.skip_prediction,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| (err, name, repetition_ix))
|
||||
})
|
||||
});
|
||||
futures::future::join_all(tasks).await
|
||||
})
|
||||
});
|
||||
let all_results = futures::future::join_all(all_tasks).await;
|
||||
|
||||
write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
|
||||
if let Some(mut output_file) =
|
||||
std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
|
||||
{
|
||||
write_aggregated_scores(&mut output_file, &all_results).log_err();
|
||||
};
|
||||
|
||||
if args.repetitions > 1 {
|
||||
if let Err(e) = write_bucketed_analysis(&all_results) {
|
||||
eprintln!("Failed to write bucketed analysis: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
|
||||
}
|
||||
|
||||
fn write_aggregated_scores(
|
||||
w: &mut impl std::io::Write,
|
||||
all_results: &Vec<
|
||||
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
|
||||
>,
|
||||
) -> Result<()> {
|
||||
let mut successful = Vec::new();
|
||||
let mut failed_count = 0;
|
||||
|
||||
for result in all_results.iter().flatten() {
|
||||
match result {
|
||||
Ok((eval_result, _execution_data)) => successful.push(eval_result),
|
||||
Err((err, name, repetition_ix)) => {
|
||||
if failed_count == 0 {
|
||||
writeln!(w, "## Errors\n")?;
|
||||
}
|
||||
|
||||
failed_count += 1;
|
||||
writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if successful.len() > 1 {
|
||||
let edit_scores = successful
|
||||
.iter()
|
||||
.filter_map(|r| r.edit_scores.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let has_edit_predictions = edit_scores.len() > 0;
|
||||
let aggregated_result = EvaluationResult {
|
||||
context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
|
||||
edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
|
||||
prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
|
||||
generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
|
||||
/ successful.len(),
|
||||
};
|
||||
|
||||
writeln!(w, "\n{}", "-".repeat(80))?;
|
||||
writeln!(w, "\n## TOTAL SCORES")?;
|
||||
writeln!(w, "{:#}", aggregated_result)?;
|
||||
}
|
||||
|
||||
if successful.len() + failed_count > 1 {
|
||||
writeln!(
|
||||
w,
|
||||
"\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
|
||||
successful.len(),
|
||||
successful.len() + failed_count,
|
||||
(successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_evaluate_one(
|
||||
example: NamedExample,
|
||||
repetition_ix: Option<u16>,
|
||||
project: Entity<Project>,
|
||||
store: Entity<EditPredictionStore>,
|
||||
prediction_options: PredictionOptions,
|
||||
predict: bool,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(EvaluationResult, ExecutionData)> {
|
||||
let predict_result = perform_predict(
|
||||
example.clone(),
|
||||
project,
|
||||
store,
|
||||
repetition_ix,
|
||||
prediction_options,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let evaluation_result = evaluate(&example.example, &predict_result, predict);
|
||||
|
||||
if repetition_ix.is_none() {
|
||||
write_eval_result(
|
||||
&example,
|
||||
&predict_result,
|
||||
&evaluation_result,
|
||||
&mut std::io::stdout(),
|
||||
std::io::stdout().is_terminal(),
|
||||
predict,
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(mut results_file) =
|
||||
std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
|
||||
{
|
||||
write_eval_result(
|
||||
&example,
|
||||
&predict_result,
|
||||
&evaluation_result,
|
||||
&mut results_file,
|
||||
false,
|
||||
predict,
|
||||
)
|
||||
.log_err();
|
||||
}
|
||||
|
||||
let execution_data = ExecutionData {
|
||||
execution_id: if let Some(rep_ix) = repetition_ix {
|
||||
format!("{:03}", rep_ix)
|
||||
} else {
|
||||
example.name.clone()
|
||||
},
|
||||
diff: predict_result.diff.clone(),
|
||||
reasoning: std::fs::read_to_string(
|
||||
predict_result
|
||||
.run_example_dir
|
||||
.join("prediction_response.md"),
|
||||
)
|
||||
.unwrap_or_default(),
|
||||
};
|
||||
|
||||
anyhow::Ok((evaluation_result, execution_data))
|
||||
}
|
||||
|
||||
fn write_eval_result(
|
||||
example: &NamedExample,
|
||||
predictions: &PredictionDetails,
|
||||
evaluation_result: &EvaluationResult,
|
||||
out: &mut impl Write,
|
||||
use_color: bool,
|
||||
predict: bool,
|
||||
) -> Result<()> {
|
||||
if predict {
|
||||
writeln!(
|
||||
out,
|
||||
"## Expected edit prediction:\n\n```diff\n{}\n```\n",
|
||||
compare_diffs(
|
||||
&example.example.expected_patch,
|
||||
&predictions.diff,
|
||||
use_color
|
||||
)
|
||||
)?;
|
||||
writeln!(
|
||||
out,
|
||||
"## Actual edit prediction:\n\n```diff\n{}\n```\n",
|
||||
compare_diffs(
|
||||
&predictions.diff,
|
||||
&example.example.expected_patch,
|
||||
use_color
|
||||
)
|
||||
)?;
|
||||
}
|
||||
|
||||
writeln!(out, "{:#}", evaluation_result)?;
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct EditScores {
|
||||
pub line_match: Scores,
|
||||
pub chr_f: f64,
|
||||
}
|
||||
|
||||
impl EditScores {
|
||||
pub fn aggregate(scores: &[EditScores]) -> EditScores {
|
||||
let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
|
||||
let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
|
||||
|
||||
EditScores { line_match, chr_f }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct EvaluationResult {
|
||||
pub edit_scores: Option<EditScores>,
|
||||
pub context_scores: Scores,
|
||||
pub prompt_len: usize,
|
||||
pub generated_len: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EvaluationResult {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if f.alternate() {
|
||||
self.fmt_table(f)
|
||||
} else {
|
||||
self.fmt_markdown(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EvaluationResult {
|
||||
fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
r#"
|
||||
### Context Scores
|
||||
{}
|
||||
"#,
|
||||
self.context_scores.to_markdown(),
|
||||
)?;
|
||||
if let Some(scores) = &self.edit_scores {
|
||||
write!(
|
||||
f,
|
||||
r#"
|
||||
### Edit Prediction Scores
|
||||
{}"#,
|
||||
scores.line_match.to_markdown()
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(f, "#### Prompt Statistics")?;
|
||||
writeln!(f, "─────────────────────────")?;
|
||||
writeln!(f, "Prompt_len Generated_len")?;
|
||||
writeln!(f, "─────────────────────────")?;
|
||||
writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
|
||||
writeln!(f)?;
|
||||
writeln!(f)?;
|
||||
writeln!(f, "#### Performance Scores")?;
|
||||
writeln!(
|
||||
f,
|
||||
"──────────────────────────────────────────────────────────────────"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
" TP FP FN Precision Recall F1"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"──────────────────────────────────────────────────────────────────"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"Context Retrieval {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
|
||||
self.context_scores.true_positives,
|
||||
self.context_scores.false_positives,
|
||||
self.context_scores.false_negatives,
|
||||
self.context_scores.precision() * 100.0,
|
||||
self.context_scores.recall() * 100.0,
|
||||
self.context_scores.f1_score() * 100.0
|
||||
)?;
|
||||
if let Some(edit_scores) = &self.edit_scores {
|
||||
let line_match = &edit_scores.line_match;
|
||||
writeln!(f, "Edit Prediction")?;
|
||||
writeln!(
|
||||
f,
|
||||
" ├─ exact lines {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
|
||||
line_match.true_positives,
|
||||
line_match.false_positives,
|
||||
line_match.false_negatives,
|
||||
line_match.precision() * 100.0,
|
||||
line_match.recall() * 100.0,
|
||||
line_match.f1_score() * 100.0
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
" └─ diff chrF {:<6} {:<6} {:<6} {:>8} {:>8} {:>6.2}",
|
||||
"-", "-", "-", "-", "-", edit_scores.chr_f
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
|
||||
let mut eval_result = EvaluationResult {
|
||||
prompt_len: preds.prompt_len,
|
||||
generated_len: preds.generated_len,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if predict {
|
||||
// todo: alternatives for patches
|
||||
let expected_patch = example
|
||||
.expected_patch
|
||||
.lines()
|
||||
.map(DiffLine::parse)
|
||||
.collect::<Vec<_>>();
|
||||
let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
|
||||
|
||||
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
|
||||
let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
|
||||
|
||||
eval_result.edit_scores = Some(EditScores { line_match, chr_f });
|
||||
}
|
||||
|
||||
eval_result
|
||||
}
|
||||
|
||||
/// Return annotated `patch_a` so that:
|
||||
/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
|
||||
/// Additions and deletions that are present in `patch_b` will be highlighted in green.
|
||||
pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
|
||||
let green = if use_color { "\x1b[32m✓ " } else { "" };
|
||||
let red = if use_color { "\x1b[31m✗ " } else { "" };
|
||||
let neutral = if use_color { " " } else { "" };
|
||||
let reset = if use_color { "\x1b[0m" } else { "" };
|
||||
let lines_a = patch_a.lines().map(DiffLine::parse);
|
||||
let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
|
||||
|
||||
let annotated = lines_a
|
||||
.map(|line| match line {
|
||||
DiffLine::Addition(_) | DiffLine::Deletion(_) => {
|
||||
if lines_b.contains(&line) {
|
||||
format!("{green}{line}{reset}")
|
||||
} else {
|
||||
format!("{red}{line}{reset}")
|
||||
}
|
||||
}
|
||||
_ => format!("{neutral}{line}{reset}"),
|
||||
})
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
annotated.join("\n")
|
||||
}
|
||||
|
||||
fn write_bucketed_analysis(
|
||||
all_results: &Vec<
|
||||
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
|
||||
>,
|
||||
) -> Result<()> {
|
||||
#[derive(Debug)]
|
||||
struct EditBucket {
|
||||
diff: String,
|
||||
is_correct: bool,
|
||||
execution_indices: Vec<String>,
|
||||
reasoning_samples: Vec<String>,
|
||||
}
|
||||
|
||||
let mut total_executions = 0;
|
||||
let mut empty_predictions = Vec::new();
|
||||
let mut errors = Vec::new();
|
||||
|
||||
let mut buckets: HashMap<String, EditBucket> = HashMap::new();
|
||||
|
||||
for result in all_results.iter().flatten() {
|
||||
total_executions += 1;
|
||||
|
||||
let (evaluation_result, execution_data) = match result {
|
||||
Ok((eval_result, execution_data)) => {
|
||||
if execution_data.diff.is_empty() {
|
||||
empty_predictions.push(execution_data);
|
||||
continue;
|
||||
}
|
||||
(eval_result, execution_data)
|
||||
}
|
||||
Err(err) => {
|
||||
errors.push(err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
buckets
|
||||
.entry(execution_data.diff.clone())
|
||||
.and_modify(|bucket| {
|
||||
bucket
|
||||
.execution_indices
|
||||
.push(execution_data.execution_id.clone());
|
||||
bucket
|
||||
.reasoning_samples
|
||||
.push(execution_data.reasoning.clone());
|
||||
})
|
||||
.or_insert_with(|| EditBucket {
|
||||
diff: execution_data.diff.clone(),
|
||||
is_correct: {
|
||||
evaluation_result
|
||||
.edit_scores
|
||||
.as_ref()
|
||||
.map_or(false, |edit_scores| {
|
||||
edit_scores.line_match.false_positives == 0
|
||||
&& edit_scores.line_match.false_negatives == 0
|
||||
&& edit_scores.line_match.true_positives > 0
|
||||
})
|
||||
},
|
||||
execution_indices: vec![execution_data.execution_id.clone()],
|
||||
reasoning_samples: vec![execution_data.reasoning.clone()],
|
||||
});
|
||||
}
|
||||
|
||||
let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
|
||||
sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
|
||||
(true, false) => std::cmp::Ordering::Less,
|
||||
(false, true) => std::cmp::Ordering::Greater,
|
||||
_ => b.execution_indices.len().cmp(&a.execution_indices.len()),
|
||||
});
|
||||
|
||||
let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
|
||||
let mut output = std::fs::File::create(&output_path)?;
|
||||
|
||||
writeln!(output, "# Bucketed Edit Analysis\n")?;
|
||||
|
||||
writeln!(output, "## Summary\n")?;
|
||||
writeln!(output, "- **Total executions**: {}", total_executions)?;
|
||||
|
||||
let correct_count: usize = sorted_buckets
|
||||
.iter()
|
||||
.filter(|b| b.is_correct)
|
||||
.map(|b| b.execution_indices.len())
|
||||
.sum();
|
||||
|
||||
let incorrect_count: usize = sorted_buckets
|
||||
.iter()
|
||||
.filter(|b| !b.is_correct)
|
||||
.map(|b| b.execution_indices.len())
|
||||
.sum();
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"- **Correct predictions**: {} ({:.1}%)",
|
||||
correct_count,
|
||||
(correct_count as f64 / total_executions as f64) * 100.0
|
||||
)?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"- **Incorrect predictions**: {} ({:.1}%)",
|
||||
incorrect_count,
|
||||
(incorrect_count as f64 / total_executions as f64) * 100.0
|
||||
)?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"- **No Predictions**: {} ({:.1}%)",
|
||||
empty_predictions.len(),
|
||||
(empty_predictions.len() as f64 / total_executions as f64) * 100.0
|
||||
)?;
|
||||
|
||||
let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
|
||||
writeln!(
|
||||
output,
|
||||
"- **Unique incorrect edit patterns**: {}\n",
|
||||
unique_incorrect
|
||||
)?;
|
||||
|
||||
writeln!(output, "---\n")?;
|
||||
|
||||
for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
|
||||
if idx == 0 {
|
||||
writeln!(
|
||||
output,
|
||||
"## Correct Predictions ({} occurrences)\n",
|
||||
bucket.execution_indices.len()
|
||||
)?;
|
||||
}
|
||||
|
||||
writeln!(output, "**Predicted Edit:**\n")?;
|
||||
writeln!(output, "```diff")?;
|
||||
writeln!(output, "{}", bucket.diff)?;
|
||||
writeln!(output, "```\n")?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"**Executions:** {}\n",
|
||||
bucket.execution_indices.join(", ")
|
||||
)?;
|
||||
writeln!(output, "---\n")?;
|
||||
}
|
||||
|
||||
for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
|
||||
writeln!(
|
||||
output,
|
||||
"## Incorrect Prediction #{} ({} occurrences)\n",
|
||||
idx + 1,
|
||||
bucket.execution_indices.len()
|
||||
)?;
|
||||
|
||||
writeln!(output, "**Predicted Edit:**\n")?;
|
||||
writeln!(output, "```diff")?;
|
||||
writeln!(output, "{}", bucket.diff)?;
|
||||
writeln!(output, "```\n")?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"**Executions:** {}\n",
|
||||
bucket.execution_indices.join(", ")
|
||||
)?;
|
||||
|
||||
for (exec_id, reasoning) in bucket
|
||||
.execution_indices
|
||||
.iter()
|
||||
.zip(bucket.reasoning_samples.iter())
|
||||
{
|
||||
writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
|
||||
}
|
||||
|
||||
writeln!(output, "\n---\n")?;
|
||||
}
|
||||
|
||||
if !empty_predictions.is_empty() {
|
||||
writeln!(
|
||||
output,
|
||||
"## No Predictions ({} occurrences)\n",
|
||||
empty_predictions.len()
|
||||
)?;
|
||||
|
||||
for execution_data in &empty_predictions {
|
||||
writeln!(
|
||||
output,
|
||||
"{}",
|
||||
fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
|
||||
)?;
|
||||
}
|
||||
writeln!(output, "\n---\n")?;
|
||||
}
|
||||
|
||||
if !errors.is_empty() {
|
||||
writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
|
||||
|
||||
for (err, name, repetition_ix) in &errors {
|
||||
writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
|
||||
}
|
||||
writeln!(output, "\n---\n")?;
|
||||
}
|
||||
|
||||
fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
|
||||
let exec_content = format!(
|
||||
"\n### Execution {} `{}/{}/prediction_response.md`{}",
|
||||
exec_id,
|
||||
crate::paths::RUN_DIR.display(),
|
||||
exec_id,
|
||||
indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
|
||||
);
|
||||
indent_text(&exec_content, 2)
|
||||
}
|
||||
|
||||
fn indent_text(text: &str, spaces: usize) -> String {
|
||||
let indent = " ".repeat(spaces);
|
||||
text.lines()
|
||||
.collect::<Vec<_>>()
|
||||
.join(&format!("\n{}", indent))
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
|
||||
let err = format!("{err:?}")
|
||||
.replace("<edits", "```xml\n<edits")
|
||||
.replace("</edits>", "</edits>\n```");
|
||||
format!(
|
||||
"### ERROR {name}{}\n\n{err}\n",
|
||||
repetition_ix
|
||||
.map(|ix| format!(" [RUN {ix:03}]"))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
@@ -1,63 +1,105 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cell::RefCell,
|
||||
fmt::{self, Display},
|
||||
fs,
|
||||
hash::Hash,
|
||||
hash::Hasher,
|
||||
io::Write,
|
||||
mem,
|
||||
path::{Path, PathBuf},
|
||||
sync::{Arc, OnceLock},
|
||||
};
|
||||
|
||||
use crate::headless::ZetaCliAppState;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use clap::ValueEnum;
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use futures::{FutureExt as _, future::Shared};
|
||||
use gpui::{AsyncApp, Entity, Task, http_client::Url};
|
||||
use gpui::Entity;
|
||||
use http_client::Url;
|
||||
use language::{Anchor, Buffer};
|
||||
use project::{Project, ProjectPath};
|
||||
use pulldown_cmark::CowStr;
|
||||
use project::Project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
io::{Read, Write},
|
||||
mem,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use zeta_prompt::RelatedFile;
|
||||
|
||||
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
|
||||
|
||||
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
|
||||
const EDIT_HISTORY_HEADING: &str = "Edit History";
|
||||
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
|
||||
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
|
||||
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
|
||||
const REPOSITORY_URL_FIELD: &str = "repository_url";
|
||||
const REVISION_FIELD: &str = "revision";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NamedExample {
|
||||
pub name: String,
|
||||
pub example: Example,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Example {
|
||||
#[serde(default)]
|
||||
pub name: String,
|
||||
pub repository_url: String,
|
||||
pub revision: String,
|
||||
#[serde(default)]
|
||||
pub uncommitted_diff: String,
|
||||
pub cursor_path: PathBuf,
|
||||
pub cursor_path: Arc<Path>,
|
||||
pub cursor_position: String,
|
||||
pub edit_history: String,
|
||||
pub expected_patch: String,
|
||||
|
||||
/// The full content of the file where an edit is being predicted, and the
|
||||
/// actual cursor offset.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub buffer: Option<ExampleBuffer>,
|
||||
|
||||
/// The context retrieved for the prediction. This requires the worktree to
|
||||
/// be loaded and the language server to be started.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub context: Option<ExampleContext>,
|
||||
|
||||
/// The input and expected output from the edit prediction model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<ExamplePrompt>,
|
||||
|
||||
/// The actual predictions from the model.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub predictions: Vec<ExamplePrediction>,
|
||||
|
||||
/// The scores, for how well the actual predictions match the expected
|
||||
/// predictions.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub score: Vec<ExampleScore>,
|
||||
|
||||
/// The application state used to process this example.
|
||||
#[serde(skip)]
|
||||
pub state: Option<ExampleState>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExampleState {
|
||||
pub project: Entity<Project>,
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub cursor_position: Anchor,
|
||||
pub _open_buffers: OpenedBuffers,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleContext {
|
||||
pub files: Arc<[RelatedFile]>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleBuffer {
|
||||
pub content: String,
|
||||
pub cursor_row: u32,
|
||||
pub cursor_column: u32,
|
||||
pub cursor_offset: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExamplePrompt {
|
||||
pub input: String,
|
||||
pub expected_output: String,
|
||||
pub format: PromptFormat,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExamplePrediction {
|
||||
pub actual_patch: String,
|
||||
pub actual_output: String,
|
||||
pub provider: PredictionProvider,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleScore {
|
||||
pub delta_chr_f: f32,
|
||||
pub line_match: ClassificationMetrics,
|
||||
}
|
||||
|
||||
impl Example {
|
||||
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
|
||||
pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
|
||||
// git@github.com:owner/repo.git
|
||||
if self.repository_url.contains('@') {
|
||||
let (owner, repo) = self
|
||||
@@ -89,486 +131,249 @@ impl Example {
|
||||
Ok((owner.into(), repo.into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = self.repo_name()?;
|
||||
pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
|
||||
let mut examples = Vec::new();
|
||||
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
let stdin_path: PathBuf = PathBuf::from("-");
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &self.repository_url],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
let inputs = if inputs.is_empty() {
|
||||
&[stdin_path]
|
||||
} else {
|
||||
inputs
|
||||
};
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
let revision = run_git(
|
||||
&repo_dir,
|
||||
&["rev-parse", &format!("{}^{{commit}}", self.revision)],
|
||||
)
|
||||
.await;
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
for path in inputs {
|
||||
let is_stdin = path.as_path() == Path::new("-");
|
||||
let content = if is_stdin {
|
||||
let mut buffer = String::new();
|
||||
std::io::stdin()
|
||||
.read_to_string(&mut buffer)
|
||||
.expect("Failed to read from stdin");
|
||||
buffer
|
||||
} else {
|
||||
if run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &self.revision],
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await?;
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
if revision != self.revision {
|
||||
run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
|
||||
}
|
||||
revision
|
||||
std::fs::read_to_string(path)
|
||||
.unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
|
||||
};
|
||||
let filename = path.file_stem().unwrap().to_string_lossy().to_string();
|
||||
let ext = if !is_stdin {
|
||||
path.extension()
|
||||
.map(|ext| ext.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| panic!("{} should have an extension", path.display()))
|
||||
} else {
|
||||
"jsonl".to_string()
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["worktree", "add", "-f", &worktree_path_string, &file_name],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
// Apply the uncommitted diff for this example.
|
||||
if !self.uncommitted_diff.is_empty() {
|
||||
let mut apply_process = smol::process::Command::new("git")
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = apply_process.stdin.take().unwrap();
|
||||
stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
|
||||
stdin.close().await?;
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await?;
|
||||
if !apply_result.status.success() {
|
||||
anyhow::bail!(
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
match ext.as_ref() {
|
||||
"json" => {
|
||||
let mut example =
|
||||
serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
|
||||
panic!("Failed to parse example file: {}\n{error}", path.display())
|
||||
});
|
||||
if example.name.is_empty() {
|
||||
example.name = filename;
|
||||
}
|
||||
examples.push(example);
|
||||
}
|
||||
"jsonl" => examples.extend(
|
||||
content
|
||||
.lines()
|
||||
.enumerate()
|
||||
.map(|(line_ix, line)| {
|
||||
let mut example =
|
||||
serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
|
||||
panic!(
|
||||
"Failed to parse example on {}:{}\n{error}",
|
||||
path.display(),
|
||||
line_ix + 1
|
||||
)
|
||||
});
|
||||
if example.name.is_empty() {
|
||||
example.name = format!("{filename}-{line_ix}")
|
||||
}
|
||||
example
|
||||
})
|
||||
.collect::<Vec<Example>>(),
|
||||
),
|
||||
"md" => {
|
||||
examples.push(parse_markdown_example(filename, &content).unwrap());
|
||||
}
|
||||
ext => {
|
||||
panic!("{} has invalid example extension `{ext}`", path.display())
|
||||
}
|
||||
}
|
||||
|
||||
Ok(worktree_path)
|
||||
}
|
||||
|
||||
pub fn unique_name(&self) -> String {
|
||||
let mut hasher = std::hash::DefaultHasher::new();
|
||||
self.hash(&mut hasher);
|
||||
let disambiguator = hasher.finish();
|
||||
let hash = format!("{:04x}", disambiguator);
|
||||
format!("{}_{}", &self.revision[..8], &hash[..4])
|
||||
sort_examples_by_repo_and_rev(&mut examples);
|
||||
examples
|
||||
}
|
||||
|
||||
pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
|
||||
let mut content = String::new();
|
||||
for example in examples {
|
||||
let line = serde_json::to_string(example).unwrap();
|
||||
content.push_str(&line);
|
||||
content.push('\n');
|
||||
}
|
||||
if let Some(output_path) = output_path {
|
||||
std::fs::write(output_path, content).expect("Failed to write examples");
|
||||
} else {
|
||||
std::io::stdout().write_all(&content.as_bytes()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub type ActualExcerpt = Excerpt;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Excerpt {
|
||||
pub path: PathBuf,
|
||||
pub text: String,
|
||||
pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
|
||||
examples.sort_by(|a, b| {
|
||||
a.repository_url
|
||||
.cmp(&b.repository_url)
|
||||
.then(b.revision.cmp(&a.revision))
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
pub enum ExampleFormat {
|
||||
Json,
|
||||
Toml,
|
||||
Md,
|
||||
pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
|
||||
let mut examples_by_repo = HashMap::default();
|
||||
for example in examples.iter_mut() {
|
||||
examples_by_repo
|
||||
.entry(example.repository_url.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(example);
|
||||
}
|
||||
examples_by_repo.into_values().collect()
|
||||
}
|
||||
|
||||
impl NamedExample {
|
||||
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let ext = path.extension();
|
||||
fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
|
||||
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
|
||||
|
||||
match ext.and_then(|s| s.to_str()) {
|
||||
Some("json") => Ok(Self {
|
||||
name: path.file_stem().unwrap_or_default().display().to_string(),
|
||||
example: serde_json::from_str(&content)?,
|
||||
}),
|
||||
Some("toml") => Ok(Self {
|
||||
name: path.file_stem().unwrap_or_default().display().to_string(),
|
||||
example: toml::from_str(&content)?,
|
||||
}),
|
||||
Some("md") => Self::parse_md(&content),
|
||||
Some(_) => {
|
||||
anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
|
||||
}
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"Failed to determine example type since the file does not have an extension."
|
||||
);
|
||||
}
|
||||
}
|
||||
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
|
||||
const EDIT_HISTORY_HEADING: &str = "Edit History";
|
||||
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
|
||||
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
|
||||
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
|
||||
const REPOSITORY_URL_FIELD: &str = "repository_url";
|
||||
const REVISION_FIELD: &str = "revision";
|
||||
|
||||
let parser = Parser::new(input);
|
||||
|
||||
let mut example = Example {
|
||||
name: id,
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: PathBuf::new().into(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patch: String::new(),
|
||||
buffer: None,
|
||||
context: None,
|
||||
prompt: None,
|
||||
predictions: Vec::new(),
|
||||
score: Vec::new(),
|
||||
state: None,
|
||||
};
|
||||
|
||||
let mut text = String::new();
|
||||
let mut block_info: CowStr = "".into();
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum Section {
|
||||
Start,
|
||||
UncommittedDiff,
|
||||
EditHistory,
|
||||
CursorPosition,
|
||||
ExpectedExcerpts,
|
||||
ExpectedPatch,
|
||||
Other,
|
||||
}
|
||||
|
||||
pub fn parse_md(input: &str) -> Result<Self> {
|
||||
use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
|
||||
let mut current_section = Section::Start;
|
||||
|
||||
let parser = Parser::new(input);
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Text(line) => {
|
||||
text.push_str(&line);
|
||||
|
||||
let mut named = NamedExample {
|
||||
name: String::new(),
|
||||
example: Example {
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: PathBuf::new(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patch: String::new(),
|
||||
},
|
||||
};
|
||||
|
||||
let mut text = String::new();
|
||||
let mut block_info: CowStr = "".into();
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum Section {
|
||||
UncommittedDiff,
|
||||
EditHistory,
|
||||
CursorPosition,
|
||||
ExpectedExcerpts,
|
||||
ExpectedPatch,
|
||||
Other,
|
||||
}
|
||||
|
||||
let mut current_section = Section::Other;
|
||||
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Text(line) => {
|
||||
text.push_str(&line);
|
||||
|
||||
if !named.name.is_empty()
|
||||
&& current_section == Section::Other
|
||||
// in h1 section
|
||||
&& let Some((field, value)) = line.split_once('=')
|
||||
{
|
||||
match field.trim() {
|
||||
REPOSITORY_URL_FIELD => {
|
||||
named.example.repository_url = value.trim().to_string();
|
||||
}
|
||||
REVISION_FIELD => {
|
||||
named.example.revision = value.trim().to_string();
|
||||
}
|
||||
_ => {}
|
||||
if let Section::Start = current_section
|
||||
&& let Some((field, value)) = line.split_once('=')
|
||||
{
|
||||
match field.trim() {
|
||||
REPOSITORY_URL_FIELD => {
|
||||
example.repository_url = value.trim().to_string();
|
||||
}
|
||||
REVISION_FIELD => {
|
||||
example.revision = value.trim().to_string();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
|
||||
if !named.name.is_empty() {
|
||||
anyhow::bail!(
|
||||
"Found multiple H1 headings. There should only be one with the name of the example."
|
||||
);
|
||||
}
|
||||
named.name = mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
|
||||
let title = mem::take(&mut text);
|
||||
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
|
||||
Section::UncommittedDiff
|
||||
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
|
||||
Section::EditHistory
|
||||
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
|
||||
Section::CursorPosition
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
|
||||
Section::ExpectedPatch
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
|
||||
Section::ExpectedExcerpts
|
||||
} else {
|
||||
Section::Other
|
||||
};
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(level)) => {
|
||||
anyhow::bail!("Unexpected heading level: {level}");
|
||||
}
|
||||
Event::Start(Tag::CodeBlock(kind)) => {
|
||||
match kind {
|
||||
CodeBlockKind::Fenced(info) => {
|
||||
block_info = info;
|
||||
}
|
||||
CodeBlockKind::Indented => {
|
||||
anyhow::bail!("Unexpected indented codeblock");
|
||||
}
|
||||
};
|
||||
}
|
||||
Event::Start(_) => {
|
||||
text.clear();
|
||||
block_info = "".into();
|
||||
}
|
||||
Event::End(TagEnd::CodeBlock) => {
|
||||
let block_info = block_info.trim();
|
||||
match current_section {
|
||||
Section::UncommittedDiff => {
|
||||
named.example.uncommitted_diff = mem::take(&mut text);
|
||||
}
|
||||
Section::EditHistory => {
|
||||
named.example.edit_history.push_str(&mem::take(&mut text));
|
||||
}
|
||||
Section::CursorPosition => {
|
||||
named.example.cursor_path = block_info.into();
|
||||
named.example.cursor_position = mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedExcerpts => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedPatch => {
|
||||
named.example.expected_patch = mem::take(&mut text);
|
||||
}
|
||||
Section::Other => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if named.example.cursor_path.as_path() == Path::new("")
|
||||
|| named.example.cursor_position.is_empty()
|
||||
{
|
||||
anyhow::bail!("Missing cursor position codeblock");
|
||||
}
|
||||
|
||||
Ok(named)
|
||||
}
|
||||
|
||||
pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
|
||||
match format {
|
||||
ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
|
||||
ExampleFormat::Toml => {
|
||||
Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
|
||||
}
|
||||
ExampleFormat::Md => Ok(write!(out, "{}", self)?),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn setup_project(
|
||||
&self,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<Project>> {
|
||||
let worktree_path = self.setup_worktree().await?;
|
||||
|
||||
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
|
||||
|
||||
AUTHENTICATED
|
||||
.get_or_init(|| {
|
||||
let client = app_state.client.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
client
|
||||
.sign_in_with_optional_connect(true, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
.shared()
|
||||
})
|
||||
.clone()
|
||||
.await;
|
||||
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
worktree
|
||||
.read_with(cx, |worktree, _cx| {
|
||||
worktree.as_local().unwrap().scan_complete()
|
||||
})?
|
||||
.await;
|
||||
|
||||
anyhow::Ok(project)
|
||||
}
|
||||
|
||||
pub async fn setup_worktree(&self) -> Result<PathBuf> {
|
||||
self.example.setup_worktree(self.file_name()).await
|
||||
}
|
||||
|
||||
pub fn file_name(&self) -> String {
|
||||
self.name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_whitespace() {
|
||||
'-'
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
|
||||
let title = mem::take(&mut text);
|
||||
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
|
||||
Section::UncommittedDiff
|
||||
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
|
||||
Section::EditHistory
|
||||
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
|
||||
Section::CursorPosition
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
|
||||
Section::ExpectedPatch
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
|
||||
Section::ExpectedExcerpts
|
||||
} else {
|
||||
c.to_ascii_lowercase()
|
||||
Section::Other
|
||||
};
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(level)) => {
|
||||
anyhow::bail!("Unexpected heading level: {level}");
|
||||
}
|
||||
Event::Start(Tag::CodeBlock(kind)) => {
|
||||
match kind {
|
||||
CodeBlockKind::Fenced(info) => {
|
||||
block_info = info;
|
||||
}
|
||||
CodeBlockKind::Indented => {
|
||||
anyhow::bail!("Unexpected indented codeblock");
|
||||
}
|
||||
};
|
||||
}
|
||||
Event::Start(_) => {
|
||||
text.clear();
|
||||
block_info = "".into();
|
||||
}
|
||||
Event::End(TagEnd::CodeBlock) => {
|
||||
let block_info = block_info.trim();
|
||||
match current_section {
|
||||
Section::UncommittedDiff => {
|
||||
example.uncommitted_diff = mem::take(&mut text);
|
||||
}
|
||||
Section::EditHistory => {
|
||||
example.edit_history.push_str(&mem::take(&mut text));
|
||||
}
|
||||
Section::CursorPosition => {
|
||||
example.cursor_path = Path::new(block_info).into();
|
||||
example.cursor_position = mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedExcerpts => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedPatch => {
|
||||
example.expected_patch = mem::take(&mut text);
|
||||
}
|
||||
Section::Start | Section::Other => {}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn cursor_position(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(Entity<Buffer>, Anchor)> {
|
||||
let worktree = project.read_with(cx, |project, cx| {
|
||||
project.visible_worktrees(cx).next().unwrap()
|
||||
})?;
|
||||
let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: cursor_path,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
let cursor_offset_within_excerpt = self
|
||||
.example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.ok_or_else(|| anyhow!("missing cursor marker"))?;
|
||||
let mut cursor_excerpt = self.example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
"",
|
||||
);
|
||||
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let Some((excerpt_offset, _)) = matches.next() else {
|
||||
anyhow::bail!(
|
||||
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
|
||||
);
|
||||
};
|
||||
assert!(matches.next().is_none());
|
||||
|
||||
Ok(excerpt_offset)
|
||||
})??;
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor =
|
||||
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
|
||||
Ok((cursor_buffer, cursor_anchor))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn apply_edit_history(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<OpenedBuffers<'_>> {
|
||||
edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
args.join(" "),
|
||||
repo_path.display(),
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
}
|
||||
|
||||
impl Display for NamedExample {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "# {}\n\n", self.name)?;
|
||||
write!(
|
||||
f,
|
||||
"{REPOSITORY_URL_FIELD} = {}\n",
|
||||
self.example.repository_url
|
||||
)?;
|
||||
write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
|
||||
|
||||
write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
|
||||
write!(f, "`````diff\n")?;
|
||||
write!(f, "{}", self.example.uncommitted_diff)?;
|
||||
write!(f, "`````\n")?;
|
||||
|
||||
if !self.example.edit_history.is_empty() {
|
||||
write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
write!(
|
||||
f,
|
||||
"## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
|
||||
self.example.cursor_path.display(),
|
||||
self.example.cursor_position
|
||||
)?;
|
||||
write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
|
||||
|
||||
if !self.example.expected_patch.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
|
||||
self.example.expected_patch
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
|
||||
anyhow::bail!("Missing cursor position codeblock");
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
REPO_LOCKS
|
||||
.with(|cell| {
|
||||
cell.borrow_mut()
|
||||
.entry(path.as_ref().to_path_buf())
|
||||
.or_default()
|
||||
.clone()
|
||||
})
|
||||
.lock_owned()
|
||||
.await
|
||||
Ok(example)
|
||||
}
|
||||
|
||||
287
crates/edit_prediction_cli/src/format_prompt.rs
Normal file
287
crates/edit_prediction_cli/src/format_prompt.rs
Normal file
@@ -0,0 +1,287 @@
|
||||
use crate::{
|
||||
PromptFormat,
|
||||
example::{Example, ExamplePrompt},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
progress::{Progress, Step},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use anyhow::{Context as _, Result, ensure};
|
||||
use edit_prediction::{
|
||||
EditPredictionStore,
|
||||
zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
|
||||
};
|
||||
use gpui::AsyncApp;
|
||||
use std::sync::Arc;
|
||||
use zeta_prompt::format_zeta_prompt;
|
||||
|
||||
pub async fn run_format_prompt(
|
||||
example: &mut Example,
|
||||
prompt_format: PromptFormat,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
|
||||
|
||||
match prompt_format {
|
||||
PromptFormat::Teacher => {
|
||||
let prompt = TeacherPrompt::format_prompt(example);
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output: example.expected_patch.clone(), // TODO
|
||||
format: prompt_format,
|
||||
});
|
||||
}
|
||||
PromptFormat::Zeta2 => {
|
||||
run_load_project(example, app_state, cx.clone()).await?;
|
||||
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
|
||||
let state = example.state.as_ref().context("state must be set")?;
|
||||
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
|
||||
let project = state.project.clone();
|
||||
let (_, input) = ep_store.update(&mut cx, |ep_store, _cx| {
|
||||
anyhow::Ok(zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example
|
||||
.context
|
||||
.as_ref()
|
||||
.context("context must be set")?
|
||||
.files
|
||||
.clone(),
|
||||
ep_store.edit_history_for_project(&project),
|
||||
example.cursor_path.clone(),
|
||||
example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.context("buffer must be set")?
|
||||
.cursor_offset,
|
||||
))
|
||||
})??;
|
||||
let prompt = format_zeta_prompt(&input);
|
||||
let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output,
|
||||
format: prompt_format,
|
||||
});
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct TeacherPrompt;
|
||||
|
||||
impl TeacherPrompt {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
|
||||
pub fn format_prompt(example: &Example) -> String {
|
||||
let edit_history = Self::format_edit_history(&example.edit_history);
|
||||
let context = Self::format_context(example);
|
||||
let editable_region = Self::format_editable_region(example);
|
||||
|
||||
let prompt = Self::PROMPT
|
||||
.replace("{{context}}", &context)
|
||||
.replace("{{edit_history}}", &edit_history)
|
||||
.replace("{{editable_region}}", &editable_region);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
pub fn parse(example: &Example, response: &str) -> Result<String> {
|
||||
// Ideally, we should always be able to find cursor position in the retrieved context.
|
||||
// In reality, sometimes we don't find it for these reasons:
|
||||
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
|
||||
// (can be fixed by getting cursor coordinates at the load_example stage)
|
||||
// 2. Context retriever just didn't include cursor line.
|
||||
//
|
||||
// In that case, fallback to using `cursor_position` as excerpt.
|
||||
let cursor_file = &example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.context("`buffer` should be filled in in the context collection step")?
|
||||
.content;
|
||||
|
||||
// Extract updated (new) editable region from the model response
|
||||
let new_editable_region = extract_last_codeblock(response);
|
||||
|
||||
// Reconstruct old editable region we sent to the model
|
||||
let old_editable_region = Self::format_editable_region(example);
|
||||
let old_editable_region = Self::extract_editable_region(&old_editable_region);
|
||||
ensure!(
|
||||
cursor_file.contains(&old_editable_region),
|
||||
"Something's wrong: editable_region is not found in the cursor file"
|
||||
);
|
||||
|
||||
// Apply editable region to a larger context and compute diff.
|
||||
// This is needed to get a better context lines around the editable region
|
||||
let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
|
||||
let diff = language::unified_diff(&cursor_file, &edited_file);
|
||||
|
||||
let diff = indoc::formatdoc! {"
|
||||
--- a/{path}
|
||||
+++ b/{path}
|
||||
{diff}",
|
||||
path = example.cursor_path.to_string_lossy(),
|
||||
diff = diff,
|
||||
};
|
||||
|
||||
Ok(diff)
|
||||
}
|
||||
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
// Strip comments ("garbage lines") from edit history
|
||||
let lines = edit_history
|
||||
.lines()
|
||||
.filter(|&s| Self::is_udiff_content_line(s))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
|
||||
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
|
||||
} else {
|
||||
&lines
|
||||
};
|
||||
|
||||
if history_lines.is_empty() {
|
||||
return "(No edit history)".to_string();
|
||||
}
|
||||
|
||||
history_lines.join("\n")
|
||||
}
|
||||
|
||||
fn format_context(example: &Example) -> String {
|
||||
assert!(example.context.is_some(), "Missing context retriever step");
|
||||
|
||||
let mut prompt = String::new();
|
||||
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
fn format_editable_region(example: &Example) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
let path_str = example.cursor_path.to_string_lossy();
|
||||
result.push_str(&format!("`````path=\"{path_str}\"\n"));
|
||||
result.push_str(Self::EDITABLE_REGION_START);
|
||||
|
||||
// TODO: control number of lines around cursor
|
||||
result.push_str(&example.cursor_position);
|
||||
if !example.cursor_position.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
|
||||
result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
|
||||
result.push_str("`````");
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn extract_editable_region(text: &str) -> String {
|
||||
let start = text
|
||||
.find(Self::EDITABLE_REGION_START)
|
||||
.map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
|
||||
let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
|
||||
|
||||
let region = &text[start..end];
|
||||
|
||||
region.replace("<|user_cursor|>", "")
|
||||
}
|
||||
|
||||
fn is_udiff_content_line(s: &str) -> bool {
|
||||
s.starts_with("-")
|
||||
|| s.starts_with("+")
|
||||
|| s.starts_with(" ")
|
||||
|| s.starts_with("---")
|
||||
|| s.starts_with("+++")
|
||||
|| s.starts_with("@@")
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_last_codeblock(text: &str) -> String {
|
||||
let mut last_block = None;
|
||||
let mut search_start = 0;
|
||||
|
||||
while let Some(start) = text[search_start..].find("```") {
|
||||
let start = start + search_start;
|
||||
let bytes = text.as_bytes();
|
||||
let mut backtick_end = start;
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
let backtick_count = backtick_end - start;
|
||||
let closing_backticks = "`".repeat(backtick_count);
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
|
||||
let code_block = &text[backtick_end + 1..backtick_end + end_pos];
|
||||
last_block = Some(code_block.to_string());
|
||||
search_start = backtick_end + end_pos + backtick_count;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
last_block.unwrap_or_else(|| text.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_last_code_block() {
|
||||
let text = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
```
|
||||
first block
|
||||
```
|
||||
|
||||
`````path='something' lines=1:2
|
||||
last block
|
||||
`````
|
||||
"};
|
||||
let last_block = extract_last_codeblock(text);
|
||||
assert_eq!(last_block, "last block\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_editable_region() {
|
||||
let text = indoc::indoc! {"
|
||||
some lines
|
||||
are
|
||||
here
|
||||
<|editable_region_start|>
|
||||
one
|
||||
two three
|
||||
|
||||
<|editable_region_end|>
|
||||
more
|
||||
lines here
|
||||
"};
|
||||
let parsed = TeacherPrompt::extract_editable_region(text);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
indoc::indoc! {"
|
||||
one
|
||||
two three
|
||||
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
use client::{Client, ProxySettings, UserStore};
|
||||
use collections::HashMap;
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::RealFs;
|
||||
use gpui::http_client::read_proxy_from_env;
|
||||
@@ -7,25 +8,39 @@ use gpui_tokio::Tokio;
|
||||
use language::LanguageRegistry;
|
||||
use language_extension::LspAccess;
|
||||
use node_runtime::{NodeBinaryOptions, NodeRuntime};
|
||||
use project::Project;
|
||||
use project::project_settings::ProjectSettings;
|
||||
use release_channel::{AppCommitSha, AppVersion};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use util::ResultExt as _;
|
||||
|
||||
/// Headless subset of `workspace::AppState`.
|
||||
pub struct ZetaCliAppState {
|
||||
pub struct EpAppState {
|
||||
pub languages: Arc<LanguageRegistry>,
|
||||
pub client: Arc<Client>,
|
||||
pub user_store: Entity<UserStore>,
|
||||
pub fs: Arc<dyn fs::Fs>,
|
||||
pub node_runtime: NodeRuntime,
|
||||
pub project_cache: ProjectCache,
|
||||
}
|
||||
|
||||
// TODO: dedupe with crates/eval/src/eval.rs
|
||||
pub fn init(cx: &mut App) -> ZetaCliAppState {
|
||||
#[derive(Default)]
|
||||
pub struct ProjectCache(Mutex<HashMap<String, Entity<Project>>>);
|
||||
|
||||
impl ProjectCache {
|
||||
pub fn insert(&self, repository_url: String, project: Entity<Project>) {
|
||||
self.0.lock().unwrap().insert(repository_url, project);
|
||||
}
|
||||
|
||||
pub fn get(&self, repository_url: &String) -> Option<Entity<Project>> {
|
||||
self.0.lock().unwrap().get(repository_url).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut App) -> EpAppState {
|
||||
let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
|
||||
|
||||
let app_version = AppVersion::load(
|
||||
@@ -112,11 +127,14 @@ pub fn init(cx: &mut App) -> ZetaCliAppState {
|
||||
prompt_store::init(cx);
|
||||
terminal_view::init(cx);
|
||||
|
||||
ZetaCliAppState {
|
||||
let project_cache = ProjectCache::default();
|
||||
|
||||
EpAppState {
|
||||
languages,
|
||||
client,
|
||||
user_store,
|
||||
fs,
|
||||
node_runtime,
|
||||
project_cache,
|
||||
}
|
||||
}
|
||||
|
||||
346
crates/edit_prediction_cli/src/load_project.rs
Normal file
346
crates/edit_prediction_cli/src/load_project.rs
Normal file
@@ -0,0 +1,346 @@
|
||||
use crate::{
|
||||
example::{Example, ExampleBuffer, ExampleState},
|
||||
headless::EpAppState,
|
||||
paths::{REPOS_DIR, WORKTREES_DIR},
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
|
||||
use project::buffer_store::BufferStoreEvent;
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
|
||||
pub async fn run_load_project(
|
||||
example: &mut Example,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
if example.state.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let progress = Progress::global().start(Step::LoadProject, &example.name);
|
||||
|
||||
let project = setup_project(example, &app_state, &progress, &mut cx).await?;
|
||||
|
||||
let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
|
||||
|
||||
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
|
||||
let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
|
||||
let cursor_point = cursor_position.to_point(&buffer);
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.map(|l| l.name().to_string())
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
(
|
||||
ExampleBuffer {
|
||||
content: buffer.text(),
|
||||
cursor_row: cursor_point.row,
|
||||
cursor_column: cursor_point.column,
|
||||
cursor_offset: cursor_position.to_offset(&buffer),
|
||||
},
|
||||
language_name,
|
||||
)
|
||||
})?;
|
||||
|
||||
progress.set_info(language_name, InfoStyle::Normal);
|
||||
|
||||
example.buffer = Some(example_buffer);
|
||||
example.state = Some(ExampleState {
|
||||
buffer,
|
||||
project,
|
||||
cursor_position,
|
||||
_open_buffers,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cursor_position(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(Entity<Buffer>, Anchor)> {
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
|
||||
let result = language_registry
|
||||
.load_language_for_file_path(&example.cursor_path)
|
||||
.await;
|
||||
|
||||
if let Err(error) = result
|
||||
&& !error.is::<LanguageNotFound>()
|
||||
{
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
let worktree = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.context("No visible worktrees")
|
||||
})??;
|
||||
|
||||
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
|
||||
.context("Failed to create RelPath")?
|
||||
.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: cursor_path,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
let cursor_offset_within_excerpt = example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.context("missing cursor marker")?;
|
||||
let mut cursor_excerpt = example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
"",
|
||||
);
|
||||
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let (excerpt_offset, _) = matches.next().with_context(|| {
|
||||
format!(
|
||||
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
|
||||
example.name
|
||||
)
|
||||
})?;
|
||||
anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
|
||||
Ok(excerpt_offset)
|
||||
})??;
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor =
|
||||
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
|
||||
|
||||
Ok((cursor_buffer, cursor_anchor))
|
||||
}
|
||||
|
||||
async fn setup_project(
|
||||
example: &mut Example,
|
||||
app_state: &Arc<EpAppState>,
|
||||
step_progress: &StepProgress,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<Project>> {
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx))?
|
||||
.context("Store should be initialized at init")?;
|
||||
|
||||
let worktree_path = setup_worktree(example, step_progress).await?;
|
||||
|
||||
if let Some(project) = app_state.project_cache.get(&example.repository_url) {
|
||||
ep_store.update(cx, |ep_store, _| {
|
||||
ep_store.clear_history_for_project(&project);
|
||||
})?;
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
let buffers = buffer_store.read_with(cx, |buffer_store, _| {
|
||||
buffer_store.buffers().collect::<Vec<_>>()
|
||||
})?;
|
||||
for buffer in buffers {
|
||||
buffer
|
||||
.update(cx, |buffer, cx| buffer.reload(cx))?
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
return Ok(project);
|
||||
}
|
||||
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.disable_worktree_scanner(cx);
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
app_state
|
||||
.project_cache
|
||||
.insert(example.repository_url.clone(), project.clone());
|
||||
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
move |_, event, cx| match event {
|
||||
BufferStoreEvent::BufferAdded(buffer) => {
|
||||
ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})?
|
||||
.detach();
|
||||
|
||||
Ok(project)
|
||||
}
|
||||
|
||||
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let worktree_path = WORKTREES_DIR
|
||||
.join(repo_owner.as_ref())
|
||||
.join(repo_name.as_ref());
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
step_progress.set_substatus(format!("cloning {}", repo_name));
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &example.repository_url],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
let revision = run_git(
|
||||
&repo_dir,
|
||||
&["rev-parse", &format!("{}^{{commit}}", example.revision)],
|
||||
)
|
||||
.await;
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
} else {
|
||||
step_progress.set_substatus("fetching");
|
||||
if run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &example.revision],
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await?;
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
revision
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
step_progress.set_substatus("preparing worktree");
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["branch", "-f", &example.name, revision.as_str()],
|
||||
)
|
||||
.await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
"worktree",
|
||||
"add",
|
||||
"-f",
|
||||
&worktree_path_string,
|
||||
&example.name,
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
// Apply the uncommitted diff for this example.
|
||||
if !example.uncommitted_diff.is_empty() {
|
||||
step_progress.set_substatus("applying diff");
|
||||
let mut apply_process = smol::process::Command::new("git")
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
|
||||
stdin.write_all(example.uncommitted_diff.as_bytes()).await?;
|
||||
stdin.close().await?;
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await?;
|
||||
anyhow::ensure!(
|
||||
apply_result.status.success(),
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
}
|
||||
|
||||
step_progress.clear_substatus();
|
||||
Ok(worktree_path)
|
||||
}
|
||||
|
||||
async fn apply_edit_history(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<OpenedBuffers> {
|
||||
edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
REPO_LOCKS
|
||||
.with(|cell| {
|
||||
cell.borrow_mut()
|
||||
.entry(path.as_ref().to_path_buf())
|
||||
.or_default()
|
||||
.clone()
|
||||
})
|
||||
.lock_owned()
|
||||
.await
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
args.join(" "),
|
||||
repo_path.display(),
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
}
|
||||
@@ -1,523 +1,340 @@
|
||||
mod evaluate;
|
||||
mod anthropic_client;
|
||||
mod distill;
|
||||
mod example;
|
||||
mod format_prompt;
|
||||
mod headless;
|
||||
mod load_project;
|
||||
mod metrics;
|
||||
mod paths;
|
||||
mod predict;
|
||||
mod source_location;
|
||||
mod training;
|
||||
mod util;
|
||||
mod progress;
|
||||
mod retrieve_context;
|
||||
mod score;
|
||||
|
||||
use crate::{
|
||||
evaluate::run_evaluate,
|
||||
example::{ExampleFormat, NamedExample},
|
||||
headless::ZetaCliAppState,
|
||||
predict::run_predict,
|
||||
source_location::SourceLocation,
|
||||
training::{context::ContextType, distill::run_distill},
|
||||
util::{open_buffer, open_buffer_with_language_server},
|
||||
};
|
||||
use ::util::{ResultExt, paths::PathStyle};
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Parser, Subcommand, ValueEnum};
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use edit_prediction_context::EditPredictionExcerptOptions;
|
||||
use gpui::{Application, AsyncApp, Entity, prelude::*};
|
||||
use language::{Bias, Buffer, BufferSnapshot, Point};
|
||||
use metrics::delta_chr_f;
|
||||
use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
|
||||
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use gpui::Application;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use std::io::{self};
|
||||
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Display;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use crate::distill::run_distill;
|
||||
use crate::example::{group_examples_by_repo, read_examples, write_examples};
|
||||
use crate::format_prompt::run_format_prompt;
|
||||
use crate::load_project::run_load_project;
|
||||
use crate::paths::FAILED_EXAMPLES_DIR;
|
||||
use crate::predict::run_prediction;
|
||||
use crate::progress::Progress;
|
||||
use crate::retrieve_context::run_context_retrieval;
|
||||
use crate::score::run_scoring;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "zeta")]
|
||||
struct ZetaCliArgs {
|
||||
#[command(name = "ep")]
|
||||
struct EpArgs {
|
||||
#[arg(long, default_value_t = false)]
|
||||
printenv: bool,
|
||||
#[clap(long, default_value_t = 10, global = true)]
|
||||
max_parallelism: usize,
|
||||
#[command(subcommand)]
|
||||
command: Option<Command>,
|
||||
#[clap(global = true)]
|
||||
inputs: Vec<PathBuf>,
|
||||
#[arg(long, short, global = true)]
|
||||
output: Option<PathBuf>,
|
||||
#[arg(long, short, global = true)]
|
||||
in_place: bool,
|
||||
#[arg(long, short, global = true)]
|
||||
failfast: bool,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Command {
|
||||
Context(ContextArgs),
|
||||
Predict(PredictArguments),
|
||||
Eval(EvaluateArguments),
|
||||
Distill(DistillArguments),
|
||||
ConvertExample {
|
||||
path: PathBuf,
|
||||
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
|
||||
output_format: ExampleFormat,
|
||||
},
|
||||
Score {
|
||||
golden_patch: PathBuf,
|
||||
actual_patch: PathBuf,
|
||||
},
|
||||
/// Parse markdown examples and output a combined .jsonl file
|
||||
ParseExample,
|
||||
/// Create git worktrees for each example and load file contents
|
||||
LoadProject,
|
||||
/// Retrieve context for input examples.
|
||||
Context,
|
||||
/// Generate a prompt string for a specific model
|
||||
FormatPrompt(FormatPromptArgs),
|
||||
/// Runs edit prediction
|
||||
Predict(PredictArgs),
|
||||
/// Computes a score based on actual and expected patches
|
||||
Score(PredictArgs),
|
||||
/// Prepares a distillation dataset by copying expected outputs to
|
||||
/// predicted outputs and removing actual outputs and prompts.
|
||||
Distill,
|
||||
/// Print aggregated scores
|
||||
Eval(PredictArgs),
|
||||
/// Remove git repositories and worktrees
|
||||
Clean,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
struct ContextArgs {
|
||||
#[arg(long)]
|
||||
provider: ContextProvider,
|
||||
#[arg(long)]
|
||||
worktree: PathBuf,
|
||||
#[arg(long)]
|
||||
cursor: SourceLocation,
|
||||
#[arg(long)]
|
||||
use_language_server: bool,
|
||||
#[arg(long)]
|
||||
edit_history: Option<FileOrStdin>,
|
||||
#[clap(flatten)]
|
||||
zeta2_args: Zeta2Args,
|
||||
impl Display for Command {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Command::ParseExample => write!(f, "parse-example"),
|
||||
Command::LoadProject => write!(f, "load-project"),
|
||||
Command::Context => write!(f, "context"),
|
||||
Command::FormatPrompt(format_prompt_args) => write!(
|
||||
f,
|
||||
"format-prompt --prompt-format={}",
|
||||
format_prompt_args
|
||||
.prompt_format
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Predict(predict_args) => {
|
||||
write!(
|
||||
f,
|
||||
"predict --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
)
|
||||
}
|
||||
Command::Score(predict_args) => {
|
||||
write!(
|
||||
f,
|
||||
"score --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
)
|
||||
}
|
||||
Command::Distill => write!(f, "distill"),
|
||||
Command::Eval(predict_args) => write!(
|
||||
f,
|
||||
"eval --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Clean => write!(f, "clean"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
|
||||
enum ContextProvider {
|
||||
Zeta1,
|
||||
#[default]
|
||||
#[derive(Debug, Args)]
|
||||
struct FormatPromptArgs {
|
||||
#[clap(long)]
|
||||
prompt_format: PromptFormat,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
|
||||
enum PromptFormat {
|
||||
Teacher,
|
||||
Zeta2,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Args)]
|
||||
struct Zeta2Args {
|
||||
#[arg(long, default_value_t = 8192)]
|
||||
max_prompt_bytes: usize,
|
||||
#[arg(long, default_value_t = 2048)]
|
||||
max_excerpt_bytes: usize,
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
min_excerpt_bytes: usize,
|
||||
#[arg(long, default_value_t = 0.66)]
|
||||
target_before_cursor_over_total_bytes: f32,
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
max_diagnostic_bytes: usize,
|
||||
#[arg(long, value_enum, default_value_t = PromptFormat::default())]
|
||||
prompt_format: PromptFormat,
|
||||
#[arg(long, value_enum, default_value_t = Default::default())]
|
||||
output_format: OutputFormat,
|
||||
#[arg(long, default_value_t = 42)]
|
||||
file_indexing_parallelism: usize,
|
||||
#[arg(long, default_value_t = false)]
|
||||
disable_imports_gathering: bool,
|
||||
#[arg(long, default_value_t = u8::MAX)]
|
||||
max_retrieved_definitions: u8,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct PredictArguments {
|
||||
#[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
|
||||
format: PredictionsOutputFormat,
|
||||
example_path: PathBuf,
|
||||
#[clap(flatten)]
|
||||
options: PredictionOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct DistillArguments {
|
||||
split_commit_dataset: PathBuf,
|
||||
#[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
|
||||
context_type: ContextType,
|
||||
#[clap(long)]
|
||||
batch: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Args)]
|
||||
pub struct PredictionOptions {
|
||||
#[clap(flatten)]
|
||||
zeta2: Zeta2Args,
|
||||
struct PredictArgs {
|
||||
#[clap(long)]
|
||||
provider: PredictionProvider,
|
||||
#[clap(long, value_enum, default_value_t = CacheMode::default())]
|
||||
cache: CacheMode,
|
||||
#[clap(long, default_value_t = 1)]
|
||||
repetitions: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
|
||||
pub enum CacheMode {
|
||||
/// Use cached LLM requests and responses, except when multiple repetitions are requested
|
||||
#[default]
|
||||
Auto,
|
||||
/// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
|
||||
#[value(alias = "request")]
|
||||
Requests,
|
||||
/// Ignore existing cache entries for both LLM and search.
|
||||
Skip,
|
||||
/// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
|
||||
/// Useful for reproducing results and fixing bugs outside of search queries
|
||||
Force,
|
||||
}
|
||||
|
||||
impl CacheMode {
|
||||
fn use_cached_llm_responses(&self) -> bool {
|
||||
self.assert_not_auto();
|
||||
matches!(self, CacheMode::Requests | CacheMode::Force)
|
||||
}
|
||||
|
||||
fn use_cached_search_results(&self) -> bool {
|
||||
self.assert_not_auto();
|
||||
matches!(self, CacheMode::Force)
|
||||
}
|
||||
|
||||
fn assert_not_auto(&self) {
|
||||
assert_ne!(
|
||||
*self,
|
||||
CacheMode::Auto,
|
||||
"Cache mode should not be auto at this point!"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Debug, Clone)]
|
||||
pub enum PredictionsOutputFormat {
|
||||
Json,
|
||||
Md,
|
||||
Diff,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct EvaluateArguments {
|
||||
example_paths: Vec<PathBuf>,
|
||||
#[clap(flatten)]
|
||||
options: PredictionOptions,
|
||||
#[clap(short, long, default_value_t = 1, alias = "repeat")]
|
||||
repetitions: u16,
|
||||
#[arg(long)]
|
||||
skip_prediction: bool,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
|
||||
enum PredictionProvider {
|
||||
Zeta1,
|
||||
#[default]
|
||||
Zeta2,
|
||||
Sweep,
|
||||
Mercury,
|
||||
Zeta1,
|
||||
Zeta2,
|
||||
Teacher,
|
||||
TeacherNonBatching,
|
||||
}
|
||||
|
||||
fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
|
||||
edit_prediction::ZetaOptions {
|
||||
context: EditPredictionExcerptOptions {
|
||||
max_bytes: args.max_excerpt_bytes,
|
||||
min_bytes: args.min_excerpt_bytes,
|
||||
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
|
||||
},
|
||||
max_prompt_bytes: args.max_prompt_bytes,
|
||||
prompt_format: args.prompt_format.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
|
||||
enum PromptFormat {
|
||||
OnlySnippets,
|
||||
#[default]
|
||||
OldTextNewText,
|
||||
Minimal,
|
||||
MinimalQwen,
|
||||
SeedCoder1120,
|
||||
}
|
||||
|
||||
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
|
||||
fn into(self) -> predict_edits_v3::PromptFormat {
|
||||
match self {
|
||||
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
|
||||
Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
|
||||
Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
|
||||
Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
|
||||
Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone)]
|
||||
enum OutputFormat {
|
||||
#[default]
|
||||
Prompt,
|
||||
Request,
|
||||
Full,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum FileOrStdin {
|
||||
File(PathBuf),
|
||||
Stdin,
|
||||
}
|
||||
|
||||
impl FileOrStdin {
|
||||
async fn read_to_string(&self) -> Result<String, std::io::Error> {
|
||||
match self {
|
||||
FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
|
||||
FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for FileOrStdin {
|
||||
type Err = <PathBuf as FromStr>::Err;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"-" => Ok(Self::Stdin),
|
||||
_ => Ok(Self::File(PathBuf::from_str(s)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LoadedContext {
|
||||
full_path_str: String,
|
||||
snapshot: BufferSnapshot,
|
||||
clipped_cursor: Point,
|
||||
worktree: Entity<Worktree>,
|
||||
project: Entity<Project>,
|
||||
buffer: Entity<Buffer>,
|
||||
lsp_open_handle: Option<OpenLspBufferHandle>,
|
||||
}
|
||||
|
||||
async fn load_context(
|
||||
args: &ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<LoadedContext> {
|
||||
let ContextArgs {
|
||||
worktree: worktree_path,
|
||||
cursor,
|
||||
use_language_server,
|
||||
..
|
||||
} = args;
|
||||
|
||||
let worktree_path = worktree_path.canonicalize()?;
|
||||
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let mut ready_languages = HashSet::default();
|
||||
let (lsp_open_handle, buffer) = if *use_language_server {
|
||||
let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
|
||||
project.clone(),
|
||||
worktree.clone(),
|
||||
cursor.path.clone(),
|
||||
&mut ready_languages,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
(Some(lsp_open_handle), buffer)
|
||||
} else {
|
||||
let buffer =
|
||||
open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
|
||||
(None, buffer)
|
||||
};
|
||||
|
||||
let full_path_str = worktree
|
||||
.read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
|
||||
.display(PathStyle::local())
|
||||
.to_string();
|
||||
|
||||
let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
|
||||
let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
|
||||
if clipped_cursor != cursor.point {
|
||||
let max_row = snapshot.max_point().row;
|
||||
if cursor.point.row < max_row {
|
||||
return Err(anyhow!(
|
||||
"Cursor position {:?} is out of bounds (line length is {})",
|
||||
cursor.point,
|
||||
snapshot.line_len(cursor.point.row)
|
||||
));
|
||||
impl EpArgs {
|
||||
fn output_path(&self) -> Option<PathBuf> {
|
||||
if self.in_place {
|
||||
if self.inputs.len() == 1 {
|
||||
self.inputs.first().cloned()
|
||||
} else {
|
||||
panic!("--in-place requires exactly one input file")
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Cursor position {:?} is out of bounds (max row is {})",
|
||||
cursor.point,
|
||||
max_row
|
||||
));
|
||||
self.output.clone()
|
||||
}
|
||||
}
|
||||
|
||||
Ok(LoadedContext {
|
||||
full_path_str,
|
||||
snapshot,
|
||||
clipped_cursor,
|
||||
worktree,
|
||||
project,
|
||||
buffer,
|
||||
lsp_open_handle,
|
||||
})
|
||||
}
|
||||
|
||||
async fn zeta2_context(
|
||||
args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<String> {
|
||||
let LoadedContext {
|
||||
worktree,
|
||||
project,
|
||||
buffer,
|
||||
clipped_cursor,
|
||||
lsp_open_handle: _handle,
|
||||
..
|
||||
} = load_context(&args, app_state, cx).await?;
|
||||
|
||||
// wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
|
||||
// the whole worktree.
|
||||
worktree
|
||||
.read_with(cx, |worktree, _cx| {
|
||||
worktree.as_local().unwrap().scan_complete()
|
||||
})?
|
||||
.await;
|
||||
let output = cx
|
||||
.update(|cx| {
|
||||
let store = cx.new(|cx| {
|
||||
edit_prediction::EditPredictionStore::new(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
store.update(cx, |store, cx| {
|
||||
store.set_options(zeta2_args_to_options(&args.zeta2_args));
|
||||
store.register_buffer(&buffer, &project, cx);
|
||||
});
|
||||
cx.spawn(async move |cx| {
|
||||
let updates_rx = store.update(cx, |store, cx| {
|
||||
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &buffer, cursor, cx);
|
||||
store.project_context_updates(&project).unwrap()
|
||||
})?;
|
||||
|
||||
updates_rx.recv().await.ok();
|
||||
|
||||
let context = store.update(cx, |store, cx| {
|
||||
store.context_for_project(&project, cx).to_vec()
|
||||
})?;
|
||||
|
||||
anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
async fn zeta1_context(
|
||||
args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<edit_prediction::zeta1::GatherContextOutput> {
|
||||
let LoadedContext {
|
||||
full_path_str,
|
||||
snapshot,
|
||||
clipped_cursor,
|
||||
..
|
||||
} = load_context(&args, app_state, cx).await?;
|
||||
|
||||
let events = match args.edit_history {
|
||||
Some(events) => events.read_to_string().await?,
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let prompt_for_events = move || (events, 0);
|
||||
cx.update(|cx| {
|
||||
edit_prediction::zeta1::gather_context(
|
||||
full_path_str,
|
||||
&snapshot,
|
||||
clipped_cursor,
|
||||
prompt_for_events,
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
}
|
||||
|
||||
fn main() {
|
||||
zlog::init();
|
||||
zlog::init_output_stderr();
|
||||
let args = ZetaCliArgs::parse();
|
||||
let args = EpArgs::parse();
|
||||
|
||||
if args.printenv {
|
||||
::util::shell_env::print_env();
|
||||
return;
|
||||
}
|
||||
|
||||
let output = args.output_path();
|
||||
let command = match args.command {
|
||||
Some(cmd) => cmd,
|
||||
None => {
|
||||
EpArgs::command().print_help().unwrap();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match &command {
|
||||
Command::Clean => {
|
||||
std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
|
||||
return;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let mut examples = read_examples(&args.inputs);
|
||||
let http_client = Arc::new(ReqwestClient::new());
|
||||
let app = Application::headless().with_http_client(http_client);
|
||||
|
||||
app.run(move |cx| {
|
||||
let app_state = Arc::new(headless::init(cx));
|
||||
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
match args.command {
|
||||
None => {
|
||||
if args.printenv {
|
||||
::util::shell_env::print_env();
|
||||
} else {
|
||||
panic!("Expected a command");
|
||||
}
|
||||
let result = async {
|
||||
if let Command::Predict(args) = &command {
|
||||
predict::sync_batches(&args.provider).await?;
|
||||
}
|
||||
Some(Command::Context(context_args)) => {
|
||||
let result = match context_args.provider {
|
||||
ContextProvider::Zeta1 => {
|
||||
let context =
|
||||
zeta1_context(context_args, &app_state, cx).await.unwrap();
|
||||
serde_json::to_string_pretty(&context.body).unwrap()
|
||||
|
||||
let total_examples = examples.len();
|
||||
Progress::global().set_total_examples(total_examples);
|
||||
|
||||
let mut grouped_examples = group_examples_by_repo(&mut examples);
|
||||
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
|
||||
|
||||
for example_batch in example_batches {
|
||||
let futures = example_batch.into_iter().map(|repo_examples| async {
|
||||
for example in repo_examples.iter_mut() {
|
||||
let result = async {
|
||||
match &command {
|
||||
Command::ParseExample => {}
|
||||
Command::LoadProject => {
|
||||
run_load_project(example, app_state.clone(), cx.clone())
|
||||
.await?;
|
||||
}
|
||||
Command::Context => {
|
||||
run_context_retrieval(
|
||||
example,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::FormatPrompt(args) => {
|
||||
run_format_prompt(
|
||||
example,
|
||||
args.prompt_format,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::Predict(args) => {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::Distill => {
|
||||
run_distill(example).await?;
|
||||
}
|
||||
Command::Score(args) | Command::Eval(args) => {
|
||||
run_scoring(example, &args, app_state.clone(), cx.clone())
|
||||
.await?;
|
||||
}
|
||||
Command::Clean => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
Progress::global().increment_failed();
|
||||
let failed_example_path =
|
||||
FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name));
|
||||
app_state
|
||||
.fs
|
||||
.write(
|
||||
&failed_example_path,
|
||||
&serde_json::to_vec_pretty(&example).unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let err_path =
|
||||
FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name));
|
||||
app_state
|
||||
.fs
|
||||
.write(&err_path, e.to_string().as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let msg = format!(
|
||||
indoc::indoc! {"
|
||||
While processing {}:
|
||||
|
||||
{:?}
|
||||
|
||||
Written to: \x1b[36m{}\x1b[0m
|
||||
|
||||
Explore this example data with:
|
||||
fx \x1b[36m{}\x1b[0m
|
||||
|
||||
Re-run this example with:
|
||||
cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
|
||||
"},
|
||||
example.name,
|
||||
e,
|
||||
err_path.display(),
|
||||
failed_example_path.display(),
|
||||
command,
|
||||
failed_example_path.display(),
|
||||
);
|
||||
if args.failfast || total_examples == 1 {
|
||||
Progress::global().finalize();
|
||||
panic!("{}", msg);
|
||||
} else {
|
||||
log::error!("{}", msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
ContextProvider::Zeta2 => {
|
||||
zeta2_context(context_args, &app_state, cx).await.unwrap()
|
||||
}
|
||||
};
|
||||
println!("{}", result);
|
||||
});
|
||||
futures::future::join_all(futures).await;
|
||||
}
|
||||
Some(Command::Predict(arguments)) => {
|
||||
run_predict(arguments, &app_state, cx).await;
|
||||
}
|
||||
Some(Command::Eval(arguments)) => {
|
||||
run_evaluate(arguments, &app_state, cx).await;
|
||||
}
|
||||
Some(Command::Distill(arguments)) => {
|
||||
let _guard = cx
|
||||
.update(|cx| gpui_tokio::Tokio::handle(cx))
|
||||
.unwrap()
|
||||
.enter();
|
||||
run_distill(arguments).await.log_err();
|
||||
}
|
||||
Some(Command::ConvertExample {
|
||||
path,
|
||||
output_format,
|
||||
}) => {
|
||||
let example = NamedExample::load(path).unwrap();
|
||||
example.write(output_format, io::stdout()).unwrap();
|
||||
}
|
||||
Some(Command::Score {
|
||||
golden_patch,
|
||||
actual_patch,
|
||||
}) => {
|
||||
let golden_content = std::fs::read_to_string(golden_patch).unwrap();
|
||||
let actual_content = std::fs::read_to_string(actual_patch).unwrap();
|
||||
Progress::global().finalize();
|
||||
|
||||
let golden_diff: Vec<DiffLine> = golden_content
|
||||
.lines()
|
||||
.map(|line| DiffLine::parse(line))
|
||||
.collect();
|
||||
|
||||
let actual_diff: Vec<DiffLine> = actual_content
|
||||
.lines()
|
||||
.map(|line| DiffLine::parse(line))
|
||||
.collect();
|
||||
|
||||
let score = delta_chr_f(&golden_diff, &actual_diff);
|
||||
println!("{:.2}", score);
|
||||
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
|
||||
write_examples(&examples, output.as_ref());
|
||||
}
|
||||
Some(Command::Clean) => {
|
||||
std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await?,
|
||||
Command::Eval(_) => score::print_report(&examples),
|
||||
_ => (),
|
||||
};
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
panic!("Fatal error: {:?}", e);
|
||||
}
|
||||
|
||||
let _ = cx.update(|cx| cx.quit());
|
||||
})
|
||||
|
||||
@@ -1,30 +1,34 @@
|
||||
use collections::{HashMap, HashSet};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
type Counts = HashMap<String, usize>;
|
||||
type CountsDelta = HashMap<String, isize>;
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Scores {
|
||||
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClassificationMetrics {
|
||||
pub true_positives: usize,
|
||||
pub false_positives: usize,
|
||||
pub false_negatives: usize,
|
||||
}
|
||||
|
||||
impl Scores {
|
||||
pub fn from_sets(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
|
||||
impl ClassificationMetrics {
|
||||
pub fn from_sets(
|
||||
expected: &HashSet<String>,
|
||||
actual: &HashSet<String>,
|
||||
) -> ClassificationMetrics {
|
||||
let true_positives = expected.intersection(actual).count();
|
||||
let false_positives = actual.difference(expected).count();
|
||||
let false_negatives = expected.difference(actual).count();
|
||||
|
||||
Scores {
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores {
|
||||
pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
|
||||
let mut true_positives = 0;
|
||||
let mut false_positives = 0;
|
||||
let mut false_negatives = 0;
|
||||
@@ -45,32 +49,16 @@ impl Scores {
|
||||
}
|
||||
}
|
||||
|
||||
Scores {
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
format!(
|
||||
"
|
||||
Precision : {:.4}
|
||||
Recall : {:.4}
|
||||
F1 Score : {:.4}
|
||||
True Positives : {}
|
||||
False Positives : {}
|
||||
False Negatives : {}",
|
||||
self.precision(),
|
||||
self.recall(),
|
||||
self.f1_score(),
|
||||
self.true_positives,
|
||||
self.false_positives,
|
||||
self.false_negatives
|
||||
)
|
||||
}
|
||||
|
||||
pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
|
||||
pub fn aggregate<'a>(
|
||||
scores: impl Iterator<Item = &'a ClassificationMetrics>,
|
||||
) -> ClassificationMetrics {
|
||||
let mut true_positives = 0;
|
||||
let mut false_positives = 0;
|
||||
let mut false_negatives = 0;
|
||||
@@ -81,7 +69,7 @@ False Negatives : {}",
|
||||
false_negatives += score.false_negatives;
|
||||
}
|
||||
|
||||
Scores {
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
@@ -115,7 +103,10 @@ False Negatives : {}",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores {
|
||||
pub fn line_match_score(
|
||||
expected_patch: &[DiffLine],
|
||||
actual_patch: &[DiffLine],
|
||||
) -> ClassificationMetrics {
|
||||
let expected_change_lines = expected_patch
|
||||
.iter()
|
||||
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
|
||||
@@ -128,7 +119,7 @@ pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine])
|
||||
.map(|line| line.to_string())
|
||||
.collect();
|
||||
|
||||
Scores::from_sets(&expected_change_lines, &actual_change_lines)
|
||||
ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
|
||||
}
|
||||
|
||||
enum ChrfWhitespace {
|
||||
@@ -204,7 +195,7 @@ pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
|
||||
let expected_counts = ngram_delta_to_counts(&expected_delta);
|
||||
let actual_counts = ngram_delta_to_counts(&actual_delta);
|
||||
|
||||
let score = Scores::from_counts(&expected_counts, &actual_counts);
|
||||
let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
|
||||
total_precision += score.precision();
|
||||
total_recall += score.recall();
|
||||
}
|
||||
|
||||
@@ -1,57 +1,27 @@
|
||||
use std::{env, path::PathBuf, sync::LazyLock};
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::LazyLock,
|
||||
};
|
||||
|
||||
pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
|
||||
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
|
||||
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
|
||||
pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
|
||||
pub static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
let dir = dirs::home_dir().unwrap().join(".zed_ep");
|
||||
ensure_dir(&dir)
|
||||
});
|
||||
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("cache")));
|
||||
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("repos")));
|
||||
pub static WORKTREES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| ensure_dir(&DATA_DIR.join("worktrees")));
|
||||
pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
TARGET_ZETA_DIR
|
||||
DATA_DIR
|
||||
.join("runs")
|
||||
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
|
||||
});
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
|
||||
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
|
||||
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
|
||||
|
||||
pub fn print_run_data_dir(deep: bool, use_color: bool) {
|
||||
println!("\n## Run Data\n");
|
||||
let mut files = Vec::new();
|
||||
|
||||
let current_dir = std::env::current_dir().unwrap();
|
||||
for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
|
||||
let file = file.unwrap();
|
||||
if file.file_type().unwrap().is_dir() && deep {
|
||||
for file in std::fs::read_dir(file.path()).unwrap() {
|
||||
let path = file.unwrap().path();
|
||||
let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
|
||||
files.push(format!(
|
||||
"- {}/{}{}{}",
|
||||
path.parent().unwrap().display(),
|
||||
if use_color { "\x1b[34m" } else { "" },
|
||||
path.file_name().unwrap().display(),
|
||||
if use_color { "\x1b[0m" } else { "" },
|
||||
));
|
||||
}
|
||||
} else {
|
||||
let path = file.path();
|
||||
let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
|
||||
files.push(format!(
|
||||
"- {}/{}{}{}",
|
||||
path.parent().unwrap().display(),
|
||||
if use_color { "\x1b[34m" } else { "" },
|
||||
path.file_name().unwrap().display(),
|
||||
if use_color { "\x1b[0m" } else { "" }
|
||||
));
|
||||
}
|
||||
}
|
||||
files.sort();
|
||||
|
||||
for file in files {
|
||||
println!("{}", file);
|
||||
}
|
||||
|
||||
println!(
|
||||
"\n💡 Tip of the day: {} always points to the latest run\n",
|
||||
LATEST_EXAMPLE_RUN_DIR.display()
|
||||
);
|
||||
fn ensure_dir(path: &Path) -> PathBuf {
|
||||
std::fs::create_dir_all(path).expect("Failed to create directory");
|
||||
path.to_path_buf()
|
||||
}
|
||||
|
||||
@@ -1,374 +1,291 @@
|
||||
use crate::example::{ActualExcerpt, NamedExample};
|
||||
use crate::headless::ZetaCliAppState;
|
||||
use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
|
||||
use crate::{
|
||||
CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
|
||||
PredictionProvider, PromptFormat,
|
||||
anthropic_client::AnthropicClient,
|
||||
example::{Example, ExamplePrediction},
|
||||
format_prompt::{TeacherPrompt, run_format_prompt},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
|
||||
progress::{InfoStyle, Progress, Step},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, future::Shared};
|
||||
use gpui::{AppContext as _, AsyncApp, Task};
|
||||
use std::{
|
||||
fs,
|
||||
sync::{
|
||||
Arc, Mutex, OnceLock,
|
||||
atomic::{AtomicUsize, Ordering::SeqCst},
|
||||
},
|
||||
};
|
||||
use ::serde::Serialize;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
|
||||
use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{AppContext, AsyncApp, Entity};
|
||||
use project::Project;
|
||||
use project::buffer_store::BufferStoreEvent;
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use std::io::{IsTerminal, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
pub async fn run_predict(
|
||||
args: PredictArguments,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let example = NamedExample::load(args.example_path).unwrap();
|
||||
let project = example.setup_project(app_state, cx).await.unwrap();
|
||||
let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
|
||||
let result = perform_predict(example, project, store, None, args.options, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
result.write(args.format, std::io::stdout()).unwrap();
|
||||
pub async fn run_prediction(
|
||||
example: &mut Example,
|
||||
provider: Option<PredictionProvider>,
|
||||
repetition_count: usize,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
if !example.predictions.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
print_run_data_dir(true, std::io::stdout().is_terminal());
|
||||
}
|
||||
let provider = provider.context("provider is required")?;
|
||||
|
||||
pub fn setup_store(
|
||||
provider: PredictionProvider,
|
||||
project: &Entity<Project>,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<EditPredictionStore>> {
|
||||
let store = cx.new(|cx| {
|
||||
edit_prediction::EditPredictionStore::new(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
store.update(cx, |store, _cx| {
|
||||
if matches!(
|
||||
provider,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
|
||||
) {
|
||||
let _step_progress = Progress::global().start(Step::Predict, &example.name);
|
||||
|
||||
if example.prompt.is_none() {
|
||||
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
|
||||
}
|
||||
|
||||
let batched = matches!(provider, PredictionProvider::Teacher);
|
||||
return predict_anthropic(example, repetition_count, batched).await;
|
||||
}
|
||||
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let _step_progress = Progress::global().start(Step::Predict, &example.name);
|
||||
|
||||
if matches!(
|
||||
provider,
|
||||
PredictionProvider::Zeta1 | PredictionProvider::Zeta2
|
||||
) {
|
||||
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
|
||||
AUTHENTICATED
|
||||
.get_or_init(|| {
|
||||
let client = app_state.client.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
|
||||
eprintln!("Authentication failed: {}", e);
|
||||
}
|
||||
})
|
||||
.shared()
|
||||
})
|
||||
.clone()
|
||||
.await;
|
||||
}
|
||||
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
|
||||
ep_store.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
|
||||
unreachable!()
|
||||
}
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})?;
|
||||
let state = example.state.as_ref().context("state must be set")?;
|
||||
let run_dir = RUN_DIR.join(&example.name);
|
||||
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
let updated_example = Arc::new(Mutex::new(example.clone()));
|
||||
let current_run_ix = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
let store = store.clone();
|
||||
move |_, event, cx| match event {
|
||||
BufferStoreEvent::BufferAdded(buffer) => {
|
||||
store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})?
|
||||
.detach();
|
||||
let mut debug_rx =
|
||||
ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
|
||||
let debug_task = cx.background_spawn({
|
||||
let updated_example = updated_example.clone();
|
||||
let current_run_ix = current_run_ix.clone();
|
||||
let run_dir = run_dir.clone();
|
||||
async move {
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
let run_ix = current_run_ix.load(SeqCst);
|
||||
let mut updated_example = updated_example.lock().unwrap();
|
||||
|
||||
anyhow::Ok(store)
|
||||
}
|
||||
let run_dir = if repetition_count > 1 {
|
||||
run_dir.join(format!("{:03}", run_ix))
|
||||
} else {
|
||||
run_dir.clone()
|
||||
};
|
||||
|
||||
pub async fn perform_predict(
|
||||
example: NamedExample,
|
||||
project: Entity<Project>,
|
||||
store: Entity<EditPredictionStore>,
|
||||
repetition_ix: Option<u16>,
|
||||
options: PredictionOptions,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<PredictionDetails> {
|
||||
let mut cache_mode = options.cache;
|
||||
if repetition_ix.is_some() {
|
||||
if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
|
||||
panic!("Repetitions are not supported in Auto cache mode");
|
||||
} else {
|
||||
cache_mode = CacheMode::Skip;
|
||||
}
|
||||
} else if cache_mode == CacheMode::Auto {
|
||||
cache_mode = CacheMode::Requests;
|
||||
}
|
||||
match event {
|
||||
DebugEvent::EditPredictionStarted(request) => {
|
||||
assert_eq!(updated_example.predictions.len(), run_ix + 1);
|
||||
|
||||
let mut example_run_dir = RUN_DIR.join(&example.file_name());
|
||||
if let Some(repetition_ix) = repetition_ix {
|
||||
example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
|
||||
}
|
||||
fs::create_dir_all(&example_run_dir)?;
|
||||
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
|
||||
.context("creating latest link")?;
|
||||
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
|
||||
.context("creating latest link")?;
|
||||
|
||||
store.update(cx, |store, _cx| {
|
||||
store.with_eval_cache(Arc::new(RunCache {
|
||||
example_run_dir: example_run_dir.clone(),
|
||||
cache_mode,
|
||||
}));
|
||||
})?;
|
||||
|
||||
let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
|
||||
|
||||
let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
|
||||
|
||||
let prompt_format = options.zeta2.prompt_format;
|
||||
|
||||
store.update(cx, |store, _cx| {
|
||||
let mut options = store.options().clone();
|
||||
options.prompt_format = prompt_format.into();
|
||||
store.set_options(options);
|
||||
})?;
|
||||
|
||||
let mut debug_task = gpui::Task::ready(Ok(()));
|
||||
|
||||
if options.provider == crate::PredictionProvider::Zeta2 {
|
||||
let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
|
||||
|
||||
debug_task = cx.background_spawn({
|
||||
let result = result.clone();
|
||||
async move {
|
||||
let mut start_time = None;
|
||||
let mut retrieval_finished_at = None;
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
match event {
|
||||
edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
|
||||
start_time = Some(info.timestamp);
|
||||
fs::write(
|
||||
example_run_dir.join("search_prompt.md"),
|
||||
&info.search_prompt,
|
||||
)?;
|
||||
if let Some(prompt) = request.prompt {
|
||||
fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
|
||||
}
|
||||
edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
|
||||
retrieval_finished_at = Some(info.timestamp);
|
||||
for (key, value) in &info.metadata {
|
||||
if *key == "search_queries" {
|
||||
fs::write(
|
||||
example_run_dir.join("search_queries.json"),
|
||||
value.as_bytes(),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
DebugEvent::EditPredictionFinished(request) => {
|
||||
assert_eq!(updated_example.predictions.len(), run_ix + 1);
|
||||
|
||||
if let Some(output) = request.model_output {
|
||||
fs::write(run_dir.join("prediction_response.md"), &output)?;
|
||||
updated_example
|
||||
.predictions
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.actual_output = output;
|
||||
}
|
||||
edit_prediction::DebugEvent::EditPredictionRequested(request) => {
|
||||
let prediction_started_at = Instant::now();
|
||||
start_time.get_or_insert(prediction_started_at);
|
||||
let prompt = request.local_prompt.unwrap_or_default();
|
||||
fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
|
||||
|
||||
{
|
||||
let mut result = result.lock().unwrap();
|
||||
result.prompt_len = prompt.chars().count();
|
||||
|
||||
for included_file in request.inputs.included_files {
|
||||
let insertions =
|
||||
vec![(request.inputs.cursor_point, CURSOR_MARKER)];
|
||||
result.excerpts.extend(included_file.excerpts.iter().map(
|
||||
|excerpt| ActualExcerpt {
|
||||
path: included_file.path.components().skip(1).collect(),
|
||||
text: String::from(excerpt.text.as_ref()),
|
||||
},
|
||||
));
|
||||
write_codeblock(
|
||||
&included_file.path,
|
||||
included_file.excerpts.iter(),
|
||||
if included_file.path == request.inputs.cursor_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
included_file.max_row,
|
||||
false,
|
||||
&mut result.excerpts_text,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let response =
|
||||
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
|
||||
let response =
|
||||
edit_prediction::open_ai_response::text_from_response(response)
|
||||
.unwrap_or_default();
|
||||
let prediction_finished_at = Instant::now();
|
||||
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
|
||||
|
||||
let mut result = result.lock().unwrap();
|
||||
result.generated_len = response.chars().count();
|
||||
result.retrieval_time =
|
||||
retrieval_finished_at.unwrap() - start_time.unwrap();
|
||||
result.prediction_time = prediction_finished_at - prediction_started_at;
|
||||
result.total_time = prediction_finished_at - start_time.unwrap();
|
||||
|
||||
if run_ix >= repetition_count {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
store.update(cx, |store, cx| {
|
||||
store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
|
||||
})?;
|
||||
}
|
||||
|
||||
let prediction = store
|
||||
.update(cx, |store, cx| {
|
||||
store.request_prediction(
|
||||
&project,
|
||||
&cursor_buffer,
|
||||
cursor_anchor,
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
debug_task.await?;
|
||||
|
||||
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
|
||||
|
||||
result.diff = prediction
|
||||
.and_then(|prediction| {
|
||||
let prediction = prediction.prediction.ok()?;
|
||||
prediction.edit_preview.as_unified_diff(&prediction.edits)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
anyhow::Ok(result)
|
||||
}
|
||||
|
||||
struct RunCache {
|
||||
cache_mode: CacheMode,
|
||||
example_run_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl RunCache {
|
||||
fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
|
||||
CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
|
||||
}
|
||||
|
||||
fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
|
||||
CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
|
||||
}
|
||||
|
||||
fn link_to_run(&self, key: &EvalCacheKey) {
|
||||
let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
|
||||
fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
|
||||
|
||||
let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
|
||||
fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl EvalCache for RunCache {
|
||||
fn read(&self, key: EvalCacheKey) -> Option<String> {
|
||||
let path = RunCache::output_cache_path(&key);
|
||||
|
||||
if path.exists() {
|
||||
let use_cache = match key.0 {
|
||||
EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
|
||||
EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
|
||||
self.cache_mode.use_cached_llm_responses()
|
||||
}
|
||||
};
|
||||
if use_cache {
|
||||
log::info!("Using cache entry: {}", path.display());
|
||||
self.link_to_run(&key);
|
||||
Some(fs::read_to_string(path).unwrap())
|
||||
} else {
|
||||
log::trace!("Skipping cached entry: {}", path.display());
|
||||
None
|
||||
}
|
||||
} else if matches!(self.cache_mode, CacheMode::Force) {
|
||||
panic!(
|
||||
"No cached entry found for {:?}. Run without `--cache force` at least once.",
|
||||
key.0
|
||||
);
|
||||
for ix in 0..repetition_count {
|
||||
current_run_ix.store(ix, SeqCst);
|
||||
let run_dir = if repetition_count > 1 {
|
||||
run_dir.join(format!("{:03}", ix))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
|
||||
fs::create_dir_all(&*CACHE_DIR).unwrap();
|
||||
|
||||
let input_path = RunCache::input_cache_path(&key);
|
||||
fs::write(&input_path, input).unwrap();
|
||||
|
||||
let output_path = RunCache::output_cache_path(&key);
|
||||
log::trace!("Writing cache entry: {}", output_path.display());
|
||||
fs::write(&output_path, output).unwrap();
|
||||
|
||||
self.link_to_run(&key);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PredictionDetails {
|
||||
pub diff: String,
|
||||
pub excerpts: Vec<ActualExcerpt>,
|
||||
pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
|
||||
pub retrieval_time: Duration,
|
||||
pub prediction_time: Duration,
|
||||
pub total_time: Duration,
|
||||
pub run_example_dir: PathBuf,
|
||||
pub prompt_len: usize,
|
||||
pub generated_len: usize,
|
||||
}
|
||||
|
||||
impl PredictionDetails {
|
||||
pub fn new(run_example_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
diff: Default::default(),
|
||||
excerpts: Default::default(),
|
||||
excerpts_text: Default::default(),
|
||||
retrieval_time: Default::default(),
|
||||
prediction_time: Default::default(),
|
||||
total_time: Default::default(),
|
||||
run_example_dir,
|
||||
prompt_len: 0,
|
||||
generated_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
|
||||
let formatted = match format {
|
||||
PredictionsOutputFormat::Md => self.to_markdown(),
|
||||
PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
|
||||
PredictionsOutputFormat::Diff => self.diff.clone(),
|
||||
run_dir.clone()
|
||||
};
|
||||
|
||||
Ok(out.write_all(formatted.as_bytes())?)
|
||||
fs::create_dir_all(&run_dir)?;
|
||||
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
}
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
|
||||
updated_example
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predictions
|
||||
.push(ExamplePrediction {
|
||||
actual_patch: String::new(),
|
||||
actual_output: String::new(),
|
||||
provider,
|
||||
});
|
||||
|
||||
let prediction = ep_store
|
||||
.update(&mut cx, |store, cx| {
|
||||
store.request_prediction(
|
||||
&state.project,
|
||||
&state.buffer,
|
||||
state.cursor_position,
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let actual_patch = prediction
|
||||
.and_then(|prediction| {
|
||||
let prediction = prediction.prediction.ok()?;
|
||||
prediction.edit_preview.as_unified_diff(&prediction.edits)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let has_prediction = !actual_patch.is_empty();
|
||||
|
||||
updated_example
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predictions
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.actual_patch = actual_patch;
|
||||
|
||||
if ix == repetition_count - 1 {
|
||||
let (info, style) = if has_prediction {
|
||||
("predicted", InfoStyle::Normal)
|
||||
} else {
|
||||
("no prediction", InfoStyle::Warning)
|
||||
};
|
||||
_step_progress.set_info(info, style);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
format!(
|
||||
"## Excerpts\n\n\
|
||||
{}\n\n\
|
||||
## Prediction\n\n\
|
||||
{}\n\n\
|
||||
## Time\n\n\
|
||||
Retrieval: {}ms\n\
|
||||
Prediction: {}ms\n\n\
|
||||
Total: {}ms\n",
|
||||
self.excerpts_text,
|
||||
self.diff,
|
||||
self.retrieval_time.as_millis(),
|
||||
self.prediction_time.as_millis(),
|
||||
self.total_time.as_millis(),
|
||||
)
|
||||
}
|
||||
ep_store.update(&mut cx, |store, _| {
|
||||
store.remove_project(&state.project);
|
||||
})?;
|
||||
debug_task.await?;
|
||||
|
||||
*example = Arc::into_inner(updated_example)
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
|
||||
.into_inner()
|
||||
.map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn predict_anthropic(
|
||||
example: &mut Example,
|
||||
_repetition_count: usize,
|
||||
batched: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let llm_model_name = "claude-sonnet-4-5";
|
||||
let max_tokens = 16384;
|
||||
let llm_client = if batched {
|
||||
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
|
||||
} else {
|
||||
AnthropicClient::plain()
|
||||
};
|
||||
let llm_client = llm_client.context("Failed to create LLM client")?;
|
||||
|
||||
let prompt = example.prompt.as_ref().context("Prompt is required")?;
|
||||
|
||||
let messages = vec![anthropic::Message {
|
||||
role: anthropic::Role::User,
|
||||
content: vec![anthropic::RequestContent::Text {
|
||||
text: prompt.input.clone(),
|
||||
cache_control: None,
|
||||
}],
|
||||
}];
|
||||
|
||||
let Some(response) = llm_client
|
||||
.generate(llm_model_name, max_tokens, messages)
|
||||
.await?
|
||||
else {
|
||||
// Request stashed for batched processing
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let actual_output = response
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
anthropic::ResponseContent::Text { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
|
||||
|
||||
let prediction = ExamplePrediction {
|
||||
actual_patch,
|
||||
actual_output,
|
||||
provider: PredictionProvider::Teacher,
|
||||
};
|
||||
|
||||
example.predictions.push(prediction);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
|
||||
match provider {
|
||||
PredictionProvider::Teacher => {
|
||||
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
|
||||
let llm_client =
|
||||
AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
|
||||
llm_client
|
||||
.sync_batches()
|
||||
.await
|
||||
.context("Failed to sync batches")?;
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
508
crates/edit_prediction_cli/src/progress.rs
Normal file
508
crates/edit_prediction_cli/src/progress.rs
Normal file
@@ -0,0 +1,508 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
io::{IsTerminal, Write},
|
||||
sync::{Arc, Mutex, OnceLock},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use log::{Level, Log, Metadata, Record};
|
||||
|
||||
pub struct Progress {
|
||||
inner: Mutex<ProgressInner>,
|
||||
}
|
||||
|
||||
struct ProgressInner {
|
||||
completed: Vec<CompletedTask>,
|
||||
in_progress: HashMap<String, InProgressTask>,
|
||||
is_tty: bool,
|
||||
terminal_width: usize,
|
||||
max_example_name_len: usize,
|
||||
status_lines_displayed: usize,
|
||||
total_examples: usize,
|
||||
failed_examples: usize,
|
||||
last_line_is_logging: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct InProgressTask {
|
||||
step: Step,
|
||||
started_at: Instant,
|
||||
substatus: Option<String>,
|
||||
info: Option<(String, InfoStyle)>,
|
||||
}
|
||||
|
||||
struct CompletedTask {
|
||||
step: Step,
|
||||
example_name: String,
|
||||
duration: Duration,
|
||||
info: Option<(String, InfoStyle)>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Step {
|
||||
LoadProject,
|
||||
Context,
|
||||
FormatPrompt,
|
||||
Predict,
|
||||
Score,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum InfoStyle {
|
||||
Normal,
|
||||
Warning,
|
||||
}
|
||||
|
||||
impl Step {
|
||||
pub fn label(&self) -> &'static str {
|
||||
match self {
|
||||
Step::LoadProject => "Load",
|
||||
Step::Context => "Context",
|
||||
Step::FormatPrompt => "Format",
|
||||
Step::Predict => "Predict",
|
||||
Step::Score => "Score",
|
||||
}
|
||||
}
|
||||
|
||||
fn color_code(&self) -> &'static str {
|
||||
match self {
|
||||
Step::LoadProject => "\x1b[33m",
|
||||
Step::Context => "\x1b[35m",
|
||||
Step::FormatPrompt => "\x1b[34m",
|
||||
Step::Predict => "\x1b[32m",
|
||||
Step::Score => "\x1b[31m",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
|
||||
static LOGGER: ProgressLogger = ProgressLogger;
|
||||
|
||||
const MARGIN: usize = 4;
|
||||
const MAX_STATUS_LINES: usize = 10;
|
||||
|
||||
impl Progress {
|
||||
/// Returns the global Progress instance, initializing it if necessary.
|
||||
pub fn global() -> Arc<Progress> {
|
||||
GLOBAL
|
||||
.get_or_init(|| {
|
||||
let progress = Arc::new(Self {
|
||||
inner: Mutex::new(ProgressInner {
|
||||
completed: Vec::new(),
|
||||
in_progress: HashMap::new(),
|
||||
is_tty: std::io::stderr().is_terminal(),
|
||||
terminal_width: get_terminal_width(),
|
||||
max_example_name_len: 0,
|
||||
status_lines_displayed: 0,
|
||||
total_examples: 0,
|
||||
failed_examples: 0,
|
||||
last_line_is_logging: false,
|
||||
}),
|
||||
});
|
||||
let _ = log::set_logger(&LOGGER);
|
||||
log::set_max_level(log::LevelFilter::Error);
|
||||
progress
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn set_total_examples(&self, total: usize) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.total_examples = total;
|
||||
}
|
||||
|
||||
pub fn increment_failed(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.failed_examples += 1;
|
||||
}
|
||||
|
||||
/// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
|
||||
/// This should be used for any output that needs to appear above the status lines.
|
||||
fn log(&self, message: &str) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
if !inner.last_line_is_logging {
|
||||
let reset = "\x1b[0m";
|
||||
let dim = "\x1b[2m";
|
||||
let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
|
||||
eprintln!("{dim}{divider}{reset}");
|
||||
inner.last_line_is_logging = true;
|
||||
}
|
||||
|
||||
eprintln!("{}", message);
|
||||
}
|
||||
|
||||
pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
|
||||
inner.in_progress.insert(
|
||||
example_name.to_string(),
|
||||
InProgressTask {
|
||||
step,
|
||||
started_at: Instant::now(),
|
||||
substatus: None,
|
||||
info: None,
|
||||
},
|
||||
);
|
||||
|
||||
Self::print_status_lines(&mut inner);
|
||||
|
||||
StepProgress {
|
||||
progress: self.clone(),
|
||||
step,
|
||||
example_name: example_name.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(&self, step: Step, example_name: &str) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
|
||||
let Some(task) = inner.in_progress.remove(example_name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if task.step == step {
|
||||
inner.completed.push(CompletedTask {
|
||||
step: task.step,
|
||||
example_name: example_name.to_string(),
|
||||
duration: task.started_at.elapsed(),
|
||||
info: task.info,
|
||||
});
|
||||
|
||||
Self::clear_status_lines(&mut inner);
|
||||
Self::print_logging_closing_divider(&mut inner);
|
||||
Self::print_completed(&inner, inner.completed.last().unwrap());
|
||||
Self::print_status_lines(&mut inner);
|
||||
} else {
|
||||
inner.in_progress.insert(example_name.to_string(), task);
|
||||
}
|
||||
}
|
||||
|
||||
fn print_logging_closing_divider(inner: &mut ProgressInner) {
|
||||
if inner.last_line_is_logging {
|
||||
let reset = "\x1b[0m";
|
||||
let dim = "\x1b[2m";
|
||||
let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
|
||||
eprintln!("{dim}{divider}{reset}");
|
||||
inner.last_line_is_logging = false;
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_status_lines(inner: &mut ProgressInner) {
|
||||
if inner.is_tty && inner.status_lines_displayed > 0 {
|
||||
// Move up and clear each line we previously displayed
|
||||
for _ in 0..inner.status_lines_displayed {
|
||||
eprint!("\x1b[A\x1b[K");
|
||||
}
|
||||
let _ = std::io::stderr().flush();
|
||||
inner.status_lines_displayed = 0;
|
||||
}
|
||||
}
|
||||
|
||||
fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
|
||||
let duration = format_duration(task.duration);
|
||||
let name_width = inner.max_example_name_len;
|
||||
|
||||
if inner.is_tty {
|
||||
let reset = "\x1b[0m";
|
||||
let bold = "\x1b[1m";
|
||||
let dim = "\x1b[2m";
|
||||
|
||||
let yellow = "\x1b[33m";
|
||||
let info_part = task
|
||||
.info
|
||||
.as_ref()
|
||||
.map(|(s, style)| {
|
||||
if *style == InfoStyle::Warning {
|
||||
format!("{yellow}{s}{reset}")
|
||||
} else {
|
||||
s.to_string()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let prefix = format!(
|
||||
"{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
|
||||
color = task.step.color_code(),
|
||||
label = task.step.label(),
|
||||
name = task.example_name,
|
||||
);
|
||||
|
||||
let duration_with_margin = format!("{duration} ");
|
||||
let padding_needed = inner
|
||||
.terminal_width
|
||||
.saturating_sub(MARGIN)
|
||||
.saturating_sub(duration_with_margin.len())
|
||||
.saturating_sub(strip_ansi_len(&prefix));
|
||||
let padding = " ".repeat(padding_needed);
|
||||
|
||||
eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
|
||||
} else {
|
||||
let info_part = task
|
||||
.info
|
||||
.as_ref()
|
||||
.map(|(s, _)| format!(" | {}", s))
|
||||
.unwrap_or_default();
|
||||
|
||||
eprintln!(
|
||||
"{label:>12} {name:<name_width$}{info_part} {duration}",
|
||||
label = task.step.label(),
|
||||
name = task.example_name,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn print_status_lines(inner: &mut ProgressInner) {
|
||||
if !inner.is_tty || inner.in_progress.is_empty() {
|
||||
inner.status_lines_displayed = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
let reset = "\x1b[0m";
|
||||
let bold = "\x1b[1m";
|
||||
let dim = "\x1b[2m";
|
||||
|
||||
// Build the done/in-progress/total label
|
||||
let done_count = inner.completed.len();
|
||||
let in_progress_count = inner.in_progress.len();
|
||||
let failed_count = inner.failed_examples;
|
||||
|
||||
let failed_label = if failed_count > 0 {
|
||||
format!(" {} failed ", failed_count)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let range_label = format!(
|
||||
" {}/{}/{} ",
|
||||
done_count, in_progress_count, inner.total_examples
|
||||
);
|
||||
|
||||
// Print a divider line with failed count on left, range label on right
|
||||
let failed_visible_len = strip_ansi_len(&failed_label);
|
||||
let range_visible_len = range_label.len();
|
||||
let middle_divider_len = inner
|
||||
.terminal_width
|
||||
.saturating_sub(MARGIN * 2)
|
||||
.saturating_sub(failed_visible_len)
|
||||
.saturating_sub(range_visible_len);
|
||||
let left_divider = "─".repeat(MARGIN);
|
||||
let middle_divider = "─".repeat(middle_divider_len);
|
||||
let right_divider = "─".repeat(MARGIN);
|
||||
eprintln!(
|
||||
"{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
|
||||
);
|
||||
|
||||
let mut tasks: Vec<_> = inner.in_progress.iter().collect();
|
||||
tasks.sort_by_key(|(name, _)| *name);
|
||||
|
||||
let total_tasks = tasks.len();
|
||||
let mut lines_printed = 0;
|
||||
|
||||
for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
|
||||
let elapsed = format_duration(task.started_at.elapsed());
|
||||
let substatus_part = task
|
||||
.substatus
|
||||
.as_ref()
|
||||
.map(|s| truncate_with_ellipsis(s, 30))
|
||||
.unwrap_or_default();
|
||||
|
||||
let step_label = task.step.label();
|
||||
let step_color = task.step.color_code();
|
||||
let name_width = inner.max_example_name_len;
|
||||
|
||||
let prefix = format!(
|
||||
"{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
|
||||
name = name,
|
||||
);
|
||||
|
||||
let duration_with_margin = format!("{elapsed} ");
|
||||
let padding_needed = inner
|
||||
.terminal_width
|
||||
.saturating_sub(MARGIN)
|
||||
.saturating_sub(duration_with_margin.len())
|
||||
.saturating_sub(strip_ansi_len(&prefix));
|
||||
let padding = " ".repeat(padding_needed);
|
||||
|
||||
eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
|
||||
lines_printed += 1;
|
||||
}
|
||||
|
||||
// Show "+N more" on its own line if there are more tasks
|
||||
if total_tasks > MAX_STATUS_LINES {
|
||||
let remaining = total_tasks - MAX_STATUS_LINES;
|
||||
eprintln!("{:>12} +{remaining} more", "");
|
||||
lines_printed += 1;
|
||||
}
|
||||
|
||||
inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
|
||||
let _ = std::io::stderr().flush();
|
||||
}
|
||||
|
||||
pub fn finalize(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
// Print summary if there were failures
|
||||
if inner.failed_examples > 0 {
|
||||
let total_processed = inner.completed.len() + inner.failed_examples;
|
||||
let percentage = if total_processed > 0 {
|
||||
inner.failed_examples as f64 / total_processed as f64 * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
eprintln!(
|
||||
"\n{} of {} examples failed ({:.1}%)",
|
||||
inner.failed_examples, total_processed, percentage
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StepProgress {
|
||||
progress: Arc<Progress>,
|
||||
step: Step,
|
||||
example_name: String,
|
||||
}
|
||||
|
||||
impl StepProgress {
|
||||
pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
|
||||
let mut inner = self.progress.inner.lock().unwrap();
|
||||
if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
|
||||
task.substatus = Some(substatus.into().into_owned());
|
||||
Progress::clear_status_lines(&mut inner);
|
||||
Progress::print_status_lines(&mut inner);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_substatus(&self) {
|
||||
let mut inner = self.progress.inner.lock().unwrap();
|
||||
if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
|
||||
task.substatus = None;
|
||||
Progress::clear_status_lines(&mut inner);
|
||||
Progress::print_status_lines(&mut inner);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
|
||||
let mut inner = self.progress.inner.lock().unwrap();
|
||||
if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
|
||||
task.info = Some((info.into(), style));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StepProgress {
|
||||
fn drop(&mut self) {
|
||||
self.progress.finish(self.step, &self.example_name);
|
||||
}
|
||||
}
|
||||
|
||||
struct ProgressLogger;
|
||||
|
||||
impl Log for ProgressLogger {
|
||||
fn enabled(&self, metadata: &Metadata) -> bool {
|
||||
metadata.level() <= Level::Info
|
||||
}
|
||||
|
||||
fn log(&self, record: &Record) {
|
||||
if !self.enabled(record.metadata()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let level_color = match record.level() {
|
||||
Level::Error => "\x1b[31m",
|
||||
Level::Warn => "\x1b[33m",
|
||||
Level::Info => "\x1b[32m",
|
||||
Level::Debug => "\x1b[34m",
|
||||
Level::Trace => "\x1b[35m",
|
||||
};
|
||||
let reset = "\x1b[0m";
|
||||
let bold = "\x1b[1m";
|
||||
|
||||
let level_label = match record.level() {
|
||||
Level::Error => "Error",
|
||||
Level::Warn => "Warn",
|
||||
Level::Info => "Info",
|
||||
Level::Debug => "Debug",
|
||||
Level::Trace => "Trace",
|
||||
};
|
||||
|
||||
let message = format!(
|
||||
"{bold}{level_color}{level_label:>12}{reset} {}",
|
||||
record.args()
|
||||
);
|
||||
|
||||
if let Some(progress) = GLOBAL.get() {
|
||||
progress.log(&message);
|
||||
} else {
|
||||
eprintln!("{}", message);
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) {
|
||||
let _ = std::io::stderr().flush();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn get_terminal_width() -> usize {
|
||||
unsafe {
|
||||
let mut winsize: libc::winsize = std::mem::zeroed();
|
||||
if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
|
||||
&& winsize.ws_col > 0
|
||||
{
|
||||
winsize.ws_col as usize
|
||||
} else {
|
||||
80
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn get_terminal_width() -> usize {
|
||||
80
|
||||
}
|
||||
|
||||
fn strip_ansi_len(s: &str) -> usize {
|
||||
let mut len = 0;
|
||||
let mut in_escape = false;
|
||||
for c in s.chars() {
|
||||
if c == '\x1b' {
|
||||
in_escape = true;
|
||||
} else if in_escape {
|
||||
if c == 'm' {
|
||||
in_escape = false;
|
||||
}
|
||||
} else {
|
||||
len += 1;
|
||||
}
|
||||
}
|
||||
len
|
||||
}
|
||||
|
||||
fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
|
||||
if s.len() <= max_len {
|
||||
s.to_string()
|
||||
} else {
|
||||
format!("{}…", &s[..max_len.saturating_sub(1)])
|
||||
}
|
||||
}
|
||||
|
||||
fn format_duration(duration: Duration) -> String {
|
||||
const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
|
||||
|
||||
let millis = duration.as_millis() as f32;
|
||||
if millis < 1000.0 {
|
||||
format!("{}ms", millis)
|
||||
} else if millis < MINUTE_IN_MILLIS {
|
||||
format!("{:.1}s", millis / 1_000.0)
|
||||
} else {
|
||||
format!("{:.1}m", millis / MINUTE_IN_MILLIS)
|
||||
}
|
||||
}
|
||||
192
crates/edit_prediction_cli/src/retrieve_context.rs
Normal file
192
crates/edit_prediction_cli/src/retrieve_context.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
use crate::{
|
||||
example::{Example, ExampleContext},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use collections::HashSet;
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::Buffer;
|
||||
use project::Project;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub async fn run_context_retrieval(
|
||||
example: &mut Example,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
if example.context.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let step_progress: Arc<StepProgress> = Progress::global()
|
||||
.start(Step::Context, &example.name)
|
||||
.into();
|
||||
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let project = state.project.clone();
|
||||
|
||||
let _lsp_handle = project.update(&mut cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&state.buffer, cx)
|
||||
})?;
|
||||
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
|
||||
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
|
||||
let mut events = ep_store.update(&mut cx, |store, cx| {
|
||||
store.register_buffer(&state.buffer, &project, cx);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
|
||||
store.debug_info(&project, cx)
|
||||
})?;
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
DebugEvent::ContextRetrievalFinished(_) => {
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let context_files =
|
||||
ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx))?;
|
||||
|
||||
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
|
||||
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
|
||||
|
||||
example.context = Some(ExampleContext {
|
||||
files: context_files,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_language_servers_to_start(
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
step_progress: &Arc<StepProgress>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
let lsp_store = project.read_with(cx, |project, _| project.lsp_store())?;
|
||||
|
||||
let (language_server_ids, mut starting_language_server_ids) = buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
lsp_store.update(cx, |lsp_store, cx| {
|
||||
let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
|
||||
let starting_ids = ids
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
|
||||
.collect::<HashSet<_>>();
|
||||
(ids, starting_ids)
|
||||
})
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
|
||||
|
||||
let timeout = cx
|
||||
.background_executor()
|
||||
.timer(Duration::from_secs(60 * 5))
|
||||
.shared();
|
||||
|
||||
let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
|
||||
let added_subscription = cx.subscribe(project, {
|
||||
let step_progress = step_progress.clone();
|
||||
move |_, event, _| match event {
|
||||
project::Event::LanguageServerAdded(language_server_id, name, _) => {
|
||||
step_progress.set_substatus(format!("LSP started: {}", name));
|
||||
tx.try_send(*language_server_id).ok();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
});
|
||||
|
||||
while !starting_language_server_ids.is_empty() {
|
||||
futures::select! {
|
||||
language_server_id = rx.next() => {
|
||||
if let Some(id) = language_server_id {
|
||||
starting_language_server_ids.remove(&id);
|
||||
}
|
||||
},
|
||||
_ = timeout.clone().fuse() => {
|
||||
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(added_subscription);
|
||||
|
||||
if !language_server_ids.is_empty() {
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.detach();
|
||||
}
|
||||
|
||||
let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
|
||||
let subscriptions = [
|
||||
cx.subscribe(&lsp_store, {
|
||||
let step_progress = step_progress.clone();
|
||||
move |_, event, _| {
|
||||
if let project::LspStoreEvent::LanguageServerUpdate {
|
||||
message:
|
||||
client::proto::update_language_server::Variant::WorkProgress(
|
||||
client::proto::LspWorkProgress {
|
||||
message: Some(message),
|
||||
..
|
||||
},
|
||||
),
|
||||
..
|
||||
} = event
|
||||
{
|
||||
step_progress.set_substatus(message.clone());
|
||||
}
|
||||
}
|
||||
}),
|
||||
cx.subscribe(project, {
|
||||
let step_progress = step_progress.clone();
|
||||
move |_, event, cx| match event {
|
||||
project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
|
||||
let lsp_store = lsp_store.read(cx);
|
||||
let name = lsp_store
|
||||
.language_server_adapter_for_id(*language_server_id)
|
||||
.unwrap()
|
||||
.name();
|
||||
step_progress.set_substatus(format!("LSP idle: {}", name));
|
||||
tx.try_send(*language_server_id).ok();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}),
|
||||
];
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.await?;
|
||||
|
||||
let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
|
||||
while !pending_language_server_ids.is_empty() {
|
||||
futures::select! {
|
||||
language_server_id = rx.next() => {
|
||||
if let Some(id) = language_server_id {
|
||||
pending_language_server_ids.remove(&id);
|
||||
}
|
||||
},
|
||||
_ = timeout.clone().fuse() => {
|
||||
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(subscriptions);
|
||||
step_progress.clear_substatus();
|
||||
Ok(())
|
||||
}
|
||||
123
crates/edit_prediction_cli/src/score.rs
Normal file
123
crates/edit_prediction_cli/src/score.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use crate::{
|
||||
PredictArgs,
|
||||
example::{Example, ExampleScore},
|
||||
headless::EpAppState,
|
||||
metrics::{self, ClassificationMetrics},
|
||||
predict::run_prediction,
|
||||
progress::{Progress, Step},
|
||||
};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use gpui::AsyncApp;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn run_scoring(
|
||||
example: &mut Example,
|
||||
args: &PredictArgs,
|
||||
app_state: Arc<EpAppState>,
|
||||
cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let _progress = Progress::global().start(Step::Score, &example.name);
|
||||
|
||||
let expected_patch = parse_patch(&example.expected_patch);
|
||||
|
||||
let mut scores = vec![];
|
||||
|
||||
for pred in &example.predictions {
|
||||
let actual_patch = parse_patch(&pred.actual_patch);
|
||||
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
|
||||
let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
|
||||
|
||||
scores.push(ExampleScore {
|
||||
delta_chr_f,
|
||||
line_match,
|
||||
});
|
||||
}
|
||||
|
||||
example.score = scores;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
|
||||
patch.lines().map(DiffLine::parse).collect()
|
||||
}
|
||||
|
||||
pub fn print_report(examples: &[Example]) {
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
|
||||
"Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
|
||||
);
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
let mut all_line_match_scores = Vec::new();
|
||||
let mut all_delta_chr_f_scores = Vec::new();
|
||||
|
||||
for example in examples {
|
||||
for score in example.score.iter() {
|
||||
let line_match = &score.line_match;
|
||||
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
|
||||
truncate_name(&example.name, 30),
|
||||
line_match.true_positives,
|
||||
line_match.false_positives,
|
||||
line_match.false_negatives,
|
||||
line_match.precision() * 100.0,
|
||||
line_match.recall() * 100.0,
|
||||
line_match.f1_score() * 100.0,
|
||||
score.delta_chr_f
|
||||
);
|
||||
|
||||
all_line_match_scores.push(line_match.clone());
|
||||
all_delta_chr_f_scores.push(score.delta_chr_f);
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
if !all_line_match_scores.is_empty() {
|
||||
let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
|
||||
let avg_delta_chr_f: f32 =
|
||||
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
|
||||
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
|
||||
"TOTAL",
|
||||
total_line_match.true_positives,
|
||||
total_line_match.false_positives,
|
||||
total_line_match.false_negatives,
|
||||
total_line_match.precision() * 100.0,
|
||||
total_line_match.recall() * 100.0,
|
||||
total_line_match.f1_score() * 100.0,
|
||||
avg_delta_chr_f
|
||||
);
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
}
|
||||
|
||||
eprintln!("\n");
|
||||
}
|
||||
|
||||
fn truncate_name(name: &str, max_len: usize) -> String {
|
||||
if name.len() <= max_len {
|
||||
name.to_string()
|
||||
} else {
|
||||
format!("{}...", &name[..max_len - 3])
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
use std::{fmt, fmt::Display, path::Path, str::FromStr, sync::Arc};
|
||||
|
||||
use ::util::{paths::PathStyle, rel_path::RelPath};
|
||||
use anyhow::{Result, anyhow};
|
||||
use language::Point;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct SourceLocation {
|
||||
pub path: Arc<RelPath>,
|
||||
pub point: Point,
|
||||
}
|
||||
|
||||
impl Serialize for SourceLocation {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for SourceLocation {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
s.parse().map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for SourceLocation {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}:{}:{}",
|
||||
self.path.display(PathStyle::Posix),
|
||||
self.point.row + 1,
|
||||
self.point.column + 1
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SourceLocation {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self> {
|
||||
let parts: Vec<&str> = s.split(':').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err(anyhow!(
|
||||
"Invalid source location. Expected 'file.rs:line:column', got '{}'",
|
||||
s
|
||||
));
|
||||
}
|
||||
|
||||
let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
|
||||
let line: u32 = parts[1]
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
|
||||
let column: u32 = parts[2]
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
|
||||
|
||||
// Convert from 1-based to 0-based indexing
|
||||
let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
|
||||
|
||||
Ok(SourceLocation { path, point })
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ Focus on:
|
||||
Rules:
|
||||
- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
|
||||
- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
|
||||
- Keep existing formatting unless it's absolutely necessary
|
||||
|
||||
Input format:
|
||||
- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.
|
||||
@@ -46,3 +47,7 @@ Output example:
|
||||
## Code Context
|
||||
|
||||
{{context}}
|
||||
|
||||
## Editable region
|
||||
|
||||
{{editable_region}}
|
||||
@@ -1,89 +0,0 @@
|
||||
use std::path::Path;
|
||||
|
||||
use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
|
||||
|
||||
#[derive(Debug, Clone, Default, clap::ValueEnum)]
|
||||
pub enum ContextType {
|
||||
#[default]
|
||||
CurrentFile,
|
||||
}
|
||||
|
||||
const MAX_CONTEXT_SIZE: usize = 32768;
|
||||
|
||||
pub fn collect_context(
|
||||
context_type: &ContextType,
|
||||
worktree_dir: &Path,
|
||||
cursor: SourceLocation,
|
||||
) -> String {
|
||||
let context = match context_type {
|
||||
ContextType::CurrentFile => {
|
||||
let file_path = worktree_dir.join(cursor.path.as_std_path());
|
||||
let context = std::fs::read_to_string(&file_path).unwrap_or_default();
|
||||
|
||||
let context = add_special_tags(&context, worktree_dir, cursor);
|
||||
context
|
||||
}
|
||||
};
|
||||
|
||||
let region_end_offset = context.find(TeacherModel::REGION_END);
|
||||
|
||||
if context.len() <= MAX_CONTEXT_SIZE {
|
||||
return context;
|
||||
}
|
||||
|
||||
if let Some(region_end_offset) = region_end_offset
|
||||
&& region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
|
||||
{
|
||||
let to_truncate = context.len() - MAX_CONTEXT_SIZE;
|
||||
format!(
|
||||
"[...{} bytes truncated]\n{}\n",
|
||||
to_truncate,
|
||||
&context[to_truncate..]
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"{}\n[...{} bytes truncated]\n",
|
||||
&context[..MAX_CONTEXT_SIZE],
|
||||
context.len() - MAX_CONTEXT_SIZE
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Add <|editable_region_start/end|> tags
|
||||
fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
|
||||
let path = worktree_dir.join(cursor.path.as_std_path());
|
||||
let file = std::fs::read_to_string(&path).unwrap_or_default();
|
||||
let lines = file.lines().collect::<Vec<_>>();
|
||||
let cursor_row = cursor.point.row as usize;
|
||||
let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
|
||||
let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
|
||||
|
||||
let snippet = lines[start_line..end_line].join("\n");
|
||||
|
||||
if context.contains(&snippet) {
|
||||
let mut cursor_line = lines[cursor_row].to_string();
|
||||
cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
|
||||
|
||||
let mut snippet_with_tags_lines = vec![];
|
||||
snippet_with_tags_lines.push(TeacherModel::REGION_START);
|
||||
snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
|
||||
snippet_with_tags_lines.push(&cursor_line);
|
||||
snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
|
||||
snippet_with_tags_lines.push(TeacherModel::REGION_END);
|
||||
let snippet_with_tags = snippet_with_tags_lines.join("\n");
|
||||
|
||||
context.replace(&snippet, &snippet_with_tags)
|
||||
} else {
|
||||
log::warn!(
|
||||
"Can't find area around the cursor in the context; proceeding without special tags"
|
||||
);
|
||||
context.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn strip_special_tags(context: &str) -> String {
|
||||
context
|
||||
.replace(TeacherModel::REGION_START, "")
|
||||
.replace(TeacherModel::REGION_END, "")
|
||||
.replace(TeacherModel::USER_CURSOR, "")
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
DistillArguments,
|
||||
example::Example,
|
||||
source_location::SourceLocation,
|
||||
training::{
|
||||
context::ContextType,
|
||||
llm_client::LlmClient,
|
||||
teacher::{TeacherModel, TeacherOutput},
|
||||
},
|
||||
};
|
||||
use anyhow::Result;
|
||||
use reqwest_client::ReqwestClient;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SplitCommit {
|
||||
repo_url: String,
|
||||
commit_sha: String,
|
||||
edit_history: String,
|
||||
expected_patch: String,
|
||||
cursor_position: String,
|
||||
}
|
||||
|
||||
pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
|
||||
let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
|
||||
.expect("Failed to read split commit dataset")
|
||||
.lines()
|
||||
.map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
|
||||
.collect();
|
||||
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
|
||||
let llm_client = if let Some(cache_path) = arguments.batch {
|
||||
LlmClient::batch(&cache_path, http_client)?
|
||||
} else {
|
||||
LlmClient::plain(http_client)?
|
||||
};
|
||||
|
||||
let mut teacher = TeacherModel::new(
|
||||
"claude-sonnet-4-5".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
llm_client,
|
||||
);
|
||||
|
||||
let mut num_marked_for_batching = 0;
|
||||
|
||||
for commit in split_commits {
|
||||
if let Some(distilled) = distill_one(&mut teacher, commit).await? {
|
||||
println!("{}", serde_json::to_string(&distilled)?);
|
||||
} else {
|
||||
if num_marked_for_batching == 0 {
|
||||
log::warn!("Marked for batching");
|
||||
}
|
||||
num_marked_for_batching += 1;
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"{} requests are marked for batching",
|
||||
num_marked_for_batching
|
||||
);
|
||||
let llm_client = teacher.client;
|
||||
llm_client.sync_batches().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn distill_one(
|
||||
teacher: &mut TeacherModel,
|
||||
commit: SplitCommit,
|
||||
) -> Result<Option<TeacherOutput>> {
|
||||
let cursor: SourceLocation = commit
|
||||
.cursor_position
|
||||
.parse()
|
||||
.expect("Failed to parse cursor position");
|
||||
|
||||
let path = cursor.path.to_rel_path_buf();
|
||||
|
||||
let example = Example {
|
||||
repository_url: commit.repo_url,
|
||||
revision: commit.commit_sha,
|
||||
uncommitted_diff: commit.edit_history.clone(),
|
||||
cursor_path: path.as_std_path().to_path_buf(),
|
||||
cursor_position: commit.cursor_position,
|
||||
edit_history: commit.edit_history, // todo: trim
|
||||
expected_patch: commit.expected_patch,
|
||||
};
|
||||
|
||||
let prediction = teacher.predict(example).await;
|
||||
|
||||
prediction
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
pub mod context;
|
||||
pub mod distill;
|
||||
pub mod llm_client;
|
||||
pub mod teacher;
|
||||
@@ -1,266 +0,0 @@
|
||||
use crate::{
|
||||
example::Example,
|
||||
source_location::SourceLocation,
|
||||
training::{
|
||||
context::{ContextType, collect_context, strip_special_tags},
|
||||
llm_client::LlmClient,
|
||||
},
|
||||
};
|
||||
use anthropic::{Message, RequestContent, ResponseContent, Role};
|
||||
use anyhow::Result;
|
||||
|
||||
pub struct TeacherModel {
|
||||
pub llm_name: String,
|
||||
pub context: ContextType,
|
||||
pub client: LlmClient,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct TeacherOutput {
|
||||
parsed_output: String,
|
||||
prompt: String,
|
||||
raw_llm_response: String,
|
||||
context: String,
|
||||
diff: String,
|
||||
}
|
||||
|
||||
impl TeacherModel {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const REGION_END: &str = "<|editable_region_end|>";
|
||||
pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
|
||||
|
||||
/// Number of lines to include before the cursor position
|
||||
pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
|
||||
|
||||
/// Number of lines to include after the cursor position
|
||||
pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
|
||||
pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
|
||||
TeacherModel {
|
||||
llm_name,
|
||||
context,
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
|
||||
let name = input.unique_name();
|
||||
let worktree_dir = input.setup_worktree(name).await?;
|
||||
let cursor: SourceLocation = input
|
||||
.cursor_position
|
||||
.parse()
|
||||
.expect("Failed to parse cursor position");
|
||||
|
||||
let context = collect_context(&self.context, &worktree_dir, cursor.clone());
|
||||
let edit_history = Self::format_edit_history(&input.edit_history);
|
||||
|
||||
let prompt = Self::PROMPT
|
||||
.replace("{{context}}", &context)
|
||||
.replace("{{edit_history}}", &edit_history);
|
||||
|
||||
let messages = vec![Message {
|
||||
role: Role::User,
|
||||
content: vec![RequestContent::Text {
|
||||
text: prompt.clone(),
|
||||
cache_control: None,
|
||||
}],
|
||||
}];
|
||||
|
||||
let Some(response) = self
|
||||
.client
|
||||
.generate(self.llm_name.clone(), 16384, messages)
|
||||
.await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let response_text = response
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
ResponseContent::Text { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let parsed_output = self.parse_response(&response_text);
|
||||
|
||||
let original_editable_region = Self::extract_editable_region(&context);
|
||||
let context_after_edit = context.replace(&original_editable_region, &parsed_output);
|
||||
let context_after_edit = strip_special_tags(&context_after_edit);
|
||||
let context_before_edit = strip_special_tags(&context);
|
||||
let diff = language::unified_diff(&context_before_edit, &context_after_edit);
|
||||
|
||||
// zeta distill --batch batch_results.txt
|
||||
// zeta distill
|
||||
// 1. Run `zeta distill <2000 examples <- all examples>` for the first time
|
||||
// - store LLM requests in a batch, don't actual send the request
|
||||
// - send the batch (2000 requests) after all inputs are processed
|
||||
// 2. `zeta send-batches`
|
||||
// - upload the batch to Anthropic
|
||||
|
||||
// https://platform.claude.com/docs/en/build-with-claude/batch-processing
|
||||
// https://crates.io/crates/anthropic-sdk-rust
|
||||
|
||||
// - poll for results
|
||||
// - when ready, store results in cache (a database)
|
||||
// 3. `zeta distill` again
|
||||
// - use the cached results this time
|
||||
|
||||
Ok(Some(TeacherOutput {
|
||||
parsed_output,
|
||||
prompt,
|
||||
raw_llm_response: response_text,
|
||||
context,
|
||||
diff,
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_response(&self, content: &str) -> String {
|
||||
let codeblock = Self::extract_last_codeblock(content);
|
||||
let editable_region = Self::extract_editable_region(&codeblock);
|
||||
|
||||
editable_region
|
||||
}
|
||||
|
||||
/// Extract content from the last code-fenced block if any, or else return content as is
|
||||
fn extract_last_codeblock(text: &str) -> String {
|
||||
let mut last_block = None;
|
||||
let mut search_start = 0;
|
||||
|
||||
while let Some(start) = text[search_start..].find("```") {
|
||||
let start = start + search_start;
|
||||
let bytes = text.as_bytes();
|
||||
let mut backtick_end = start;
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
let backtick_count = backtick_end - start;
|
||||
let closing_backticks = "`".repeat(backtick_count);
|
||||
|
||||
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
|
||||
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
|
||||
last_block = Some(code_block.to_string());
|
||||
search_start = backtick_end + end_pos + backtick_count;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
last_block.unwrap_or_else(|| text.to_string())
|
||||
}
|
||||
|
||||
fn extract_editable_region(text: &str) -> String {
|
||||
let start = text
|
||||
.find(Self::REGION_START)
|
||||
.map_or(0, |pos| pos + Self::REGION_START.len());
|
||||
let end = text.find(Self::REGION_END).unwrap_or(text.len());
|
||||
|
||||
text[start..end].to_string()
|
||||
}
|
||||
|
||||
/// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
let lines = edit_history
|
||||
.lines()
|
||||
.filter(|&s| Self::is_content_line(s))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
|
||||
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
|
||||
} else {
|
||||
&lines
|
||||
};
|
||||
history_lines.join("\n")
|
||||
}
|
||||
|
||||
fn is_content_line(s: &str) -> bool {
|
||||
s.starts_with("-")
|
||||
|| s.starts_with("+")
|
||||
|| s.starts_with(" ")
|
||||
|| s.starts_with("---")
|
||||
|| s.starts_with("+++")
|
||||
|| s.starts_with("@@")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let teacher = TeacherModel::new(
|
||||
"test".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
LlmClient::dummy(),
|
||||
);
|
||||
let response = "This is a test response.";
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(parsed, response.to_string());
|
||||
|
||||
let response = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
`````
|
||||
actual response
|
||||
`````
|
||||
"};
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(parsed, "actual response");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_last_code_block() {
|
||||
let text = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
```
|
||||
first block
|
||||
```
|
||||
|
||||
`````
|
||||
last block
|
||||
`````
|
||||
"};
|
||||
let last_block = TeacherModel::extract_last_codeblock(text);
|
||||
assert_eq!(last_block, "last block");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_editable_region() {
|
||||
let teacher = TeacherModel::new(
|
||||
"test".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
LlmClient::dummy(),
|
||||
);
|
||||
let response = indoc::indoc! {"
|
||||
some lines
|
||||
are
|
||||
here
|
||||
<|editable_region_start|>
|
||||
one
|
||||
two three
|
||||
|
||||
<|editable_region_end|>
|
||||
more
|
||||
lines here
|
||||
"};
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
indoc::indoc! {"
|
||||
one
|
||||
two three
|
||||
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,198 +0,0 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
|
||||
use project::lsp_store::OpenLspBufferHandle;
|
||||
use project::{Project, ProjectPath, Worktree};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use util::rel_path::RelPath;
|
||||
|
||||
pub fn open_buffer(
|
||||
project: Entity<Project>,
|
||||
worktree: Entity<Worktree>,
|
||||
path: Arc<RelPath>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Entity<Buffer>>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path,
|
||||
})?;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?;
|
||||
|
||||
let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
|
||||
while *parse_status.borrow() != ParseStatus::Idle {
|
||||
parse_status.changed().await?;
|
||||
}
|
||||
|
||||
Ok(buffer)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn open_buffer_with_language_server(
|
||||
project: Entity<Project>,
|
||||
worktree: Entity<Worktree>,
|
||||
path: Arc<RelPath>,
|
||||
ready_languages: &mut HashSet<LanguageId>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
|
||||
let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
|
||||
|
||||
let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
|
||||
(
|
||||
project.register_buffer_with_language_servers(&buffer, cx),
|
||||
project.path_style(cx),
|
||||
)
|
||||
})?;
|
||||
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
|
||||
let result = language_registry
|
||||
.load_language_for_file_path(path.as_std_path())
|
||||
.await;
|
||||
|
||||
if let Err(error) = result
|
||||
&& !error.is::<LanguageNotFound>()
|
||||
{
|
||||
anyhow::bail!(error);
|
||||
}
|
||||
|
||||
let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
|
||||
buffer.language().map(|language| language.id())
|
||||
})?
|
||||
else {
|
||||
return Err(anyhow!("No language for {}", path.display(path_style)));
|
||||
};
|
||||
|
||||
let log_prefix = format!("{} | ", path.display(path_style));
|
||||
if !ready_languages.contains(&language_id) {
|
||||
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
|
||||
ready_languages.insert(language_id);
|
||||
}
|
||||
|
||||
let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
|
||||
|
||||
// hacky wait for buffer to be registered with the language server
|
||||
for _ in 0..100 {
|
||||
let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(&buffer, cx)
|
||||
.next()
|
||||
.map(|(_, language_server)| language_server.server_id())
|
||||
})
|
||||
})?
|
||||
else {
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(10))
|
||||
.await;
|
||||
continue;
|
||||
};
|
||||
|
||||
return Ok((lsp_open_handle, language_server_id, buffer));
|
||||
}
|
||||
|
||||
return Err(anyhow!("No language server found for buffer"));
|
||||
}
|
||||
|
||||
// TODO: Dedupe with similar function in crates/eval/src/instance.rs
|
||||
pub fn wait_for_lang_server(
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
log_prefix: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>> {
|
||||
eprintln!("{}⏵ Waiting for language server", log_prefix);
|
||||
|
||||
let (mut tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
let lsp_store = project
|
||||
.read_with(cx, |project, _| project.lsp_store())
|
||||
.unwrap();
|
||||
|
||||
let has_lang_server = buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
lsp_store.update(cx, |lsp_store, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(buffer, cx)
|
||||
.next()
|
||||
.is_some()
|
||||
})
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
if has_lang_server {
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
.unwrap()
|
||||
.detach();
|
||||
}
|
||||
let (mut added_tx, mut added_rx) = mpsc::channel(1);
|
||||
|
||||
let subscriptions = [
|
||||
cx.subscribe(&lsp_store, {
|
||||
let log_prefix = log_prefix.clone();
|
||||
move |_, event, _| {
|
||||
if let project::LspStoreEvent::LanguageServerUpdate {
|
||||
message:
|
||||
client::proto::update_language_server::Variant::WorkProgress(
|
||||
client::proto::LspWorkProgress {
|
||||
message: Some(message),
|
||||
..
|
||||
},
|
||||
),
|
||||
..
|
||||
} = event
|
||||
{
|
||||
eprintln!("{}⟲ {message}", log_prefix)
|
||||
}
|
||||
}
|
||||
}),
|
||||
cx.subscribe(project, {
|
||||
let buffer = buffer.clone();
|
||||
move |project, event, cx| match event {
|
||||
project::Event::LanguageServerAdded(_, _, _) => {
|
||||
let buffer = buffer.clone();
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer, cx))
|
||||
.detach();
|
||||
added_tx.try_send(()).ok();
|
||||
}
|
||||
project::Event::DiskBasedDiagnosticsFinished { .. } => {
|
||||
tx.try_send(()).ok();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}),
|
||||
];
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
if !has_lang_server {
|
||||
// some buffers never have a language server, so this aborts quickly in that case.
|
||||
let timeout = cx.background_executor().timer(Duration::from_secs(500));
|
||||
futures::select! {
|
||||
_ = added_rx.next() => {},
|
||||
_ = timeout.fuse() => {
|
||||
anyhow::bail!("Waiting for language server add timed out after 5 seconds");
|
||||
}
|
||||
};
|
||||
}
|
||||
let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
|
||||
let result = futures::select! {
|
||||
_ = rx.next() => {
|
||||
eprintln!("{}⚑ Language server idle", log_prefix);
|
||||
anyhow::Ok(())
|
||||
},
|
||||
_ = timeout.fuse() => {
|
||||
anyhow::bail!("LSP wait timed out after 5 minutes");
|
||||
}
|
||||
};
|
||||
drop(subscriptions);
|
||||
result
|
||||
})
|
||||
}
|
||||
@@ -26,6 +26,7 @@ serde.workspace = true
|
||||
smallvec.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
util.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::RelatedExcerpt;
|
||||
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
|
||||
use std::ops::Range;
|
||||
use zeta_prompt::RelatedExcerpt;
|
||||
|
||||
#[cfg(not(test))]
|
||||
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
|
||||
@@ -76,14 +76,9 @@ pub fn assemble_excerpts(
|
||||
|
||||
input_ranges
|
||||
.into_iter()
|
||||
.map(|range| {
|
||||
let offset_range = range.to_offset(buffer);
|
||||
RelatedExcerpt {
|
||||
point_range: range,
|
||||
anchor_range: buffer.anchor_before(offset_range.start)
|
||||
..buffer.anchor_after(offset_range.end),
|
||||
text: buffer.as_rope().slice(offset_range),
|
||||
}
|
||||
.map(|range| RelatedExcerpt {
|
||||
row_range: range.start.row..range.end.row,
|
||||
text: buffer.text_for_range(range).collect(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -3,13 +3,13 @@ use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
|
||||
use project::{LocationLink, Project, ProjectPath};
|
||||
use serde::{Serialize, Serializer};
|
||||
use smallvec::SmallVec;
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
@@ -24,12 +24,14 @@ mod fake_definition_lsp;
|
||||
|
||||
pub use cloud_llm_client::predict_edits_v3::Line;
|
||||
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
|
||||
pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
|
||||
|
||||
const IDENTIFIER_LINE_COUNT: u32 = 3;
|
||||
|
||||
pub struct RelatedExcerptStore {
|
||||
project: WeakEntity<Project>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
related_files: Arc<[RelatedFile]>,
|
||||
related_file_buffers: Vec<Entity<Buffer>>,
|
||||
cache: HashMap<Identifier, Arc<CacheEntry>>,
|
||||
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
|
||||
identifier_line_count: u32,
|
||||
@@ -68,82 +70,6 @@ struct CachedDefinition {
|
||||
anchor_range: Range<Anchor>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct RelatedFile {
|
||||
#[serde(serialize_with = "serialize_project_path")]
|
||||
pub path: ProjectPath,
|
||||
#[serde(skip)]
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub excerpts: Vec<RelatedExcerpt>,
|
||||
pub max_row: u32,
|
||||
}
|
||||
|
||||
impl RelatedFile {
|
||||
pub fn merge_excerpts(&mut self) {
|
||||
self.excerpts.sort_unstable_by(|a, b| {
|
||||
a.point_range
|
||||
.start
|
||||
.cmp(&b.point_range.start)
|
||||
.then(b.point_range.end.cmp(&a.point_range.end))
|
||||
});
|
||||
|
||||
let mut index = 1;
|
||||
while index < self.excerpts.len() {
|
||||
if self.excerpts[index - 1]
|
||||
.point_range
|
||||
.end
|
||||
.cmp(&self.excerpts[index].point_range.start)
|
||||
.is_ge()
|
||||
{
|
||||
let removed = self.excerpts.remove(index);
|
||||
if removed
|
||||
.point_range
|
||||
.end
|
||||
.cmp(&self.excerpts[index - 1].point_range.end)
|
||||
.is_gt()
|
||||
{
|
||||
self.excerpts[index - 1].point_range.end = removed.point_range.end;
|
||||
self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
|
||||
}
|
||||
} else {
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct RelatedExcerpt {
|
||||
#[serde(skip)]
|
||||
pub anchor_range: Range<Anchor>,
|
||||
#[serde(serialize_with = "serialize_point_range")]
|
||||
pub point_range: Range<Point>,
|
||||
#[serde(serialize_with = "serialize_rope")]
|
||||
pub text: Rope,
|
||||
}
|
||||
|
||||
fn serialize_project_path<S: Serializer>(
|
||||
project_path: &ProjectPath,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
project_path.path.serialize(serializer)
|
||||
}
|
||||
|
||||
fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
rope.to_string().serialize(serializer)
|
||||
}
|
||||
|
||||
fn serialize_point_range<S: Serializer>(
|
||||
range: &Range<Point>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
[
|
||||
[range.start.row, range.start.column],
|
||||
[range.end.row, range.end.column],
|
||||
]
|
||||
.serialize(serializer)
|
||||
}
|
||||
|
||||
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
|
||||
|
||||
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
|
||||
@@ -179,7 +105,8 @@ impl RelatedExcerptStore {
|
||||
RelatedExcerptStore {
|
||||
project: project.downgrade(),
|
||||
update_tx,
|
||||
related_files: Vec::new(),
|
||||
related_files: Vec::new().into(),
|
||||
related_file_buffers: Vec::new(),
|
||||
cache: Default::default(),
|
||||
identifier_line_count: IDENTIFIER_LINE_COUNT,
|
||||
}
|
||||
@@ -193,8 +120,21 @@ impl RelatedExcerptStore {
|
||||
self.update_tx.unbounded_send((buffer, position)).ok();
|
||||
}
|
||||
|
||||
pub fn related_files(&self) -> &[RelatedFile] {
|
||||
&self.related_files
|
||||
pub fn related_files(&self) -> Arc<[RelatedFile]> {
|
||||
self.related_files.clone()
|
||||
}
|
||||
|
||||
pub fn related_files_with_buffers(
|
||||
&self,
|
||||
) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
|
||||
self.related_files
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(self.related_file_buffers.iter().cloned())
|
||||
}
|
||||
|
||||
pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
|
||||
self.related_files = files.into();
|
||||
}
|
||||
|
||||
async fn fetch_excerpts(
|
||||
@@ -297,7 +237,8 @@ impl RelatedExcerptStore {
|
||||
}
|
||||
mean_definition_latency /= cache_miss_count.max(1) as u32;
|
||||
|
||||
let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
|
||||
let (new_cache, related_files, related_file_buffers) =
|
||||
rebuild_related_files(&project, new_cache, cx).await?;
|
||||
|
||||
if let Some(file) = &file {
|
||||
log::debug!(
|
||||
@@ -309,7 +250,8 @@ impl RelatedExcerptStore {
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.cache = new_cache;
|
||||
this.related_files = related_files;
|
||||
this.related_files = related_files.into();
|
||||
this.related_file_buffers = related_file_buffers;
|
||||
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
@@ -323,10 +265,16 @@ impl RelatedExcerptStore {
|
||||
}
|
||||
|
||||
async fn rebuild_related_files(
|
||||
project: &Entity<Project>,
|
||||
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
|
||||
) -> Result<(
|
||||
HashMap<Identifier, Arc<CacheEntry>>,
|
||||
Vec<RelatedFile>,
|
||||
Vec<Entity<Buffer>>,
|
||||
)> {
|
||||
let mut snapshots = HashMap::default();
|
||||
let mut worktree_root_names = HashMap::default();
|
||||
for entry in new_entries.values() {
|
||||
for definition in &entry.definitions {
|
||||
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
|
||||
@@ -340,12 +288,22 @@ async fn rebuild_related_files(
|
||||
.read_with(cx, |buffer, _| buffer.snapshot())?,
|
||||
);
|
||||
}
|
||||
let worktree_id = definition.path.worktree_id;
|
||||
if let hash_map::Entry::Vacant(e) =
|
||||
worktree_root_names.entry(definition.path.worktree_id)
|
||||
{
|
||||
project.read_with(cx, |project, cx| {
|
||||
if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
|
||||
e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(cx
|
||||
.background_spawn(async move {
|
||||
let mut files = Vec::<RelatedFile>::new();
|
||||
let mut files = Vec::new();
|
||||
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
|
||||
let mut paths_by_buffer = HashMap::default();
|
||||
for entry in new_entries.values() {
|
||||
@@ -369,16 +327,31 @@ async fn rebuild_related_files(
|
||||
continue;
|
||||
};
|
||||
let excerpts = assemble_excerpts(snapshot, ranges);
|
||||
files.push(RelatedFile {
|
||||
path: project_path.clone(),
|
||||
buffer: buffer.downgrade(),
|
||||
excerpts,
|
||||
max_row: snapshot.max_point().row,
|
||||
});
|
||||
let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let path = Path::new(&format!(
|
||||
"{}/{}",
|
||||
root_name,
|
||||
project_path.path.as_unix_str()
|
||||
))
|
||||
.into();
|
||||
|
||||
files.push((
|
||||
buffer,
|
||||
RelatedFile {
|
||||
path,
|
||||
excerpts,
|
||||
max_row: snapshot.max_point().row,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
files.sort_by_key(|file| file.path.clone());
|
||||
(new_entries, files)
|
||||
files.sort_by_key(|(_, file)| file.path.clone());
|
||||
let (related_buffers, related_files) = files.into_iter().unzip();
|
||||
|
||||
(new_entries, related_files, related_buffers)
|
||||
})
|
||||
.await)
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
&excerpts,
|
||||
&[
|
||||
(
|
||||
"src/company.rs",
|
||||
"root/src/company.rs",
|
||||
&[indoc! {"
|
||||
pub struct Company {
|
||||
owner: Arc<Person>,
|
||||
@@ -56,7 +56,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
}"}],
|
||||
),
|
||||
(
|
||||
"src/main.rs",
|
||||
"root/src/main.rs",
|
||||
&[
|
||||
indoc! {"
|
||||
pub struct Session {
|
||||
@@ -71,7 +71,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
],
|
||||
),
|
||||
(
|
||||
"src/person.rs",
|
||||
"root/src/person.rs",
|
||||
&[
|
||||
indoc! {"
|
||||
impl Person {
|
||||
@@ -446,7 +446,7 @@ fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.text.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
(file.path.path.as_unix_str(), excerpts)
|
||||
(file.path.to_str().unwrap(), excerpts)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let expected_excerpts = expected_files
|
||||
@@ -492,10 +492,10 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
|
||||
if excerpt.text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if current_row < excerpt.point_range.start.row {
|
||||
if current_row < excerpt.row_range.start {
|
||||
writeln!(&mut output, "…").unwrap();
|
||||
}
|
||||
current_row = excerpt.point_range.start.row;
|
||||
current_row = excerpt.row_range.start;
|
||||
|
||||
for line in excerpt.text.to_string().lines() {
|
||||
output.push_str(line);
|
||||
|
||||
@@ -17,7 +17,6 @@ anyhow.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
codestral.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
copilot.workspace = true
|
||||
@@ -46,6 +45,7 @@ ui_input.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
copilot = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -17,7 +17,7 @@ use gpui::{
|
||||
};
|
||||
use multi_buffer::MultiBuffer;
|
||||
use project::Project;
|
||||
use text::OffsetRangeExt;
|
||||
use text::Point;
|
||||
use ui::{
|
||||
ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
|
||||
StyledTypography as _, h_flex, v_flex,
|
||||
@@ -66,7 +66,7 @@ impl EditPredictionContextView {
|
||||
) -> Self {
|
||||
let store = EditPredictionStore::global(client, user_store, cx);
|
||||
|
||||
let mut debug_rx = store.update(cx, |store, _| store.debug_info());
|
||||
let mut debug_rx = store.update(cx, |store, cx| store.debug_info(&project, cx));
|
||||
let _update_task = cx.spawn_in(window, async move |this, cx| {
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
@@ -103,7 +103,8 @@ impl EditPredictionContextView {
|
||||
self.handle_context_retrieval_finished(info, window, cx);
|
||||
}
|
||||
}
|
||||
DebugEvent::EditPredictionRequested(_) => {}
|
||||
DebugEvent::EditPredictionStarted(_) => {}
|
||||
DebugEvent::EditPredictionFinished(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,12 +153,11 @@ impl EditPredictionContextView {
|
||||
run.finished_at = Some(info.timestamp);
|
||||
run.metadata = info.metadata;
|
||||
|
||||
let project = self.project.clone();
|
||||
let related_files = self
|
||||
.store
|
||||
.read(cx)
|
||||
.context_for_project(&self.project, cx)
|
||||
.to_vec();
|
||||
.context_for_project_with_buffers(&self.project, cx)
|
||||
.map_or(Vec::new(), |files| files.collect());
|
||||
|
||||
let editor = run.editor.clone();
|
||||
let multibuffer = run.editor.read(cx).buffer().clone();
|
||||
@@ -168,33 +168,14 @@ impl EditPredictionContextView {
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let mut paths = Vec::new();
|
||||
for related_file in related_files {
|
||||
let (buffer, point_ranges): (_, Vec<_>) =
|
||||
if let Some(buffer) = related_file.buffer.upgrade() {
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
|
||||
(
|
||||
buffer,
|
||||
related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(related_file.path.clone(), cx)
|
||||
})?
|
||||
.await?,
|
||||
related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.point_range.clone())
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
for (related_file, buffer) in related_files {
|
||||
let point_ranges = related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| {
|
||||
Point::new(excerpt.row_range.start, 0)..Point::new(excerpt.row_range.end, 0)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
cx.update(|_, cx| {
|
||||
let path = PathKey::for_buffer(&buffer, cx);
|
||||
paths.push((path, buffer, point_ranges));
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use cloud_zeta2_prompt::write_codeblock;
|
||||
use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
|
||||
use editor::{Editor, ExcerptRange, MultiBuffer};
|
||||
use feature_flags::FeatureFlag;
|
||||
@@ -362,14 +361,14 @@ impl RatePredictionsModal {
|
||||
write!(&mut formatted_inputs, "## Events\n\n").unwrap();
|
||||
|
||||
for event in &prediction.inputs.events {
|
||||
write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
|
||||
formatted_inputs.push_str("```diff\n");
|
||||
zeta_prompt::write_event(&mut formatted_inputs, event.as_ref());
|
||||
formatted_inputs.push_str("```\n\n");
|
||||
}
|
||||
|
||||
write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
|
||||
|
||||
for included_file in &prediction.inputs.included_files {
|
||||
let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
|
||||
write!(&mut formatted_inputs, "## Related files\n\n").unwrap();
|
||||
|
||||
for included_file in prediction.inputs.related_files.as_ref() {
|
||||
write!(
|
||||
&mut formatted_inputs,
|
||||
"### {}\n\n",
|
||||
@@ -377,20 +376,28 @@ impl RatePredictionsModal {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
write_codeblock(
|
||||
&included_file.path,
|
||||
&included_file.excerpts,
|
||||
if included_file.path == prediction.inputs.cursor_path {
|
||||
cursor_insertions.as_slice()
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
included_file.max_row,
|
||||
false,
|
||||
&mut formatted_inputs,
|
||||
);
|
||||
for excerpt in included_file.excerpts.iter() {
|
||||
write!(
|
||||
&mut formatted_inputs,
|
||||
"```{}\n{}\n```\n",
|
||||
included_file.path.display(),
|
||||
excerpt.text
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap();
|
||||
|
||||
writeln!(
|
||||
&mut formatted_inputs,
|
||||
"```{}\n{}<CURSOR>{}\n```\n",
|
||||
prediction.inputs.cursor_path.display(),
|
||||
&prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt],
|
||||
&prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
self.active_prediction = Some(ActivePrediction {
|
||||
prediction,
|
||||
feedback_editor: cx.new(|cx| {
|
||||
|
||||
@@ -45,7 +45,7 @@ impl Editor {
|
||||
|
||||
let bracket_matches_by_accent = self.visible_excerpts(false, cx).into_iter().fold(
|
||||
HashMap::default(),
|
||||
|mut acc, (excerpt_id, (buffer, buffer_version, buffer_range))| {
|
||||
|mut acc, (excerpt_id, (buffer, _, buffer_range))| {
|
||||
let buffer_snapshot = buffer.read(cx).snapshot();
|
||||
if language_settings::language_settings(
|
||||
buffer_snapshot.language().map(|language| language.name()),
|
||||
@@ -62,7 +62,7 @@ impl Editor {
|
||||
let brackets_by_accent = buffer_snapshot
|
||||
.fetch_bracket_ranges(
|
||||
buffer_range.start..buffer_range.end,
|
||||
Some((&buffer_version, fetched_chunks)),
|
||||
Some(fetched_chunks),
|
||||
)
|
||||
.into_iter()
|
||||
.flat_map(|(chunk_range, pairs)| {
|
||||
|
||||
@@ -56,6 +56,7 @@ use sum_tree::{Bias, TreeMap};
|
||||
use text::{BufferId, LineIndent};
|
||||
use ui::{SharedString, px};
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
use ztracing::instrument;
|
||||
|
||||
use std::{
|
||||
any::TypeId,
|
||||
@@ -168,6 +169,7 @@ impl DisplayMap {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn snapshot(&mut self, cx: &mut Context<Self>) -> DisplaySnapshot {
|
||||
let tab_size = Self::tab_size(&self.buffer, cx);
|
||||
|
||||
@@ -195,6 +197,7 @@ impl DisplayMap {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn set_state(&mut self, other: &DisplaySnapshot, cx: &mut Context<Self>) {
|
||||
self.fold(
|
||||
other
|
||||
@@ -211,6 +214,7 @@ impl DisplayMap {
|
||||
}
|
||||
|
||||
/// Creates folds for the given creases.
|
||||
#[instrument(skip_all)]
|
||||
pub fn fold<T: Clone + ToOffset>(&mut self, creases: Vec<Crease<T>>, cx: &mut Context<Self>) {
|
||||
let buffer_snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let edits = self.buffer_subscription.consume().into_inner();
|
||||
@@ -279,6 +283,7 @@ impl DisplayMap {
|
||||
}
|
||||
|
||||
/// Removes any folds with the given ranges.
|
||||
#[instrument(skip_all)]
|
||||
pub fn remove_folds_with_type<T: ToOffset>(
|
||||
&mut self,
|
||||
ranges: impl IntoIterator<Item = Range<T>>,
|
||||
@@ -304,6 +309,7 @@ impl DisplayMap {
|
||||
}
|
||||
|
||||
/// Removes any folds whose ranges intersect any of the given ranges.
|
||||
#[instrument(skip_all)]
|
||||
pub fn unfold_intersecting<T: ToOffset>(
|
||||
&mut self,
|
||||
ranges: impl IntoIterator<Item = Range<T>>,
|
||||
@@ -335,6 +341,7 @@ impl DisplayMap {
|
||||
block_map.remove_intersecting_replace_blocks(offset_ranges, inclusive);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn disable_header_for_buffer(&mut self, buffer_id: BufferId, cx: &mut Context<Self>) {
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let edits = self.buffer_subscription.consume().into_inner();
|
||||
@@ -349,6 +356,7 @@ impl DisplayMap {
|
||||
block_map.disable_header_for_buffer(buffer_id)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn fold_buffers(
|
||||
&mut self,
|
||||
buffer_ids: impl IntoIterator<Item = language::BufferId>,
|
||||
@@ -367,6 +375,7 @@ impl DisplayMap {
|
||||
block_map.fold_buffers(buffer_ids, self.buffer.read(cx), cx)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn unfold_buffers(
|
||||
&mut self,
|
||||
buffer_ids: impl IntoIterator<Item = language::BufferId>,
|
||||
@@ -385,14 +394,17 @@ impl DisplayMap {
|
||||
block_map.unfold_buffers(buffer_ids, self.buffer.read(cx), cx)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn is_buffer_folded(&self, buffer_id: language::BufferId) -> bool {
|
||||
self.block_map.folded_buffers.contains(&buffer_id)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn folded_buffers(&self) -> &HashSet<BufferId> {
|
||||
&self.block_map.folded_buffers
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn insert_creases(
|
||||
&mut self,
|
||||
creases: impl IntoIterator<Item = Crease<Anchor>>,
|
||||
@@ -402,6 +414,7 @@ impl DisplayMap {
|
||||
self.crease_map.insert(creases, &snapshot)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn remove_creases(
|
||||
&mut self,
|
||||
crease_ids: impl IntoIterator<Item = CreaseId>,
|
||||
@@ -411,6 +424,7 @@ impl DisplayMap {
|
||||
self.crease_map.remove(crease_ids, &snapshot)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn insert_blocks(
|
||||
&mut self,
|
||||
blocks: impl IntoIterator<Item = BlockProperties<Anchor>>,
|
||||
@@ -429,6 +443,7 @@ impl DisplayMap {
|
||||
block_map.insert(blocks)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn resize_blocks(&mut self, heights: HashMap<CustomBlockId, u32>, cx: &mut Context<Self>) {
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let edits = self.buffer_subscription.consume().into_inner();
|
||||
@@ -443,10 +458,12 @@ impl DisplayMap {
|
||||
block_map.resize(heights);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn replace_blocks(&mut self, renderers: HashMap<CustomBlockId, RenderBlock>) {
|
||||
self.block_map.replace_blocks(renderers);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn remove_blocks(&mut self, ids: HashSet<CustomBlockId>, cx: &mut Context<Self>) {
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let edits = self.buffer_subscription.consume().into_inner();
|
||||
@@ -461,6 +478,7 @@ impl DisplayMap {
|
||||
block_map.remove(ids);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn row_for_block(
|
||||
&mut self,
|
||||
block_id: CustomBlockId,
|
||||
@@ -480,6 +498,7 @@ impl DisplayMap {
|
||||
Some(DisplayRow(block_row.0))
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn highlight_text(
|
||||
&mut self,
|
||||
key: HighlightKey,
|
||||
@@ -507,6 +526,7 @@ impl DisplayMap {
|
||||
self.text_highlights.insert(key, to_insert);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn highlight_inlays(
|
||||
&mut self,
|
||||
type_id: TypeId,
|
||||
@@ -526,6 +546,7 @@ impl DisplayMap {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn text_highlights(&self, type_id: TypeId) -> Option<(HighlightStyle, &[Range<Anchor>])> {
|
||||
let highlights = self.text_highlights.get(&HighlightKey::Type(type_id))?;
|
||||
Some((highlights.0, &highlights.1))
|
||||
@@ -538,6 +559,7 @@ impl DisplayMap {
|
||||
self.text_highlights.values()
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn clear_highlights(&mut self, type_id: TypeId) -> bool {
|
||||
let mut cleared = self
|
||||
.text_highlights
|
||||
@@ -566,6 +588,7 @@ impl DisplayMap {
|
||||
.update(cx, |map, cx| map.set_wrap_width(width, cx))
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn update_fold_widths(
|
||||
&mut self,
|
||||
widths: impl IntoIterator<Item = (ChunkRendererId, Pixels)>,
|
||||
@@ -597,6 +620,7 @@ impl DisplayMap {
|
||||
self.inlay_map.current_inlays()
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn splice_inlays(
|
||||
&mut self,
|
||||
to_remove: &[InlayId],
|
||||
@@ -626,6 +650,7 @@ impl DisplayMap {
|
||||
self.block_map.read(snapshot, edits);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
fn tab_size(buffer: &Entity<MultiBuffer>, cx: &App) -> NonZeroU32 {
|
||||
let buffer = buffer.read(cx).as_singleton().map(|buffer| buffer.read(cx));
|
||||
let language = buffer
|
||||
@@ -675,6 +700,7 @@ pub struct HighlightedChunk<'a> {
|
||||
}
|
||||
|
||||
impl<'a> HighlightedChunk<'a> {
|
||||
#[instrument(skip_all)]
|
||||
fn highlight_invisibles(
|
||||
self,
|
||||
editor_style: &'a EditorStyle,
|
||||
@@ -832,6 +858,7 @@ impl DisplaySnapshot {
|
||||
self.buffer_snapshot().widest_line_number()
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn prev_line_boundary(&self, mut point: MultiBufferPoint) -> (Point, DisplayPoint) {
|
||||
loop {
|
||||
let mut inlay_point = self.inlay_snapshot().to_inlay_point(point);
|
||||
@@ -850,6 +877,7 @@ impl DisplaySnapshot {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn next_line_boundary(
|
||||
&self,
|
||||
mut point: MultiBufferPoint,
|
||||
@@ -888,6 +916,7 @@ impl DisplaySnapshot {
|
||||
new_start..new_end
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn point_to_display_point(&self, point: MultiBufferPoint, bias: Bias) -> DisplayPoint {
|
||||
let inlay_point = self.inlay_snapshot().to_inlay_point(point);
|
||||
let fold_point = self.fold_snapshot().to_fold_point(inlay_point, bias);
|
||||
@@ -917,6 +946,7 @@ impl DisplaySnapshot {
|
||||
.anchor_at(point.to_offset(self, bias), bias)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
fn display_point_to_inlay_point(&self, point: DisplayPoint, bias: Bias) -> InlayPoint {
|
||||
let block_point = point.0;
|
||||
let wrap_point = self.block_snapshot.to_wrap_point(block_point, bias);
|
||||
@@ -928,6 +958,7 @@ impl DisplaySnapshot {
|
||||
fold_point.to_inlay_point(self.fold_snapshot())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn display_point_to_fold_point(&self, point: DisplayPoint, bias: Bias) -> FoldPoint {
|
||||
let block_point = point.0;
|
||||
let wrap_point = self.block_snapshot.to_wrap_point(block_point, bias);
|
||||
@@ -937,6 +968,7 @@ impl DisplaySnapshot {
|
||||
.0
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn fold_point_to_display_point(&self, fold_point: FoldPoint) -> DisplayPoint {
|
||||
let tab_point = self.tab_snapshot().fold_point_to_tab_point(fold_point);
|
||||
let wrap_point = self.wrap_snapshot().tab_point_to_wrap_point(tab_point);
|
||||
@@ -949,6 +981,7 @@ impl DisplaySnapshot {
|
||||
}
|
||||
|
||||
/// Returns text chunks starting at the given display row until the end of the file
|
||||
#[instrument(skip_all)]
|
||||
pub fn text_chunks(&self, display_row: DisplayRow) -> impl Iterator<Item = &str> {
|
||||
self.block_snapshot
|
||||
.chunks(
|
||||
@@ -961,6 +994,7 @@ impl DisplaySnapshot {
|
||||
}
|
||||
|
||||
/// Returns text chunks starting at the end of the given display row in reverse until the start of the file
|
||||
#[instrument(skip_all)]
|
||||
pub fn reverse_text_chunks(&self, display_row: DisplayRow) -> impl Iterator<Item = &str> {
|
||||
(0..=display_row.0).rev().flat_map(move |row| {
|
||||
self.block_snapshot
|
||||
@@ -977,6 +1011,7 @@ impl DisplaySnapshot {
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn chunks(
|
||||
&self,
|
||||
display_rows: Range<DisplayRow>,
|
||||
@@ -995,6 +1030,7 @@ impl DisplaySnapshot {
|
||||
)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn highlighted_chunks<'a>(
|
||||
&'a self,
|
||||
display_rows: Range<DisplayRow>,
|
||||
@@ -1071,6 +1107,7 @@ impl DisplaySnapshot {
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn layout_row(
|
||||
&self,
|
||||
display_row: DisplayRow,
|
||||
@@ -1132,6 +1169,7 @@ impl DisplaySnapshot {
|
||||
layout_line.closest_index_for_x(x) as u32
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn grapheme_at(&self, mut point: DisplayPoint) -> Option<SharedString> {
|
||||
point = DisplayPoint(self.block_snapshot.clip_point(point.0, Bias::Left));
|
||||
let chars = self
|
||||
@@ -1321,6 +1359,7 @@ impl DisplaySnapshot {
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn crease_for_buffer_row(&self, buffer_row: MultiBufferRow) -> Option<Crease<Point>> {
|
||||
let start =
|
||||
MultiBufferPoint::new(buffer_row.0, self.buffer_snapshot().line_len(buffer_row));
|
||||
@@ -1407,6 +1446,7 @@ impl DisplaySnapshot {
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
#[instrument(skip_all)]
|
||||
pub fn text_highlight_ranges<Tag: ?Sized + 'static>(
|
||||
&self,
|
||||
) -> Option<Arc<(HighlightStyle, Vec<Range<Anchor>>)>> {
|
||||
@@ -1417,6 +1457,7 @@ impl DisplaySnapshot {
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
#[instrument(skip_all)]
|
||||
pub fn all_text_highlight_ranges<Tag: ?Sized + 'static>(
|
||||
&self,
|
||||
) -> Vec<(gpui::Hsla, Range<Point>)> {
|
||||
@@ -1466,6 +1507,7 @@ impl DisplaySnapshot {
|
||||
///
|
||||
/// This moves by buffer rows instead of display rows, a distinction that is
|
||||
/// important when soft wrapping is enabled.
|
||||
#[instrument(skip_all)]
|
||||
pub fn start_of_relative_buffer_row(&self, point: DisplayPoint, times: isize) -> DisplayPoint {
|
||||
let start = self.display_point_to_fold_point(point, Bias::Left);
|
||||
let target = start.row() as isize + times;
|
||||
|
||||
@@ -529,7 +529,7 @@ impl BlockMap {
|
||||
BlockMapWriter(self)
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all, fields(edits))]
|
||||
#[ztracing::instrument(skip_all, fields(edits = ?edits))]
|
||||
fn sync(&self, wrap_snapshot: &WrapSnapshot, mut edits: WrapPatch) {
|
||||
let _timer = zlog::time!("BlockMap::sync").warn_if_gt(std::time::Duration::from_millis(50));
|
||||
|
||||
@@ -570,6 +570,9 @@ impl BlockMap {
|
||||
let mut wrap_point_cursor = wrap_snapshot.wrap_point_cursor();
|
||||
|
||||
while let Some(edit) = edits.next() {
|
||||
let span = ztracing::debug_span!("while edits", edit = ?edit);
|
||||
let _enter = span.enter();
|
||||
|
||||
let mut old_start = edit.old.start;
|
||||
let mut new_start = edit.new.start;
|
||||
|
||||
@@ -628,6 +631,8 @@ impl BlockMap {
|
||||
let mut old_end = edit.old.end;
|
||||
let mut new_end = edit.new.end;
|
||||
loop {
|
||||
let span = ztracing::debug_span!("decide where edit ends loop");
|
||||
let _enter = span.enter();
|
||||
// Seek to the transform starting at or after the end of the edit
|
||||
cursor.seek(&old_end, Bias::Left);
|
||||
cursor.next();
|
||||
@@ -736,6 +741,10 @@ impl BlockMap {
|
||||
// and then insert the block itself.
|
||||
let mut just_processed_folded_buffer = false;
|
||||
for (block_placement, block) in blocks_in_edit.drain(..) {
|
||||
let span =
|
||||
ztracing::debug_span!("for block in edits", block_height = block.height());
|
||||
let _enter = span.enter();
|
||||
|
||||
let mut summary = TransformSummary {
|
||||
input_rows: WrapRow(0),
|
||||
output_rows: BlockRow(block.height()),
|
||||
@@ -957,6 +966,7 @@ impl BlockMap {
|
||||
}
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip(tree, wrap_snapshot))]
|
||||
fn push_isomorphic(tree: &mut SumTree<Transform>, rows: RowDelta, wrap_snapshot: &WrapSnapshot) {
|
||||
if rows == RowDelta(0) {
|
||||
return;
|
||||
|
||||
@@ -840,7 +840,7 @@ impl WrapSnapshot {
|
||||
self.tab_point_to_wrap_point(self.tab_snapshot.clip_point(self.to_tab_point(point), bias))
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all, fields(point, ret))]
|
||||
#[ztracing::instrument(skip_all, fields(point=?point, ret))]
|
||||
pub fn prev_row_boundary(&self, mut point: WrapPoint) -> WrapRow {
|
||||
if self.transforms.is_empty() {
|
||||
return WrapRow(0);
|
||||
@@ -851,11 +851,14 @@ impl WrapSnapshot {
|
||||
let mut cursor = self
|
||||
.transforms
|
||||
.cursor::<Dimensions<WrapPoint, TabPoint>>(());
|
||||
// start
|
||||
cursor.seek(&point, Bias::Right);
|
||||
// end
|
||||
if cursor.item().is_none() {
|
||||
cursor.prev();
|
||||
}
|
||||
|
||||
// start
|
||||
while let Some(transform) = cursor.item() {
|
||||
if transform.is_isomorphic() && cursor.start().1.column() == 0 {
|
||||
return cmp::min(cursor.end().0.row(), point.row());
|
||||
@@ -863,6 +866,7 @@ impl WrapSnapshot {
|
||||
cursor.prev();
|
||||
}
|
||||
}
|
||||
// end
|
||||
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
@@ -7135,6 +7135,7 @@ impl Editor {
|
||||
Some((query, selection_anchor_range))
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all)]
|
||||
fn update_selection_occurrence_highlights(
|
||||
&mut self,
|
||||
query_text: String,
|
||||
@@ -7279,6 +7280,7 @@ impl Editor {
|
||||
});
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all)]
|
||||
fn refresh_selected_text_highlights(
|
||||
&mut self,
|
||||
on_buffer_edit: bool,
|
||||
@@ -20973,9 +20975,22 @@ impl Editor {
|
||||
buffer_ranges.last()
|
||||
}?;
|
||||
|
||||
let selection = text::ToPoint::to_point(&range.start, buffer).row
|
||||
..text::ToPoint::to_point(&range.end, buffer).row;
|
||||
Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection))
|
||||
let start_row_in_buffer = text::ToPoint::to_point(&range.start, buffer).row;
|
||||
let end_row_in_buffer = text::ToPoint::to_point(&range.end, buffer).row;
|
||||
|
||||
let Some(buffer_diff) = multi_buffer.diff_for(buffer.remote_id()) else {
|
||||
let selection = start_row_in_buffer..end_row_in_buffer;
|
||||
|
||||
return Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection));
|
||||
};
|
||||
|
||||
let buffer_diff_snapshot = buffer_diff.read(cx).snapshot(cx);
|
||||
|
||||
Some((
|
||||
multi_buffer.buffer(buffer.remote_id()).unwrap(),
|
||||
buffer_diff_snapshot.row_to_base_text_row(start_row_in_buffer, buffer)
|
||||
..buffer_diff_snapshot.row_to_base_text_row(end_row_in_buffer, buffer),
|
||||
))
|
||||
});
|
||||
|
||||
let Some((buffer, selection)) = buffer_and_selection else {
|
||||
|
||||
@@ -27701,6 +27701,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("x", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
@@ -27716,8 +27717,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.bˇ
|
||||
"
|
||||
- [x] Item 2.bˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.newline(&Newline, window, cx);
|
||||
@@ -27728,34 +27728,41 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
ˇ
|
||||
"
|
||||
ˇ"
|
||||
});
|
||||
|
||||
// Case 3: Test adding a new nested list item preserves indent
|
||||
cx.set_state(&indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("-", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
-ˇ
|
||||
"
|
||||
-ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input(" [x] Item 2.c", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
- [x] Item 2.cˇ
|
||||
"
|
||||
- [x] Item 2.cˇ"
|
||||
});
|
||||
|
||||
// Case 4: Test adding new line after nested ordered list preserves indent of previous line
|
||||
@@ -27764,8 +27771,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.bˇ
|
||||
"
|
||||
2. Item 2.bˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.newline(&Newline, window, cx);
|
||||
@@ -27776,60 +27782,81 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
ˇ
|
||||
"
|
||||
ˇ"
|
||||
});
|
||||
|
||||
// Case 5: Adding new ordered list item preserves indent
|
||||
cx.set_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("3", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
3ˇ
|
||||
"
|
||||
3ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input(".", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
3.ˇ
|
||||
"
|
||||
3.ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input(" Item 2.c", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
3. Item 2.cˇ
|
||||
"
|
||||
3. Item 2.cˇ"
|
||||
});
|
||||
|
||||
// Case 6: Test adding new line after nested ordered list preserves indent of previous line
|
||||
cx.set_state(indoc! {"
|
||||
- Item 1
|
||||
- Item 1.a
|
||||
- Item 1.a
|
||||
ˇ"});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("-", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- Item 1
|
||||
- Item 1.a
|
||||
- Item 1.a
|
||||
-ˇ"});
|
||||
|
||||
// Case 7: Test blockquote newline preserves something
|
||||
cx.set_state(indoc! {"
|
||||
> Item 1ˇ
|
||||
"
|
||||
> Item 1ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.newline(&Newline, window, cx);
|
||||
});
|
||||
cx.assert_editor_state(indoc! {"
|
||||
> Item 1
|
||||
ˇ
|
||||
"
|
||||
ˇ"
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ use theme::ActiveTheme;
|
||||
enum MatchingBracketHighlight {}
|
||||
|
||||
impl Editor {
|
||||
#[ztracing::instrument(skip_all)]
|
||||
pub fn refresh_matching_bracket_highlights(
|
||||
&mut self,
|
||||
window: &Window,
|
||||
|
||||
@@ -623,7 +623,10 @@ pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
});
|
||||
MarkdownStyle {
|
||||
base_text_style,
|
||||
code_block: StyleRefinement::default().my(rems(1.)).font_buffer(cx),
|
||||
code_block: StyleRefinement::default()
|
||||
.my(rems(1.))
|
||||
.font_buffer(cx)
|
||||
.font_features(buffer_font_features.clone()),
|
||||
inline_code: TextStyleRefinement {
|
||||
background_color: Some(cx.theme().colors().background),
|
||||
font_family: Some(buffer_font_family),
|
||||
|
||||
@@ -892,7 +892,7 @@ pub fn wait_for_lang_server(
|
||||
.update(cx, |buffer, cx| {
|
||||
lsp_store.update(cx, |lsp_store, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(buffer, cx)
|
||||
.running_language_servers_for_local_buffer(buffer, cx)
|
||||
.next()
|
||||
.is_some()
|
||||
})
|
||||
|
||||
@@ -23,6 +23,7 @@ use std::{
|
||||
path::PathBuf,
|
||||
sync::{Arc, LazyLock},
|
||||
};
|
||||
use text::LineEnding;
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
|
||||
pub static LOAD_INDEX_TEXT_TASK: LazyLock<TaskLabel> = LazyLock::new(TaskLabel::new);
|
||||
@@ -200,6 +201,7 @@ impl GitRepository for FakeGitRepository {
|
||||
async {
|
||||
Ok(CommitDetails {
|
||||
sha: commit.into(),
|
||||
message: "initial commit".into(),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
@@ -451,7 +453,12 @@ impl GitRepository for FakeGitRepository {
|
||||
})
|
||||
}
|
||||
|
||||
fn blame(&self, path: RepoPath, _content: Rope) -> BoxFuture<'_, Result<git::blame::Blame>> {
|
||||
fn blame(
|
||||
&self,
|
||||
path: RepoPath,
|
||||
_content: Rope,
|
||||
_line_ending: LineEnding,
|
||||
) -> BoxFuture<'_, Result<git::blame::Blame>> {
|
||||
self.with_state_async(false, move |state| {
|
||||
state
|
||||
.blames
|
||||
@@ -568,7 +575,7 @@ impl GitRepository for FakeGitRepository {
|
||||
_askpass: AskPassDelegate,
|
||||
_env: Arc<HashMap<String, String>>,
|
||||
) -> BoxFuture<'_, Result<()>> {
|
||||
unimplemented!()
|
||||
async { Ok(()) }.boxed()
|
||||
}
|
||||
|
||||
fn run_hook(
|
||||
@@ -576,7 +583,7 @@ impl GitRepository for FakeGitRepository {
|
||||
_hook: RunHook,
|
||||
_env: Arc<HashMap<String, String>>,
|
||||
) -> BoxFuture<'_, Result<()>> {
|
||||
unimplemented!()
|
||||
async { Ok(()) }.boxed()
|
||||
}
|
||||
|
||||
fn push(
|
||||
|
||||
@@ -803,7 +803,7 @@ impl Fs for RealFs {
|
||||
}
|
||||
let file = smol::fs::File::create(path).await?;
|
||||
let mut writer = smol::io::BufWriter::with_capacity(buffer_size, file);
|
||||
for chunk in chunks(text, line_ending) {
|
||||
for chunk in text::chunks_with_line_ending(text, line_ending) {
|
||||
writer.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
writer.flush().await?;
|
||||
@@ -2555,7 +2555,7 @@ impl Fs for FakeFs {
|
||||
async fn save(&self, path: &Path, text: &Rope, line_ending: LineEnding) -> Result<()> {
|
||||
self.simulate_random_delay().await;
|
||||
let path = normalize_path(path);
|
||||
let content = chunks(text, line_ending).collect::<String>();
|
||||
let content = text::chunks_with_line_ending(text, line_ending).collect::<String>();
|
||||
if let Some(path) = path.parent() {
|
||||
self.create_dir(path).await?;
|
||||
}
|
||||
@@ -2773,25 +2773,6 @@ impl Fs for FakeFs {
|
||||
}
|
||||
}
|
||||
|
||||
fn chunks(rope: &Rope, line_ending: LineEnding) -> impl Iterator<Item = &str> {
|
||||
rope.chunks().flat_map(move |chunk| {
|
||||
let mut newline = false;
|
||||
let end_with_newline = chunk.ends_with('\n').then_some(line_ending.as_str());
|
||||
chunk
|
||||
.lines()
|
||||
.flat_map(move |line| {
|
||||
let ending = if newline {
|
||||
Some(line_ending.as_str())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
newline = true;
|
||||
ending.into_iter().chain([line])
|
||||
})
|
||||
.chain(end_with_newline)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn normalize_path(path: &Path) -> PathBuf {
|
||||
let mut components = path.components().peekable();
|
||||
let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek().cloned() {
|
||||
|
||||
@@ -8,7 +8,7 @@ use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use text::{LineEnding, Rope};
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
@@ -35,8 +35,10 @@ impl Blame {
|
||||
working_directory: &Path,
|
||||
path: &RepoPath,
|
||||
content: &Rope,
|
||||
line_ending: LineEnding,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let output =
|
||||
run_git_blame(git_binary, working_directory, path, content, line_ending).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
@@ -63,12 +65,12 @@ async fn run_git_blame(
|
||||
working_directory: &Path,
|
||||
path: &RepoPath,
|
||||
contents: &Rope,
|
||||
line_ending: LineEnding,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("-w")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_unix_str())
|
||||
@@ -83,7 +85,7 @@ async fn run_git_blame(
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
for chunk in text::chunks_with_line_ending(contents, line_ending) {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
@@ -232,14 +232,12 @@ impl From<Oid> for usize {
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RunHook {
|
||||
PreCommit,
|
||||
PrePush,
|
||||
}
|
||||
|
||||
impl RunHook {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
Self::PreCommit => "pre-commit",
|
||||
Self::PrePush => "pre-push",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,7 +248,6 @@ impl RunHook {
|
||||
pub fn from_proto(value: i32) -> Option<Self> {
|
||||
match value {
|
||||
0 => Some(Self::PreCommit),
|
||||
1 => Some(Self::PrePush),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ use rope::Rope;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use smol::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
|
||||
use text::LineEnding;
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::{OsStr, OsString};
|
||||
@@ -487,7 +488,12 @@ pub trait GitRepository: Send + Sync {
|
||||
fn show(&self, commit: String) -> BoxFuture<'_, Result<CommitDetails>>;
|
||||
|
||||
fn load_commit(&self, commit: String, cx: AsyncApp) -> BoxFuture<'_, Result<CommitDiff>>;
|
||||
fn blame(&self, path: RepoPath, content: Rope) -> BoxFuture<'_, Result<crate::blame::Blame>>;
|
||||
fn blame(
|
||||
&self,
|
||||
path: RepoPath,
|
||||
content: Rope,
|
||||
line_ending: LineEnding,
|
||||
) -> BoxFuture<'_, Result<crate::blame::Blame>>;
|
||||
fn file_history(&self, path: RepoPath) -> BoxFuture<'_, Result<FileHistory>>;
|
||||
fn file_history_paginated(
|
||||
&self,
|
||||
@@ -652,6 +658,7 @@ pub struct RealGitRepository {
|
||||
pub repository: Arc<Mutex<git2::Repository>>,
|
||||
pub system_git_binary_path: Option<PathBuf>,
|
||||
pub any_git_binary_path: PathBuf,
|
||||
any_git_binary_help_output: Arc<Mutex<Option<SharedString>>>,
|
||||
executor: BackgroundExecutor,
|
||||
}
|
||||
|
||||
@@ -670,6 +677,7 @@ impl RealGitRepository {
|
||||
system_git_binary_path,
|
||||
any_git_binary_path,
|
||||
executor,
|
||||
any_git_binary_help_output: Arc::new(Mutex::new(None)),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -680,6 +688,27 @@ impl RealGitRepository {
|
||||
.context("failed to read git work directory")
|
||||
.map(Path::to_path_buf)
|
||||
}
|
||||
|
||||
async fn any_git_binary_help_output(&self) -> SharedString {
|
||||
if let Some(output) = self.any_git_binary_help_output.lock().clone() {
|
||||
return output;
|
||||
}
|
||||
let git_binary_path = self.any_git_binary_path.clone();
|
||||
let executor = self.executor.clone();
|
||||
let working_directory = self.working_directory();
|
||||
let output: SharedString = self
|
||||
.executor
|
||||
.spawn(async move {
|
||||
GitBinary::new(git_binary_path, working_directory?, executor)
|
||||
.run(["help", "-a"])
|
||||
.await
|
||||
})
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
.into();
|
||||
*self.any_git_binary_help_output.lock() = Some(output.clone());
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -1489,7 +1518,12 @@ impl GitRepository for RealGitRepository {
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn blame(&self, path: RepoPath, content: Rope) -> BoxFuture<'_, Result<crate::blame::Blame>> {
|
||||
fn blame(
|
||||
&self,
|
||||
path: RepoPath,
|
||||
content: Rope,
|
||||
line_ending: LineEnding,
|
||||
) -> BoxFuture<'_, Result<crate::blame::Blame>> {
|
||||
let working_directory = self.working_directory();
|
||||
let git_binary_path = self.any_git_binary_path.clone();
|
||||
let executor = self.executor.clone();
|
||||
@@ -1501,6 +1535,7 @@ impl GitRepository for RealGitRepository {
|
||||
&working_directory?,
|
||||
&path,
|
||||
&content,
|
||||
line_ending,
|
||||
)
|
||||
.await
|
||||
})
|
||||
@@ -1819,6 +1854,7 @@ impl GitRepository for RealGitRepository {
|
||||
.args(["commit", "--quiet", "-m"])
|
||||
.arg(&message.to_string())
|
||||
.arg("--cleanup=strip")
|
||||
.arg("--no-verify")
|
||||
.stdout(smol::process::Stdio::piped())
|
||||
.stderr(smol::process::Stdio::piped());
|
||||
|
||||
@@ -2289,48 +2325,47 @@ impl GitRepository for RealGitRepository {
|
||||
env: Arc<HashMap<String, String>>,
|
||||
) -> BoxFuture<'_, Result<()>> {
|
||||
let working_directory = self.working_directory();
|
||||
let repository = self.repository.clone();
|
||||
let git_binary_path = self.any_git_binary_path.clone();
|
||||
let executor = self.executor.clone();
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
let working_directory = working_directory?;
|
||||
let git = GitBinary::new(git_binary_path, working_directory.clone(), executor)
|
||||
.envs(HashMap::clone(&env));
|
||||
let help_output = self.any_git_binary_help_output();
|
||||
|
||||
let output = git.run(&["help", "-a"]).await?;
|
||||
if !output.lines().any(|line| line.trim().starts_with("hook ")) {
|
||||
log::warn!(
|
||||
"git hook command not available, running the {} hook manually",
|
||||
hook.as_str()
|
||||
);
|
||||
// Note: Do not spawn these commands on the background thread, as this causes some git hooks to hang.
|
||||
async move {
|
||||
let working_directory = working_directory?;
|
||||
if !help_output
|
||||
.await
|
||||
.lines()
|
||||
.any(|line| line.trim().starts_with("hook "))
|
||||
{
|
||||
let hook_abs_path = repository.lock().path().join("hooks").join(hook.as_str());
|
||||
if hook_abs_path.is_file() {
|
||||
let output = new_smol_command(&hook_abs_path)
|
||||
.envs(env.iter())
|
||||
.current_dir(&working_directory)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
let hook_abs_path = working_directory
|
||||
.join(".git")
|
||||
.join("hooks")
|
||||
.join(hook.as_str());
|
||||
if hook_abs_path.is_file() {
|
||||
let output = new_smol_command(&hook_abs_path)
|
||||
.envs(env.iter())
|
||||
.current_dir(&working_directory)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"{} hook failed:\n{}",
|
||||
hook.as_str(),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
if !output.status.success() {
|
||||
return Err(GitBinaryCommandError {
|
||||
stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
|
||||
stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
|
||||
status: output.status,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
.boxed()
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let git = GitBinary::new(git_binary_path, working_directory, executor)
|
||||
.envs(HashMap::clone(&env));
|
||||
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -47,11 +47,13 @@ impl BlameRenderer for GitBlameRenderer {
|
||||
let name = util::truncate_and_trailoff(author_name, GIT_BLAME_MAX_AUTHOR_CHARS_DISPLAYED);
|
||||
|
||||
let avatar = if ProjectSettings::get_global(cx).git.blame.show_avatar {
|
||||
CommitAvatar::new(
|
||||
&blame_entry.sha.to_string().into(),
|
||||
details.as_ref().and_then(|it| it.remote.as_ref()),
|
||||
Some(
|
||||
CommitAvatar::new(
|
||||
&blame_entry.sha.to_string().into(),
|
||||
details.as_ref().and_then(|it| it.remote.as_ref()),
|
||||
)
|
||||
.render(window, cx),
|
||||
)
|
||||
.render(window, cx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -65,7 +67,7 @@ impl BlameRenderer for GitBlameRenderer {
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.font_family(style.font().family)
|
||||
.font(style.font())
|
||||
.line_height(style.line_height)
|
||||
.text_color(cx.theme().status().hint)
|
||||
.child(
|
||||
@@ -264,7 +266,7 @@ impl BlameRenderer for GitBlameRenderer {
|
||||
.flex_wrap()
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.children(avatar)
|
||||
.child(avatar)
|
||||
.child(author)
|
||||
.when(!author_email.is_empty(), |this| {
|
||||
this.child(
|
||||
|
||||
@@ -139,7 +139,7 @@ impl CommitModal {
|
||||
&& !git_panel.amend_pending()
|
||||
{
|
||||
git_panel.set_amend_pending(true, cx);
|
||||
git_panel.load_last_commit_message_if_empty(cx);
|
||||
git_panel.load_last_commit_message(cx);
|
||||
}
|
||||
}
|
||||
ForceMode::Commit => {
|
||||
@@ -492,53 +492,20 @@ impl CommitModal {
|
||||
}
|
||||
}
|
||||
|
||||
fn commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.git_panel.read(cx).amend_pending() {
|
||||
return;
|
||||
fn on_commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.git_panel.update(cx, |git_panel, cx| {
|
||||
git_panel.commit(&self.commit_editor.focus_handle(cx), window, cx)
|
||||
}) {
|
||||
telemetry::event!("Git Committed", source = "Git Modal");
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
telemetry::event!("Git Committed", source = "Git Modal");
|
||||
self.git_panel.update(cx, |git_panel, cx| {
|
||||
git_panel.commit_changes(
|
||||
CommitOptions {
|
||||
amend: false,
|
||||
signoff: git_panel.signoff_enabled(),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self
|
||||
.git_panel
|
||||
.read(cx)
|
||||
.active_repository
|
||||
.as_ref()
|
||||
.and_then(|repo| repo.read(cx).head_commit.as_ref())
|
||||
.is_none()
|
||||
{
|
||||
return;
|
||||
}
|
||||
if !self.git_panel.read(cx).amend_pending() {
|
||||
self.git_panel.update(cx, |git_panel, cx| {
|
||||
git_panel.set_amend_pending(true, cx);
|
||||
git_panel.load_last_commit_message_if_empty(cx);
|
||||
});
|
||||
} else {
|
||||
fn on_amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.git_panel.update(cx, |git_panel, cx| {
|
||||
git_panel.amend(&self.commit_editor.focus_handle(cx), window, cx)
|
||||
}) {
|
||||
telemetry::event!("Git Amended", source = "Git Modal");
|
||||
self.git_panel.update(cx, |git_panel, cx| {
|
||||
git_panel.set_amend_pending(false, cx);
|
||||
git_panel.commit_changes(
|
||||
CommitOptions {
|
||||
amend: true,
|
||||
signoff: git_panel.signoff_enabled(),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
@@ -564,8 +531,8 @@ impl Render for CommitModal {
|
||||
.id("commit-modal")
|
||||
.key_context("GitCommit")
|
||||
.on_action(cx.listener(Self::dismiss))
|
||||
.on_action(cx.listener(Self::commit))
|
||||
.on_action(cx.listener(Self::amend))
|
||||
.on_action(cx.listener(Self::on_commit))
|
||||
.on_action(cx.listener(Self::on_amend))
|
||||
.when(!DisableAiSettings::get_global(cx).disable_ai, |this| {
|
||||
this.on_action(cx.listener(|this, _: &GenerateCommitMessage, _, cx| {
|
||||
this.git_panel.update(cx, |panel, cx| {
|
||||
|
||||
@@ -29,11 +29,16 @@ pub struct CommitDetails {
|
||||
pub struct CommitAvatar<'a> {
|
||||
sha: &'a SharedString,
|
||||
remote: Option<&'a GitRemote>,
|
||||
size: Option<IconSize>,
|
||||
}
|
||||
|
||||
impl<'a> CommitAvatar<'a> {
|
||||
pub fn new(sha: &'a SharedString, remote: Option<&'a GitRemote>) -> Self {
|
||||
Self { sha, remote }
|
||||
Self {
|
||||
sha,
|
||||
remote,
|
||||
size: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_commit_details(details: &'a CommitDetails) -> Self {
|
||||
@@ -43,28 +48,37 @@ impl<'a> CommitAvatar<'a> {
|
||||
.message
|
||||
.as_ref()
|
||||
.and_then(|details| details.remote.as_ref()),
|
||||
size: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> CommitAvatar<'a> {
|
||||
pub fn render(&'a self, window: &mut Window, cx: &mut App) -> Option<impl IntoElement + use<>> {
|
||||
pub fn size(mut self, size: IconSize) -> Self {
|
||||
self.size = Some(size);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn render(&'a self, window: &mut Window, cx: &mut App) -> AnyElement {
|
||||
match self.avatar(window, cx) {
|
||||
// Loading or no avatar found
|
||||
None => Icon::new(IconName::Person)
|
||||
.color(Color::Muted)
|
||||
.when_some(self.size, |this, size| this.size(size))
|
||||
.into_any_element(),
|
||||
// Found
|
||||
Some(avatar) => avatar
|
||||
.when_some(self.size, |this, size| this.size(size.rems()))
|
||||
.into_any_element(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn avatar(&'a self, window: &mut Window, cx: &mut App) -> Option<Avatar> {
|
||||
let remote = self
|
||||
.remote
|
||||
.filter(|remote| remote.host_supports_avatars())?;
|
||||
|
||||
let avatar_url = CommitAvatarAsset::new(remote.clone(), self.sha.clone());
|
||||
|
||||
let element = match window.use_asset::<CommitAvatarAsset>(&avatar_url, cx) {
|
||||
// Loading or no avatar found
|
||||
None | Some(None) => Icon::new(IconName::Person)
|
||||
.color(Color::Muted)
|
||||
.into_element()
|
||||
.into_any(),
|
||||
// Found
|
||||
Some(Some(url)) => Avatar::new(url.to_string()).into_element().into_any(),
|
||||
};
|
||||
Some(element)
|
||||
let url = window.use_asset::<CommitAvatarAsset>(&avatar_url, cx)??;
|
||||
Some(Avatar::new(url.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,7 +267,7 @@ impl Render for CommitTooltip {
|
||||
.gap_x_2()
|
||||
.overflow_x_hidden()
|
||||
.flex_wrap()
|
||||
.children(avatar)
|
||||
.child(avatar)
|
||||
.child(author)
|
||||
.when(!author_email.is_empty(), |this| {
|
||||
this.child(
|
||||
|
||||
@@ -5,8 +5,8 @@ use editor::{Editor, EditorEvent, ExcerptRange, MultiBuffer, multibuffer_context
|
||||
use git::repository::{CommitDetails, CommitDiff, RepoPath};
|
||||
use git::{GitHostingProviderRegistry, GitRemote, parse_git_remote_url};
|
||||
use gpui::{
|
||||
AnyElement, App, AppContext as _, Asset, AsyncApp, AsyncWindowContext, Context, Element,
|
||||
Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ParentElement,
|
||||
AnyElement, App, AppContext as _, AsyncApp, AsyncWindowContext, Context, Element, Entity,
|
||||
EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ParentElement,
|
||||
PromptLevel, Render, Styled, Task, WeakEntity, Window, actions,
|
||||
};
|
||||
use language::{
|
||||
@@ -21,7 +21,7 @@ use std::{
|
||||
sync::Arc,
|
||||
};
|
||||
use theme::ActiveTheme;
|
||||
use ui::{Avatar, DiffStat, Tooltip, prelude::*};
|
||||
use ui::{DiffStat, Tooltip, prelude::*};
|
||||
use util::{ResultExt, paths::PathStyle, rel_path::RelPath, truncate_and_trailoff};
|
||||
use workspace::item::TabTooltipContent;
|
||||
use workspace::{
|
||||
@@ -33,6 +33,7 @@ use workspace::{
|
||||
searchable::SearchableItemHandle,
|
||||
};
|
||||
|
||||
use crate::commit_tooltip::CommitAvatar;
|
||||
use crate::git_panel::GitPanel;
|
||||
|
||||
actions!(git, [ApplyCurrentStash, PopCurrentStash, DropCurrentStash,]);
|
||||
@@ -318,17 +319,7 @@ impl CommitView {
|
||||
cx: &mut App,
|
||||
) -> AnyElement {
|
||||
let size = size.into();
|
||||
let remote = self.remote.as_ref().filter(|r| r.host_supports_avatars());
|
||||
|
||||
if let Some(remote) = remote {
|
||||
let avatar_asset = CommitAvatarAsset::new(remote.clone(), sha.clone());
|
||||
if let Some(Some(url)) = window.use_asset::<CommitAvatarAsset>(&avatar_asset, cx) {
|
||||
return Avatar::new(url.to_string())
|
||||
.size(size)
|
||||
.into_element()
|
||||
.into_any();
|
||||
}
|
||||
}
|
||||
let avatar = CommitAvatar::new(sha, self.remote.as_ref());
|
||||
|
||||
v_flex()
|
||||
.w(size)
|
||||
@@ -339,10 +330,15 @@ impl CommitView {
|
||||
.justify_center()
|
||||
.items_center()
|
||||
.child(
|
||||
Icon::new(IconName::Person)
|
||||
.color(Color::Muted)
|
||||
.size(IconSize::Medium)
|
||||
.into_element(),
|
||||
avatar
|
||||
.avatar(window, cx)
|
||||
.map(|a| a.size(size).into_any_element())
|
||||
.unwrap_or_else(|| {
|
||||
Icon::new(IconName::Person)
|
||||
.color(Color::Muted)
|
||||
.size(IconSize::Medium)
|
||||
.into_any_element()
|
||||
}),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
@@ -647,54 +643,6 @@ impl CommitView {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct CommitAvatarAsset {
|
||||
sha: SharedString,
|
||||
remote: GitRemote,
|
||||
}
|
||||
|
||||
impl std::hash::Hash for CommitAvatarAsset {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.sha.hash(state);
|
||||
self.remote.host.name().hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl CommitAvatarAsset {
|
||||
fn new(remote: GitRemote, sha: SharedString) -> Self {
|
||||
Self { remote, sha }
|
||||
}
|
||||
}
|
||||
|
||||
impl Asset for CommitAvatarAsset {
|
||||
type Source = Self;
|
||||
type Output = Option<SharedString>;
|
||||
|
||||
fn load(
|
||||
source: Self::Source,
|
||||
cx: &mut App,
|
||||
) -> impl Future<Output = Self::Output> + Send + 'static {
|
||||
let client = cx.http_client();
|
||||
async move {
|
||||
match source
|
||||
.remote
|
||||
.host
|
||||
.commit_author_avatar_url(
|
||||
&source.remote.owner,
|
||||
&source.remote.repo,
|
||||
source.sha.clone(),
|
||||
client,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(url)) => Some(SharedString::from(url.to_string())),
|
||||
Ok(None) => None,
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl language::File for GitBlob {
|
||||
fn as_local(&self) -> Option<&dyn language::LocalFile> {
|
||||
None
|
||||
|
||||
@@ -111,6 +111,7 @@ fn excerpt_for_buffer_updated(
|
||||
);
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all)]
|
||||
fn buffer_added(editor: &mut Editor, buffer: Entity<Buffer>, cx: &mut Context<Editor>) {
|
||||
let Some(project) = editor.project() else {
|
||||
return;
|
||||
@@ -166,6 +167,7 @@ fn buffers_removed(editor: &mut Editor, removed_buffer_ids: &[BufferId], cx: &mu
|
||||
editor.remove_blocks(removed_block_ids, None, cx);
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all)]
|
||||
fn conflicts_updated(
|
||||
editor: &mut Editor,
|
||||
conflict_set: Entity<ConflictSet>,
|
||||
@@ -311,6 +313,7 @@ fn conflicts_updated(
|
||||
}
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all)]
|
||||
fn update_conflict_highlighting(
|
||||
editor: &mut Editor,
|
||||
conflict: &ConflictRegion,
|
||||
|
||||
@@ -1934,16 +1934,26 @@ impl GitPanel {
|
||||
}
|
||||
}
|
||||
|
||||
fn commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.amend_pending {
|
||||
return;
|
||||
}
|
||||
if self
|
||||
.commit_editor
|
||||
.focus_handle(cx)
|
||||
.contains_focused(window, cx)
|
||||
{
|
||||
fn on_commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.commit(&self.commit_editor.focus_handle(cx), window, cx) {
|
||||
telemetry::event!("Git Committed", source = "Git Panel");
|
||||
}
|
||||
}
|
||||
|
||||
/// Commits staged changes with the current commit message.
|
||||
///
|
||||
/// Returns `true` if the commit was executed, `false` otherwise.
|
||||
pub(crate) fn commit(
|
||||
&mut self,
|
||||
commit_editor_focus_handle: &FocusHandle,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
if self.amend_pending {
|
||||
return false;
|
||||
}
|
||||
|
||||
if commit_editor_focus_handle.contains_focused(window, cx) {
|
||||
self.commit_changes(
|
||||
CommitOptions {
|
||||
amend: false,
|
||||
@@ -1951,24 +1961,39 @@ impl GitPanel {
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
);
|
||||
true
|
||||
} else {
|
||||
cx.propagate();
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self
|
||||
.commit_editor
|
||||
.focus_handle(cx)
|
||||
.contains_focused(window, cx)
|
||||
{
|
||||
fn on_amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.amend(&self.commit_editor.focus_handle(cx), window, cx) {
|
||||
telemetry::event!("Git Amended", source = "Git Panel");
|
||||
}
|
||||
}
|
||||
|
||||
/// Amends the most recent commit with staged changes and/or an updated commit message.
|
||||
///
|
||||
/// Uses a two-stage workflow where the first invocation loads the commit
|
||||
/// message for editing, second invocation performs the amend. Returns
|
||||
/// `true` if the amend was executed, `false` otherwise.
|
||||
pub(crate) fn amend(
|
||||
&mut self,
|
||||
commit_editor_focus_handle: &FocusHandle,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
if commit_editor_focus_handle.contains_focused(window, cx) {
|
||||
if self.head_commit(cx).is_some() {
|
||||
if !self.amend_pending {
|
||||
self.set_amend_pending(true, cx);
|
||||
self.load_last_commit_message_if_empty(cx);
|
||||
self.load_last_commit_message(cx);
|
||||
|
||||
return false;
|
||||
} else {
|
||||
telemetry::event!("Git Amended", source = "Git Panel");
|
||||
self.commit_changes(
|
||||
CommitOptions {
|
||||
amend: true,
|
||||
@@ -1977,13 +2002,16 @@ impl GitPanel {
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
cx.propagate();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn head_commit(&self, cx: &App) -> Option<CommitDetails> {
|
||||
self.active_repository
|
||||
.as_ref()
|
||||
@@ -1991,13 +2019,11 @@ impl GitPanel {
|
||||
.cloned()
|
||||
}
|
||||
|
||||
pub fn load_last_commit_message_if_empty(&mut self, cx: &mut Context<Self>) {
|
||||
if !self.commit_editor.read(cx).is_empty(cx) {
|
||||
return;
|
||||
}
|
||||
pub fn load_last_commit_message(&mut self, cx: &mut Context<Self>) {
|
||||
let Some(head_commit) = self.head_commit(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let recent_sha = head_commit.sha.to_string();
|
||||
let detail_task = self.load_commit_details(recent_sha, cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
@@ -2133,11 +2159,16 @@ impl GitPanel {
|
||||
let result = task.await;
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.pending_commit.take();
|
||||
|
||||
match result {
|
||||
Ok(()) => {
|
||||
this.commit_editor
|
||||
.update(cx, |editor, cx| editor.clear(window, cx));
|
||||
this.original_commit_message = None;
|
||||
if options.amend {
|
||||
this.set_amend_pending(false, cx);
|
||||
} else {
|
||||
this.commit_editor
|
||||
.update(cx, |editor, cx| editor.clear(window, cx));
|
||||
this.original_commit_message = None;
|
||||
}
|
||||
}
|
||||
Err(e) => this.show_error_toast("commit", e, cx),
|
||||
}
|
||||
@@ -2146,9 +2177,6 @@ impl GitPanel {
|
||||
});
|
||||
|
||||
self.pending_commit = Some(task);
|
||||
if options.amend {
|
||||
self.set_amend_pending(false, cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn uncommit(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -5067,6 +5095,9 @@ impl GitPanel {
|
||||
self.amend_pending
|
||||
}
|
||||
|
||||
/// Sets the pending amend state, ensuring that the original commit message
|
||||
/// is either saved, when `value` is `true` and there's no pending amend, or
|
||||
/// restored, when `value` is `false` and there's a pending amend.
|
||||
pub fn set_amend_pending(&mut self, value: bool, cx: &mut Context<Self>) {
|
||||
if value && !self.amend_pending {
|
||||
let current_message = self.commit_message_buffer(cx).read(cx).text();
|
||||
@@ -5184,7 +5215,7 @@ impl GitPanel {
|
||||
pub(crate) fn toggle_amend_pending(&mut self, cx: &mut Context<Self>) {
|
||||
self.set_amend_pending(!self.amend_pending, cx);
|
||||
if self.amend_pending {
|
||||
self.load_last_commit_message_if_empty(cx);
|
||||
self.load_last_commit_message(cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5215,8 +5246,8 @@ impl Render for GitPanel {
|
||||
.when(has_write_access && !project.is_read_only(cx), |this| {
|
||||
this.on_action(cx.listener(Self::toggle_staged_for_selected))
|
||||
.on_action(cx.listener(Self::stage_range))
|
||||
.on_action(cx.listener(GitPanel::commit))
|
||||
.on_action(cx.listener(GitPanel::amend))
|
||||
.on_action(cx.listener(GitPanel::on_commit))
|
||||
.on_action(cx.listener(GitPanel::on_amend))
|
||||
.on_action(cx.listener(GitPanel::toggle_signoff_enabled))
|
||||
.on_action(cx.listener(Self::stage_all))
|
||||
.on_action(cx.listener(Self::unstage_all))
|
||||
@@ -6557,6 +6588,94 @@ mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_amend(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
"project": {
|
||||
".git": {},
|
||||
"src": {
|
||||
"main.rs": "fn main() {}"
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
fs.set_status_for_repo(
|
||||
Path::new(path!("/root/project/.git")),
|
||||
&[("src/main.rs", StatusCode::Modified.worktree())],
|
||||
);
|
||||
|
||||
let project = Project::test(fs.clone(), [Path::new(path!("/root/project"))], cx).await;
|
||||
let workspace =
|
||||
cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
let cx = &mut VisualTestContext::from_window(*workspace, cx);
|
||||
|
||||
// Wait for the project scanning to finish so that `head_commit(cx)` is
|
||||
// actually set, otherwise no head commit would be available from which
|
||||
// to fetch the latest commit message from.
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
let panel = workspace.update(cx, GitPanel::new).unwrap();
|
||||
panel.read_with(cx, |panel, cx| {
|
||||
assert!(panel.active_repository.is_some());
|
||||
assert!(panel.head_commit(cx).is_some());
|
||||
});
|
||||
|
||||
panel.update_in(cx, |panel, window, cx| {
|
||||
// Update the commit editor's message to ensure that its contents
|
||||
// are later restored, after amending is finished.
|
||||
panel.commit_message_buffer(cx).update(cx, |buffer, cx| {
|
||||
buffer.set_text("refactor: update main.rs", cx);
|
||||
});
|
||||
|
||||
// Start amending the previous commit.
|
||||
panel.focus_editor(&Default::default(), window, cx);
|
||||
panel.on_amend(&Amend, window, cx);
|
||||
});
|
||||
|
||||
// Since `GitPanel.amend` attempts to fetch the latest commit message in
|
||||
// a background task, we need to wait for it to complete before being
|
||||
// able to assert that the commit message editor's state has been
|
||||
// updated.
|
||||
cx.run_until_parked();
|
||||
|
||||
panel.update_in(cx, |panel, window, cx| {
|
||||
assert_eq!(
|
||||
panel.commit_message_buffer(cx).read(cx).text(),
|
||||
"initial commit"
|
||||
);
|
||||
assert_eq!(
|
||||
panel.original_commit_message,
|
||||
Some("refactor: update main.rs".to_string())
|
||||
);
|
||||
|
||||
// Finish amending the previous commit.
|
||||
panel.focus_editor(&Default::default(), window, cx);
|
||||
panel.on_amend(&Amend, window, cx);
|
||||
});
|
||||
|
||||
// Since the actual commit logic is run in a background task, we need to
|
||||
// await its completion to actually ensure that the commit message
|
||||
// editor's contents are set to the original message and haven't been
|
||||
// cleared.
|
||||
cx.run_until_parked();
|
||||
|
||||
panel.update_in(cx, |panel, _window, cx| {
|
||||
// After amending, the commit editor's message should be restored to
|
||||
// the original message.
|
||||
assert_eq!(
|
||||
panel.commit_message_buffer(cx).read(cx).text(),
|
||||
"refactor: update main.rs"
|
||||
);
|
||||
assert!(panel.original_commit_message.is_none());
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_open_diff(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
@@ -21,7 +21,6 @@ default = ["font-kit", "wayland", "x11", "windows-manifest"]
|
||||
test-support = [
|
||||
"leak-detection",
|
||||
"collections/test-support",
|
||||
"rand",
|
||||
"util/test-support",
|
||||
"http_client/test-support",
|
||||
"wayland",
|
||||
@@ -109,7 +108,7 @@ parking = "2.0.0"
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
profiling.workspace = true
|
||||
rand = { optional = true, workspace = true }
|
||||
rand.workspace = true
|
||||
raw-window-handle = "0.6"
|
||||
refineable.workspace = true
|
||||
resvg = { version = "0.45.0", default-features = false, features = [
|
||||
@@ -158,8 +157,10 @@ media.workspace = true
|
||||
objc.workspace = true
|
||||
objc2 = { version = "0.6", optional = true }
|
||||
objc2-metal = { version = "0.3", optional = true }
|
||||
mach2.workspace = true
|
||||
#TODO: replace with "objc2"
|
||||
metal.workspace = true
|
||||
flume = "0.11"
|
||||
|
||||
[target.'cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))'.dependencies]
|
||||
pathfinder_geometry = "0.5"
|
||||
|
||||
@@ -84,6 +84,8 @@ mod macos {
|
||||
.allowlist_var("_dispatch_main_q")
|
||||
.allowlist_var("_dispatch_source_type_data_add")
|
||||
.allowlist_var("DISPATCH_QUEUE_PRIORITY_HIGH")
|
||||
.allowlist_var("DISPATCH_QUEUE_PRIORITY_DEFAULT")
|
||||
.allowlist_var("DISPATCH_QUEUE_PRIORITY_LOW")
|
||||
.allowlist_var("DISPATCH_TIME_NOW")
|
||||
.allowlist_function("dispatch_get_global_queue")
|
||||
.allowlist_function("dispatch_async_f")
|
||||
|
||||
@@ -38,10 +38,11 @@ use crate::{
|
||||
AssetSource, BackgroundExecutor, Bounds, ClipboardItem, CursorStyle, DispatchPhase, DisplayId,
|
||||
EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, KeyContext,
|
||||
Keymap, Keystroke, LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform,
|
||||
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, PromptBuilder,
|
||||
PromptButton, PromptHandle, PromptLevel, Render, RenderImage, RenderablePromptHandle,
|
||||
Reservation, ScreenCaptureSource, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
|
||||
TextSystem, Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
|
||||
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, Priority,
|
||||
PromptBuilder, PromptButton, PromptHandle, PromptLevel, Render, RenderImage,
|
||||
RenderablePromptHandle, Reservation, ScreenCaptureSource, SharedString, SubscriberSet,
|
||||
Subscription, SvgRenderer, Task, TextSystem, Window, WindowAppearance, WindowHandle, WindowId,
|
||||
WindowInvalidator,
|
||||
colors::{Colors, GlobalColors},
|
||||
current_platform, hash, init_app_menus,
|
||||
};
|
||||
@@ -1494,6 +1495,24 @@ impl App {
|
||||
.spawn(async move { f(&mut cx).await })
|
||||
}
|
||||
|
||||
/// Spawns the future returned by the given function on the main thread with
|
||||
/// the given priority. The closure will be invoked with [AsyncApp], which
|
||||
/// allows the application state to be accessed across await points.
|
||||
pub fn spawn_with_priority<AsyncFn, R>(&self, priority: Priority, f: AsyncFn) -> Task<R>
|
||||
where
|
||||
AsyncFn: AsyncFnOnce(&mut AsyncApp) -> R + 'static,
|
||||
R: 'static,
|
||||
{
|
||||
if self.quitting {
|
||||
debug_panic!("Can't spawn on main thread after on_app_quit")
|
||||
};
|
||||
|
||||
let mut cx = self.to_async();
|
||||
|
||||
self.foreground_executor
|
||||
.spawn_with_priority(priority, async move { f(&mut cx).await })
|
||||
}
|
||||
|
||||
/// Schedules the given function to be run at the end of the current effect cycle, allowing entities
|
||||
/// that are currently on the stack to be returned to the app.
|
||||
pub fn defer(&mut self, f: impl FnOnce(&mut App) + 'static) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
AnyView, AnyWindowHandle, AppContext, AsyncApp, DispatchPhase, Effect, EntityId, EventEmitter,
|
||||
FocusHandle, FocusOutEvent, Focusable, Global, KeystrokeObserver, Reservation, SubscriberSet,
|
||||
Subscription, Task, WeakEntity, WeakFocusHandle, Window, WindowHandle,
|
||||
FocusHandle, FocusOutEvent, Focusable, Global, KeystrokeObserver, Priority, Reservation,
|
||||
SubscriberSet, Subscription, Task, WeakEntity, WeakFocusHandle, Window, WindowHandle,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use futures::FutureExt;
|
||||
@@ -667,6 +667,25 @@ impl<'a, T: 'static> Context<'a, T> {
|
||||
window.spawn(self, async move |cx| f(view, cx).await)
|
||||
}
|
||||
|
||||
/// Schedule a future to be run asynchronously with the given priority.
|
||||
/// The given callback is invoked with a [`WeakEntity<V>`] to avoid leaking the entity for a long-running process.
|
||||
/// It's also given an [`AsyncWindowContext`], which can be used to access the state of the entity across await points.
|
||||
/// The returned future will be polled on the main thread.
|
||||
#[track_caller]
|
||||
pub fn spawn_in_with_priority<AsyncFn, R>(
|
||||
&self,
|
||||
priority: Priority,
|
||||
window: &Window,
|
||||
f: AsyncFn,
|
||||
) -> Task<R>
|
||||
where
|
||||
R: 'static,
|
||||
AsyncFn: AsyncFnOnce(WeakEntity<T>, &mut AsyncWindowContext) -> R + 'static,
|
||||
{
|
||||
let view = self.weak_entity();
|
||||
window.spawn_with_priority(priority, self, async move |cx| f(view, cx).await)
|
||||
}
|
||||
|
||||
/// Register a callback to be invoked when the given global state changes.
|
||||
pub fn observe_global_in<G: Global>(
|
||||
&mut self,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant};
|
||||
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant, TaskTiming, profiler};
|
||||
use async_task::Runnable;
|
||||
use futures::channel::mpsc;
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use smol::prelude::*;
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
@@ -46,6 +47,52 @@ pub struct ForegroundExecutor {
|
||||
not_send: PhantomData<Rc<()>>,
|
||||
}
|
||||
|
||||
/// Realtime task priority
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum RealtimePriority {
|
||||
/// Audio task
|
||||
Audio,
|
||||
/// Other realtime task
|
||||
#[default]
|
||||
Other,
|
||||
}
|
||||
|
||||
/// Task priority
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum Priority {
|
||||
/// Realtime priority
|
||||
///
|
||||
/// Spawning a task with this priority will spin it off on a separate thread dedicated just to that task.
|
||||
Realtime(RealtimePriority),
|
||||
/// High priority
|
||||
///
|
||||
/// Only use for tasks that are critical to the user experience / responsiveness of the editor.
|
||||
High,
|
||||
/// Medium priority, probably suits most of your use cases.
|
||||
#[default]
|
||||
Medium,
|
||||
/// Low priority
|
||||
///
|
||||
/// Prioritize this for background work that can come in large quantities
|
||||
/// to not starve the executor of resources for high priority tasks
|
||||
Low,
|
||||
}
|
||||
|
||||
impl Priority {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) const fn probability(&self) -> u32 {
|
||||
match self {
|
||||
// realtime priorities are not considered for probability scheduling
|
||||
Priority::Realtime(_) => 0,
|
||||
Priority::High => 60,
|
||||
Priority::Medium => 30,
|
||||
Priority::Low => 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Task is a primitive that allows work to happen in the background.
|
||||
///
|
||||
/// It implements [`Future`] so you can `.await` on it.
|
||||
@@ -151,7 +198,77 @@ impl BackgroundExecutor {
|
||||
where
|
||||
R: Send + 'static,
|
||||
{
|
||||
self.spawn_internal::<R>(Box::pin(future), None)
|
||||
self.spawn_with_priority(Priority::default(), future)
|
||||
}
|
||||
|
||||
/// Enqueues the given future to be run to completion on a background thread.
|
||||
#[track_caller]
|
||||
pub fn spawn_with_priority<R>(
|
||||
&self,
|
||||
priority: Priority,
|
||||
future: impl Future<Output = R> + Send + 'static,
|
||||
) -> Task<R>
|
||||
where
|
||||
R: Send + 'static,
|
||||
{
|
||||
self.spawn_internal::<R>(Box::pin(future), None, priority)
|
||||
}
|
||||
|
||||
/// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
|
||||
///
|
||||
/// This allows to spawn background work that borrows from its scope. Note that the supplied future will run to
|
||||
/// completion before the current task is resumed, even if the current task is slated for cancellation.
|
||||
pub async fn await_on_background<R>(&self, future: impl Future<Output = R> + Send) -> R
|
||||
where
|
||||
R: Send,
|
||||
{
|
||||
// We need to ensure that cancellation of the parent task does not drop the environment
|
||||
// before the our own task has completed or got cancelled.
|
||||
struct NotifyOnDrop<'a>(&'a (Condvar, Mutex<bool>));
|
||||
|
||||
impl Drop for NotifyOnDrop<'_> {
|
||||
fn drop(&mut self) {
|
||||
*self.0.1.lock() = true;
|
||||
self.0.0.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
struct WaitOnDrop<'a>(&'a (Condvar, Mutex<bool>));
|
||||
|
||||
impl Drop for WaitOnDrop<'_> {
|
||||
fn drop(&mut self) {
|
||||
let mut done = self.0.1.lock();
|
||||
if !*done {
|
||||
self.0.0.wait(&mut done);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let dispatcher = self.dispatcher.clone();
|
||||
let location = core::panic::Location::caller();
|
||||
|
||||
let pair = &(Condvar::new(), Mutex::new(false));
|
||||
let _wait_guard = WaitOnDrop(pair);
|
||||
|
||||
let (runnable, task) = unsafe {
|
||||
async_task::Builder::new()
|
||||
.metadata(RunnableMeta { location })
|
||||
.spawn_unchecked(
|
||||
move |_| async {
|
||||
let _notify_guard = NotifyOnDrop(pair);
|
||||
future.await
|
||||
},
|
||||
move |runnable| {
|
||||
dispatcher.dispatch(
|
||||
RunnableVariant::Meta(runnable),
|
||||
None,
|
||||
Priority::default(),
|
||||
)
|
||||
},
|
||||
)
|
||||
};
|
||||
runnable.schedule();
|
||||
task.await
|
||||
}
|
||||
|
||||
/// Enqueues the given future to be run to completion on a background thread.
|
||||
@@ -165,7 +282,7 @@ impl BackgroundExecutor {
|
||||
where
|
||||
R: Send + 'static,
|
||||
{
|
||||
self.spawn_internal::<R>(Box::pin(future), Some(label))
|
||||
self.spawn_internal::<R>(Box::pin(future), Some(label), Priority::default())
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
@@ -173,15 +290,55 @@ impl BackgroundExecutor {
|
||||
&self,
|
||||
future: AnyFuture<R>,
|
||||
label: Option<TaskLabel>,
|
||||
priority: Priority,
|
||||
) -> Task<R> {
|
||||
let dispatcher = self.dispatcher.clone();
|
||||
let location = core::panic::Location::caller();
|
||||
let (runnable, task) = async_task::Builder::new()
|
||||
.metadata(RunnableMeta { location })
|
||||
.spawn(
|
||||
move |_| future,
|
||||
move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), label),
|
||||
let (runnable, task) = if let Priority::Realtime(realtime) = priority {
|
||||
let location = core::panic::Location::caller();
|
||||
let (mut tx, rx) = flume::bounded::<Runnable<RunnableMeta>>(1);
|
||||
|
||||
dispatcher.spawn_realtime(
|
||||
realtime,
|
||||
Box::new(move || {
|
||||
while let Ok(runnable) = rx.recv() {
|
||||
let start = Instant::now();
|
||||
let location = runnable.metadata().location;
|
||||
let mut timing = TaskTiming {
|
||||
location,
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
|
||||
let end = Instant::now();
|
||||
timing.end = Some(end);
|
||||
profiler::add_task_timing(timing);
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
async_task::Builder::new()
|
||||
.metadata(RunnableMeta { location })
|
||||
.spawn(
|
||||
move |_| future,
|
||||
move |runnable| {
|
||||
let _ = tx.send(runnable);
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let location = core::panic::Location::caller();
|
||||
async_task::Builder::new()
|
||||
.metadata(RunnableMeta { location })
|
||||
.spawn(
|
||||
move |_| future,
|
||||
move |runnable| {
|
||||
dispatcher.dispatch(RunnableVariant::Meta(runnable), label, priority)
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
runnable.schedule();
|
||||
Task(TaskState::Spawned(task))
|
||||
}
|
||||
@@ -354,11 +511,28 @@ impl BackgroundExecutor {
|
||||
where
|
||||
F: FnOnce(&mut Scope<'scope>),
|
||||
{
|
||||
let mut scope = Scope::new(self.clone());
|
||||
let mut scope = Scope::new(self.clone(), Priority::default());
|
||||
(scheduler)(&mut scope);
|
||||
let spawned = mem::take(&mut scope.futures)
|
||||
.into_iter()
|
||||
.map(|f| self.spawn(f))
|
||||
.map(|f| self.spawn_with_priority(scope.priority, f))
|
||||
.collect::<Vec<_>>();
|
||||
for task in spawned {
|
||||
task.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Scoped lets you start a number of tasks and waits
|
||||
/// for all of them to complete before returning.
|
||||
pub async fn scoped_priority<'scope, F>(&self, priority: Priority, scheduler: F)
|
||||
where
|
||||
F: FnOnce(&mut Scope<'scope>),
|
||||
{
|
||||
let mut scope = Scope::new(self.clone(), priority);
|
||||
(scheduler)(&mut scope);
|
||||
let spawned = mem::take(&mut scope.futures)
|
||||
.into_iter()
|
||||
.map(|f| self.spawn_with_priority(scope.priority, f))
|
||||
.collect::<Vec<_>>();
|
||||
for task in spawned {
|
||||
task.await;
|
||||
@@ -494,6 +668,19 @@ impl ForegroundExecutor {
|
||||
/// Enqueues the given Task to run on the main thread at some point in the future.
|
||||
#[track_caller]
|
||||
pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
|
||||
where
|
||||
R: 'static,
|
||||
{
|
||||
self.spawn_with_priority(Priority::default(), future)
|
||||
}
|
||||
|
||||
/// Enqueues the given Task to run on the main thread at some point in the future.
|
||||
#[track_caller]
|
||||
pub fn spawn_with_priority<R>(
|
||||
&self,
|
||||
priority: Priority,
|
||||
future: impl Future<Output = R> + 'static,
|
||||
) -> Task<R>
|
||||
where
|
||||
R: 'static,
|
||||
{
|
||||
@@ -505,16 +692,19 @@ impl ForegroundExecutor {
|
||||
dispatcher: Arc<dyn PlatformDispatcher>,
|
||||
future: AnyLocalFuture<R>,
|
||||
location: &'static core::panic::Location<'static>,
|
||||
priority: Priority,
|
||||
) -> Task<R> {
|
||||
let (runnable, task) = spawn_local_with_source_location(
|
||||
future,
|
||||
move |runnable| dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable)),
|
||||
move |runnable| {
|
||||
dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable), priority)
|
||||
},
|
||||
RunnableMeta { location },
|
||||
);
|
||||
runnable.schedule();
|
||||
Task(TaskState::Spawned(task))
|
||||
}
|
||||
inner::<R>(dispatcher, Box::pin(future), location)
|
||||
inner::<R>(dispatcher, Box::pin(future), location, priority)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -590,6 +780,7 @@ where
|
||||
/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
|
||||
pub struct Scope<'a> {
|
||||
executor: BackgroundExecutor,
|
||||
priority: Priority,
|
||||
futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
|
||||
tx: Option<mpsc::Sender<()>>,
|
||||
rx: mpsc::Receiver<()>,
|
||||
@@ -597,10 +788,11 @@ pub struct Scope<'a> {
|
||||
}
|
||||
|
||||
impl<'a> Scope<'a> {
|
||||
fn new(executor: BackgroundExecutor) -> Self {
|
||||
fn new(executor: BackgroundExecutor, priority: Priority) -> Self {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
Self {
|
||||
executor,
|
||||
priority,
|
||||
tx: Some(tx),
|
||||
rx,
|
||||
futures: Default::default(),
|
||||
|
||||
@@ -1416,9 +1416,9 @@ where
|
||||
/// ```
|
||||
pub fn contains(&self, point: &Point<T>) -> bool {
|
||||
point.x >= self.origin.x
|
||||
&& point.x <= self.origin.x.clone() + self.size.width.clone()
|
||||
&& point.x < self.origin.x.clone() + self.size.width.clone()
|
||||
&& point.y >= self.origin.y
|
||||
&& point.y <= self.origin.y.clone() + self.size.height.clone()
|
||||
&& point.y < self.origin.y.clone() + self.size.height.clone()
|
||||
}
|
||||
|
||||
/// Checks if this bounds is completely contained within another bounds.
|
||||
|
||||
@@ -31,6 +31,8 @@ mod path_builder;
|
||||
mod platform;
|
||||
pub mod prelude;
|
||||
mod profiler;
|
||||
#[cfg(any(target_os = "windows", target_os = "linux"))]
|
||||
mod queue;
|
||||
mod scene;
|
||||
mod shared_string;
|
||||
mod shared_uri;
|
||||
@@ -89,16 +91,20 @@ pub use keymap::*;
|
||||
pub use path_builder::*;
|
||||
pub use platform::*;
|
||||
pub use profiler::*;
|
||||
#[cfg(any(target_os = "windows", target_os = "linux"))]
|
||||
pub(crate) use queue::{PriorityQueueReceiver, PriorityQueueSender};
|
||||
pub use refineable::*;
|
||||
pub use scene::*;
|
||||
pub use shared_string::*;
|
||||
pub use shared_uri::*;
|
||||
pub use smol::Timer;
|
||||
use std::{any::Any, future::Future};
|
||||
pub use style::*;
|
||||
pub use styled::*;
|
||||
pub use subscription::*;
|
||||
pub use svg_renderer::*;
|
||||
pub(crate) use tab_stop::*;
|
||||
use taffy::TaffyLayoutEngine;
|
||||
pub use taffy::{AvailableSpace, LayoutId};
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub use test::*;
|
||||
@@ -109,9 +115,6 @@ pub use util::{FutureExt, Timeout, arc_cow::ArcCow};
|
||||
pub use view::*;
|
||||
pub use window::*;
|
||||
|
||||
use std::{any::Any, future::Future};
|
||||
use taffy::TaffyLayoutEngine;
|
||||
|
||||
/// The context trait, allows the different contexts in GPUI to be used
|
||||
/// interchangeably for certain operations.
|
||||
pub trait AppContext {
|
||||
|
||||
@@ -39,9 +39,10 @@ use crate::{
|
||||
Action, AnyWindowHandle, App, AsyncWindowContext, BackgroundExecutor, Bounds,
|
||||
DEFAULT_WINDOW_SIZE, DevicePixels, DispatchEventResult, Font, FontId, FontMetrics, FontRun,
|
||||
ForegroundExecutor, GlyphId, GpuSpecs, ImageSource, Keymap, LineLayout, Pixels, PlatformInput,
|
||||
Point, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Scene, ShapedGlyph,
|
||||
ShapedRun, SharedString, Size, SvgRenderer, SystemWindowTab, Task, TaskLabel, TaskTiming,
|
||||
ThreadTaskTimings, Window, WindowControlArea, hash, point, px, size,
|
||||
Point, Priority, RealtimePriority, RenderGlyphParams, RenderImage, RenderImageParams,
|
||||
RenderSvgParams, Scene, ShapedGlyph, ShapedRun, SharedString, Size, SvgRenderer,
|
||||
SystemWindowTab, Task, TaskLabel, TaskTiming, ThreadTaskTimings, Window, WindowControlArea,
|
||||
hash, point, px, size,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use async_task::Runnable;
|
||||
@@ -289,6 +290,13 @@ pub trait PlatformDisplay: Send + Sync + Debug {
|
||||
/// Get the bounds for this display
|
||||
fn bounds(&self) -> Bounds<Pixels>;
|
||||
|
||||
/// Get the visible bounds for this display, excluding taskbar/dock areas.
|
||||
/// This is the usable area where windows can be placed without being obscured.
|
||||
/// Defaults to the full display bounds if not overridden.
|
||||
fn visible_bounds(&self) -> Bounds<Pixels> {
|
||||
self.bounds()
|
||||
}
|
||||
|
||||
/// Get the default bounds for this display to place a window
|
||||
fn default_bounds(&self) -> Bounds<Pixels> {
|
||||
let bounds = self.bounds();
|
||||
@@ -580,9 +588,10 @@ pub trait PlatformDispatcher: Send + Sync {
|
||||
fn get_all_timings(&self) -> Vec<ThreadTaskTimings>;
|
||||
fn get_current_thread_timings(&self) -> Vec<TaskTiming>;
|
||||
fn is_main_thread(&self) -> bool;
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>);
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant);
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority);
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority);
|
||||
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant);
|
||||
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>);
|
||||
|
||||
fn now(&self) -> Instant {
|
||||
Instant::now()
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use crate::{
|
||||
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, RunnableVariant, THREAD_TIMINGS, TaskLabel,
|
||||
TaskTiming, ThreadTaskTimings,
|
||||
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, PriorityQueueReceiver,
|
||||
PriorityQueueSender, RealtimePriority, RunnableVariant, THREAD_TIMINGS, TaskLabel, TaskTiming,
|
||||
ThreadTaskTimings, profiler,
|
||||
};
|
||||
use calloop::{
|
||||
EventLoop,
|
||||
EventLoop, PostAction,
|
||||
channel::{self, Sender},
|
||||
timer::TimeoutAction,
|
||||
};
|
||||
@@ -19,9 +20,9 @@ struct TimerAfter {
|
||||
}
|
||||
|
||||
pub(crate) struct LinuxDispatcher {
|
||||
main_sender: Sender<RunnableVariant>,
|
||||
main_sender: PriorityQueueCalloopSender<RunnableVariant>,
|
||||
timer_sender: Sender<TimerAfter>,
|
||||
background_sender: flume::Sender<RunnableVariant>,
|
||||
background_sender: PriorityQueueSender<RunnableVariant>,
|
||||
_background_threads: Vec<thread::JoinHandle<()>>,
|
||||
main_thread_id: thread::ThreadId,
|
||||
}
|
||||
@@ -29,18 +30,20 @@ pub(crate) struct LinuxDispatcher {
|
||||
const MIN_THREADS: usize = 2;
|
||||
|
||||
impl LinuxDispatcher {
|
||||
pub fn new(main_sender: Sender<RunnableVariant>) -> Self {
|
||||
let (background_sender, background_receiver) = flume::unbounded::<RunnableVariant>();
|
||||
pub fn new(main_sender: PriorityQueueCalloopSender<RunnableVariant>) -> Self {
|
||||
let (background_sender, background_receiver) = PriorityQueueReceiver::new();
|
||||
let thread_count =
|
||||
std::thread::available_parallelism().map_or(MIN_THREADS, |i| i.get().max(MIN_THREADS));
|
||||
|
||||
// These thread should really be lower prio then the foreground
|
||||
// executor
|
||||
let mut background_threads = (0..thread_count)
|
||||
.map(|i| {
|
||||
let receiver = background_receiver.clone();
|
||||
let mut receiver = background_receiver.clone();
|
||||
std::thread::Builder::new()
|
||||
.name(format!("Worker-{i}"))
|
||||
.spawn(move || {
|
||||
for runnable in receiver {
|
||||
for runnable in receiver.iter() {
|
||||
let start = Instant::now();
|
||||
|
||||
let mut location = match runnable {
|
||||
@@ -51,7 +54,7 @@ impl LinuxDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -63,7 +66,7 @@ impl LinuxDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -72,7 +75,7 @@ impl LinuxDispatcher {
|
||||
|
||||
let end = Instant::now();
|
||||
location.end = Some(end);
|
||||
Self::add_task_timing(location);
|
||||
profiler::add_task_timing(location);
|
||||
|
||||
log::trace!(
|
||||
"background thread {}: ran runnable. took: {:?}",
|
||||
@@ -113,7 +116,7 @@ impl LinuxDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -124,7 +127,7 @@ impl LinuxDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -133,7 +136,7 @@ impl LinuxDispatcher {
|
||||
let end = Instant::now();
|
||||
|
||||
timing.end = Some(end);
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
}
|
||||
TimeoutAction::Drop
|
||||
},
|
||||
@@ -157,22 +160,6 @@ impl LinuxDispatcher {
|
||||
main_thread_id: thread::current().id(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_task_timing(timing: TaskTiming) {
|
||||
THREAD_TIMINGS.with(|timings| {
|
||||
let mut timings = timings.lock();
|
||||
let timings = &mut timings.timings;
|
||||
|
||||
if let Some(last_timing) = timings.iter_mut().rev().next() {
|
||||
if last_timing.location == timing.location {
|
||||
last_timing.end = timing.end;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
timings.push_back(timing);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl PlatformDispatcher for LinuxDispatcher {
|
||||
@@ -199,22 +186,26 @@ impl PlatformDispatcher for LinuxDispatcher {
|
||||
thread::current().id() == self.main_thread_id
|
||||
}
|
||||
|
||||
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>) {
|
||||
self.background_sender.send(runnable).unwrap();
|
||||
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>, priority: Priority) {
|
||||
self.background_sender
|
||||
.send(priority, runnable)
|
||||
.unwrap_or_else(|_| panic!("blocking sender returned without value"));
|
||||
}
|
||||
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
|
||||
self.main_sender.send(runnable).unwrap_or_else(|runnable| {
|
||||
// NOTE: Runnable may wrap a Future that is !Send.
|
||||
//
|
||||
// This is usually safe because we only poll it on the main thread.
|
||||
// However if the send fails, we know that:
|
||||
// 1. main_receiver has been dropped (which implies the app is shutting down)
|
||||
// 2. we are on a background thread.
|
||||
// It is not safe to drop something !Send on the wrong thread, and
|
||||
// the app will exit soon anyway, so we must forget the runnable.
|
||||
std::mem::forget(runnable);
|
||||
});
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
|
||||
self.main_sender
|
||||
.send(priority, runnable)
|
||||
.unwrap_or_else(|runnable| {
|
||||
// NOTE: Runnable may wrap a Future that is !Send.
|
||||
//
|
||||
// This is usually safe because we only poll it on the main thread.
|
||||
// However if the send fails, we know that:
|
||||
// 1. main_receiver has been dropped (which implies the app is shutting down)
|
||||
// 2. we are on a background thread.
|
||||
// It is not safe to drop something !Send on the wrong thread, and
|
||||
// the app will exit soon anyway, so we must forget the runnable.
|
||||
std::mem::forget(runnable);
|
||||
});
|
||||
}
|
||||
|
||||
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
|
||||
@@ -222,4 +213,252 @@ impl PlatformDispatcher for LinuxDispatcher {
|
||||
.send(TimerAfter { duration, runnable })
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
|
||||
std::thread::spawn(move || {
|
||||
// SAFETY: always safe to call
|
||||
let thread_id = unsafe { libc::pthread_self() };
|
||||
|
||||
let policy = match priority {
|
||||
RealtimePriority::Audio => libc::SCHED_FIFO,
|
||||
RealtimePriority::Other => libc::SCHED_RR,
|
||||
};
|
||||
let sched_priority = match priority {
|
||||
RealtimePriority::Audio => 65,
|
||||
RealtimePriority::Other => 45,
|
||||
};
|
||||
|
||||
let sched_param = libc::sched_param { sched_priority };
|
||||
// SAFETY: sched_param is a valid initialized structure
|
||||
let result = unsafe { libc::pthread_setschedparam(thread_id, policy, &sched_param) };
|
||||
if result != 0 {
|
||||
log::warn!("failed to set realtime thread priority to {:?}", priority);
|
||||
}
|
||||
|
||||
f();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PriorityQueueCalloopSender<T> {
|
||||
sender: PriorityQueueSender<T>,
|
||||
ping: calloop::ping::Ping,
|
||||
}
|
||||
|
||||
impl<T> PriorityQueueCalloopSender<T> {
|
||||
fn new(tx: PriorityQueueSender<T>, ping: calloop::ping::Ping) -> Self {
|
||||
Self { sender: tx, ping }
|
||||
}
|
||||
|
||||
fn send(&self, priority: Priority, item: T) -> Result<(), crate::queue::SendError<T>> {
|
||||
let res = self.sender.send(priority, item);
|
||||
if res.is_ok() {
|
||||
self.ping.ping();
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for PriorityQueueCalloopSender<T> {
|
||||
fn drop(&mut self) {
|
||||
self.ping.ping();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PriorityQueueCalloopReceiver<T> {
|
||||
receiver: PriorityQueueReceiver<T>,
|
||||
source: calloop::ping::PingSource,
|
||||
ping: calloop::ping::Ping,
|
||||
}
|
||||
|
||||
impl<T> PriorityQueueCalloopReceiver<T> {
|
||||
pub fn new() -> (PriorityQueueCalloopSender<T>, Self) {
|
||||
let (ping, source) = calloop::ping::make_ping().expect("Failed to create a Ping.");
|
||||
|
||||
let (tx, rx) = PriorityQueueReceiver::new();
|
||||
|
||||
(
|
||||
PriorityQueueCalloopSender::new(tx, ping.clone()),
|
||||
Self {
|
||||
receiver: rx,
|
||||
source,
|
||||
ping,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
use calloop::channel::Event;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ChannelError(calloop::ping::PingError);
|
||||
|
||||
impl std::fmt::Display for ChannelError {
|
||||
#[cfg_attr(feature = "nightly_coverage", coverage(off))]
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ChannelError {
|
||||
#[cfg_attr(feature = "nightly_coverage", coverage(off))]
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
Some(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> calloop::EventSource for PriorityQueueCalloopReceiver<T> {
|
||||
type Event = Event<T>;
|
||||
type Metadata = ();
|
||||
type Ret = ();
|
||||
type Error = ChannelError;
|
||||
|
||||
fn process_events<F>(
|
||||
&mut self,
|
||||
readiness: calloop::Readiness,
|
||||
token: calloop::Token,
|
||||
mut callback: F,
|
||||
) -> Result<calloop::PostAction, Self::Error>
|
||||
where
|
||||
F: FnMut(Self::Event, &mut Self::Metadata) -> Self::Ret,
|
||||
{
|
||||
let mut clear_readiness = false;
|
||||
let mut disconnected = false;
|
||||
|
||||
let action = self
|
||||
.source
|
||||
.process_events(readiness, token, |(), &mut ()| {
|
||||
let mut is_empty = true;
|
||||
|
||||
let mut receiver = self.receiver.clone();
|
||||
for runnable in receiver.try_iter() {
|
||||
match runnable {
|
||||
Ok(r) => {
|
||||
callback(Event::Msg(r), &mut ());
|
||||
is_empty = false;
|
||||
}
|
||||
Err(_) => {
|
||||
disconnected = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if disconnected {
|
||||
callback(Event::Closed, &mut ());
|
||||
}
|
||||
|
||||
if is_empty {
|
||||
clear_readiness = true;
|
||||
}
|
||||
})
|
||||
.map_err(ChannelError)?;
|
||||
|
||||
if disconnected {
|
||||
Ok(PostAction::Remove)
|
||||
} else if clear_readiness {
|
||||
Ok(action)
|
||||
} else {
|
||||
// Re-notify the ping source so we can try again.
|
||||
self.ping.ping();
|
||||
Ok(PostAction::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
fn register(
|
||||
&mut self,
|
||||
poll: &mut calloop::Poll,
|
||||
token_factory: &mut calloop::TokenFactory,
|
||||
) -> calloop::Result<()> {
|
||||
self.source.register(poll, token_factory)
|
||||
}
|
||||
|
||||
fn reregister(
|
||||
&mut self,
|
||||
poll: &mut calloop::Poll,
|
||||
token_factory: &mut calloop::TokenFactory,
|
||||
) -> calloop::Result<()> {
|
||||
self.source.reregister(poll, token_factory)
|
||||
}
|
||||
|
||||
fn unregister(&mut self, poll: &mut calloop::Poll) -> calloop::Result<()> {
|
||||
self.source.unregister(poll)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn calloop_works() {
|
||||
let mut event_loop = calloop::EventLoop::try_new().unwrap();
|
||||
let handle = event_loop.handle();
|
||||
|
||||
let (tx, rx) = PriorityQueueCalloopReceiver::new();
|
||||
|
||||
struct Data {
|
||||
got_msg: bool,
|
||||
got_closed: bool,
|
||||
}
|
||||
|
||||
let mut data = Data {
|
||||
got_msg: false,
|
||||
got_closed: false,
|
||||
};
|
||||
|
||||
let _channel_token = handle
|
||||
.insert_source(rx, move |evt, &mut (), data: &mut Data| match evt {
|
||||
Event::Msg(()) => {
|
||||
data.got_msg = true;
|
||||
}
|
||||
|
||||
Event::Closed => {
|
||||
data.got_closed = true;
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// nothing is sent, nothing is received
|
||||
event_loop
|
||||
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
|
||||
.unwrap();
|
||||
|
||||
assert!(!data.got_msg);
|
||||
assert!(!data.got_closed);
|
||||
// a message is send
|
||||
|
||||
tx.send(Priority::Medium, ()).unwrap();
|
||||
event_loop
|
||||
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
|
||||
.unwrap();
|
||||
|
||||
assert!(data.got_msg);
|
||||
assert!(!data.got_closed);
|
||||
|
||||
// the sender is dropped
|
||||
drop(tx);
|
||||
event_loop
|
||||
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
|
||||
.unwrap();
|
||||
|
||||
assert!(data.got_msg);
|
||||
assert!(data.got_closed);
|
||||
}
|
||||
}
|
||||
|
||||
// running 1 test
|
||||
// test platform::linux::dispatcher::tests::tomato ... FAILED
|
||||
|
||||
// failures:
|
||||
|
||||
// ---- platform::linux::dispatcher::tests::tomato stdout ----
|
||||
// [crates/gpui/src/platform/linux/dispatcher.rs:262:9]
|
||||
// returning 1 tasks to process
|
||||
// [crates/gpui/src/platform/linux/dispatcher.rs:480:75] evt = Msg(
|
||||
// (),
|
||||
// )
|
||||
// returning 0 tasks to process
|
||||
|
||||
// thread 'platform::linux::dispatcher::tests::tomato' (478301) panicked at crates/gpui/src/platform/linux/dispatcher.rs:515:9:
|
||||
// assertion failed: data.got_closed
|
||||
// note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
|
||||
|
||||
@@ -14,7 +14,7 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use calloop::{LoopSignal, channel::Channel};
|
||||
use calloop::LoopSignal;
|
||||
use futures::channel::oneshot;
|
||||
use util::ResultExt as _;
|
||||
use util::command::{new_smol_command, new_std_command};
|
||||
@@ -25,8 +25,8 @@ use crate::{
|
||||
Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId,
|
||||
ForegroundExecutor, Keymap, LinuxDispatcher, Menu, MenuItem, OwnedMenu, PathPromptOptions,
|
||||
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper,
|
||||
PlatformTextSystem, PlatformWindow, Point, Result, RunnableVariant, Task, WindowAppearance,
|
||||
WindowParams, px,
|
||||
PlatformTextSystem, PlatformWindow, Point, PriorityQueueCalloopReceiver, Result,
|
||||
RunnableVariant, Task, WindowAppearance, WindowParams, px,
|
||||
};
|
||||
|
||||
#[cfg(any(feature = "wayland", feature = "x11"))]
|
||||
@@ -149,8 +149,8 @@ pub(crate) struct LinuxCommon {
|
||||
}
|
||||
|
||||
impl LinuxCommon {
|
||||
pub fn new(signal: LoopSignal) -> (Self, Channel<RunnableVariant>) {
|
||||
let (main_sender, main_receiver) = calloop::channel::channel::<RunnableVariant>();
|
||||
pub fn new(signal: LoopSignal) -> (Self, PriorityQueueCalloopReceiver<RunnableVariant>) {
|
||||
let (main_sender, main_receiver) = PriorityQueueCalloopReceiver::new();
|
||||
|
||||
#[cfg(any(feature = "wayland", feature = "x11"))]
|
||||
let text_system = Arc::new(crate::CosmicTextSystem::new());
|
||||
|
||||
@@ -77,10 +77,10 @@ use crate::{
|
||||
LinuxKeyboardLayout, Modifiers, ModifiersChangedEvent, MouseButton, MouseDownEvent,
|
||||
MouseExitEvent, MouseMoveEvent, MouseUpEvent, NavigationDirection, Pixels, PlatformDisplay,
|
||||
PlatformInput, PlatformKeyboardLayout, Point, ResultExt as _, SCROLL_LINES, ScrollDelta,
|
||||
ScrollWheelEvent, Size, TouchPhase, WindowParams, point, px, size,
|
||||
ScrollWheelEvent, Size, TouchPhase, WindowParams, point, profiler, px, size,
|
||||
};
|
||||
use crate::{
|
||||
LinuxDispatcher, RunnableVariant, TaskTiming,
|
||||
RunnableVariant, TaskTiming,
|
||||
platform::{PlatformWindow, blade::BladeContext},
|
||||
};
|
||||
use crate::{
|
||||
@@ -503,7 +503,7 @@ impl WaylandClient {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -515,7 +515,7 @@ impl WaylandClient {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -524,7 +524,7 @@ impl WaylandClient {
|
||||
|
||||
let end = Instant::now();
|
||||
timing.end = Some(end);
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{Capslock, LinuxDispatcher, ResultExt as _, RunnableVariant, TaskTiming, xcb_flush};
|
||||
use crate::{Capslock, ResultExt as _, RunnableVariant, TaskTiming, profiler, xcb_flush};
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use ashpd::WindowIdentifier;
|
||||
use calloop::{
|
||||
@@ -322,7 +322,7 @@ impl X11Client {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -334,7 +334,7 @@ impl X11Client {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -343,7 +343,7 @@ impl X11Client {
|
||||
|
||||
let end = Instant::now();
|
||||
timing.end = Some(end);
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,22 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::{
|
||||
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, RunnableMeta, RunnableVariant, THREAD_TIMINGS,
|
||||
TaskLabel, TaskTiming, ThreadTaskTimings,
|
||||
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, RealtimePriority, RunnableMeta,
|
||||
RunnableVariant, THREAD_TIMINGS, TaskLabel, TaskTiming, ThreadTaskTimings,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_task::Runnable;
|
||||
use mach2::{
|
||||
kern_return::KERN_SUCCESS,
|
||||
mach_time::mach_timebase_info_data_t,
|
||||
thread_policy::{
|
||||
THREAD_EXTENDED_POLICY, THREAD_EXTENDED_POLICY_COUNT, THREAD_PRECEDENCE_POLICY,
|
||||
THREAD_PRECEDENCE_POLICY_COUNT, THREAD_TIME_CONSTRAINT_POLICY,
|
||||
THREAD_TIME_CONSTRAINT_POLICY_COUNT, thread_extended_policy_data_t,
|
||||
thread_precedence_policy_data_t, thread_time_constraint_policy_data_t,
|
||||
},
|
||||
};
|
||||
use objc::{
|
||||
class, msg_send,
|
||||
runtime::{BOOL, YES},
|
||||
@@ -15,9 +26,11 @@ use objc::{
|
||||
};
|
||||
use std::{
|
||||
ffi::c_void,
|
||||
mem::MaybeUninit,
|
||||
ptr::{NonNull, addr_of},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
/// All items in the generated file are marked as pub, so we're gonna wrap it in a separate mod to prevent
|
||||
/// these pub items from leaking into public API.
|
||||
@@ -56,7 +69,7 @@ impl PlatformDispatcher for MacDispatcher {
|
||||
is_main_thread == YES
|
||||
}
|
||||
|
||||
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>) {
|
||||
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>, priority: Priority) {
|
||||
let (context, trampoline) = match runnable {
|
||||
RunnableVariant::Meta(runnable) => (
|
||||
runnable.into_raw().as_ptr() as *mut c_void,
|
||||
@@ -67,16 +80,24 @@ impl PlatformDispatcher for MacDispatcher {
|
||||
Some(trampoline_compat as unsafe extern "C" fn(*mut c_void)),
|
||||
),
|
||||
};
|
||||
|
||||
let queue_priority = match priority {
|
||||
Priority::Realtime(_) => unreachable!(),
|
||||
Priority::High => DISPATCH_QUEUE_PRIORITY_HIGH as isize,
|
||||
Priority::Medium => DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
|
||||
Priority::Low => DISPATCH_QUEUE_PRIORITY_LOW as isize,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
dispatch_async_f(
|
||||
dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH.try_into().unwrap(), 0),
|
||||
dispatch_get_global_queue(queue_priority, 0),
|
||||
context,
|
||||
trampoline,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
|
||||
let (context, trampoline) = match runnable {
|
||||
RunnableVariant::Meta(runnable) => (
|
||||
runnable.into_raw().as_ptr() as *mut c_void,
|
||||
@@ -110,6 +131,120 @@ impl PlatformDispatcher for MacDispatcher {
|
||||
dispatch_after_f(when, queue, context, trampoline);
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
|
||||
std::thread::spawn(move || {
|
||||
match priority {
|
||||
RealtimePriority::Audio => set_audio_thread_priority(),
|
||||
RealtimePriority::Other => set_high_thread_priority(),
|
||||
}
|
||||
.context(format!("for priority {:?}", priority))
|
||||
.log_err();
|
||||
|
||||
f();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn set_high_thread_priority() -> anyhow::Result<()> {
|
||||
// SAFETY: always safe to call
|
||||
let thread_id = unsafe { libc::pthread_self() };
|
||||
|
||||
// SAFETY: all sched_param members are valid when initialized to zero.
|
||||
let mut sched_param = unsafe { MaybeUninit::<libc::sched_param>::zeroed().assume_init() };
|
||||
sched_param.sched_priority = 45;
|
||||
|
||||
let result = unsafe { libc::pthread_setschedparam(thread_id, libc::SCHED_FIFO, &sched_param) };
|
||||
if result != 0 {
|
||||
anyhow::bail!("failed to set realtime thread priority")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_audio_thread_priority() -> anyhow::Result<()> {
|
||||
// https://chromium.googlesource.com/chromium/chromium/+/master/base/threading/platform_thread_mac.mm#93
|
||||
|
||||
// SAFETY: always safe to call
|
||||
let thread_id = unsafe { libc::pthread_self() };
|
||||
|
||||
// SAFETY: thread_id is a valid thread id
|
||||
let thread_id = unsafe { libc::pthread_mach_thread_np(thread_id) };
|
||||
|
||||
// Fixed priority thread
|
||||
let mut policy = thread_extended_policy_data_t { timeshare: 0 };
|
||||
|
||||
// SAFETY: thread_id is a valid thread id
|
||||
// SAFETY: thread_extended_policy_data_t is passed as THREAD_EXTENDED_POLICY
|
||||
let result = unsafe {
|
||||
mach2::thread_policy::thread_policy_set(
|
||||
thread_id,
|
||||
THREAD_EXTENDED_POLICY,
|
||||
&mut policy as *mut _ as *mut _,
|
||||
THREAD_EXTENDED_POLICY_COUNT,
|
||||
)
|
||||
};
|
||||
|
||||
if result != KERN_SUCCESS {
|
||||
anyhow::bail!("failed to set thread extended policy");
|
||||
}
|
||||
|
||||
// relatively high priority
|
||||
let mut precedence = thread_precedence_policy_data_t { importance: 63 };
|
||||
|
||||
// SAFETY: thread_id is a valid thread id
|
||||
// SAFETY: thread_precedence_policy_data_t is passed as THREAD_PRECEDENCE_POLICY
|
||||
let result = unsafe {
|
||||
mach2::thread_policy::thread_policy_set(
|
||||
thread_id,
|
||||
THREAD_PRECEDENCE_POLICY,
|
||||
&mut precedence as *mut _ as *mut _,
|
||||
THREAD_PRECEDENCE_POLICY_COUNT,
|
||||
)
|
||||
};
|
||||
|
||||
if result != KERN_SUCCESS {
|
||||
anyhow::bail!("failed to set thread precedence policy");
|
||||
}
|
||||
|
||||
const GUARANTEED_AUDIO_DUTY_CYCLE: f32 = 0.75;
|
||||
const MAX_AUDIO_DUTY_CYCLE: f32 = 0.85;
|
||||
|
||||
// ~128 frames @ 44.1KHz
|
||||
const TIME_QUANTUM: f32 = 2.9;
|
||||
|
||||
const AUDIO_TIME_NEEDED: f32 = GUARANTEED_AUDIO_DUTY_CYCLE * TIME_QUANTUM;
|
||||
const MAX_TIME_ALLOWED: f32 = MAX_AUDIO_DUTY_CYCLE * TIME_QUANTUM;
|
||||
|
||||
let mut timebase_info = mach_timebase_info_data_t { numer: 0, denom: 0 };
|
||||
// SAFETY: timebase_info is a valid pointer to a mach_timebase_info_data_t struct
|
||||
unsafe { mach2::mach_time::mach_timebase_info(&mut timebase_info) };
|
||||
|
||||
let ms_to_abs_time = ((timebase_info.denom as f32) / (timebase_info.numer as f32)) * 1000000f32;
|
||||
|
||||
let mut time_constraints = thread_time_constraint_policy_data_t {
|
||||
period: (TIME_QUANTUM * ms_to_abs_time) as u32,
|
||||
computation: (AUDIO_TIME_NEEDED * ms_to_abs_time) as u32,
|
||||
constraint: (MAX_TIME_ALLOWED * ms_to_abs_time) as u32,
|
||||
preemptible: 0,
|
||||
};
|
||||
|
||||
// SAFETY: thread_id is a valid thread id
|
||||
// SAFETY: thread_precedence_pthread_time_constraint_policy_data_t is passed as THREAD_TIME_CONSTRAINT_POLICY
|
||||
let result = unsafe {
|
||||
mach2::thread_policy::thread_policy_set(
|
||||
thread_id,
|
||||
THREAD_TIME_CONSTRAINT_POLICY,
|
||||
&mut time_constraints as *mut _ as *mut _,
|
||||
THREAD_TIME_CONSTRAINT_POLICY_COUNT,
|
||||
)
|
||||
};
|
||||
|
||||
if result != KERN_SUCCESS {
|
||||
anyhow::bail!("failed to set thread time constraint policy");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
extern "C" fn trampoline(runnable: *mut c_void) {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use crate::{Bounds, DisplayId, Pixels, PlatformDisplay, px, size};
|
||||
use crate::{Bounds, DisplayId, Pixels, PlatformDisplay, point, px, size};
|
||||
use anyhow::Result;
|
||||
use cocoa::{
|
||||
appkit::NSScreen,
|
||||
base::{id, nil},
|
||||
foundation::{NSDictionary, NSString},
|
||||
foundation::{NSArray, NSDictionary, NSString},
|
||||
};
|
||||
use core_foundation::uuid::{CFUUIDGetUUIDBytes, CFUUIDRef};
|
||||
use core_graphics::display::{CGDirectDisplayID, CGDisplayBounds, CGGetActiveDisplayList};
|
||||
@@ -114,4 +114,53 @@ impl PlatformDisplay for MacDisplay {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn visible_bounds(&self) -> Bounds<Pixels> {
|
||||
unsafe {
|
||||
let dominated_screen = self.get_nsscreen();
|
||||
|
||||
if dominated_screen == nil {
|
||||
return self.bounds();
|
||||
}
|
||||
|
||||
let screen_frame = NSScreen::frame(dominated_screen);
|
||||
let visible_frame = NSScreen::visibleFrame(dominated_screen);
|
||||
|
||||
// Convert from bottom-left origin (AppKit) to top-left origin
|
||||
let origin_y =
|
||||
screen_frame.size.height - visible_frame.origin.y - visible_frame.size.height
|
||||
+ screen_frame.origin.y;
|
||||
|
||||
Bounds {
|
||||
origin: point(
|
||||
px(visible_frame.origin.x as f32 - screen_frame.origin.x as f32),
|
||||
px(origin_y as f32),
|
||||
),
|
||||
size: size(
|
||||
px(visible_frame.size.width as f32),
|
||||
px(visible_frame.size.height as f32),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MacDisplay {
|
||||
/// Find the NSScreen corresponding to this display
|
||||
unsafe fn get_nsscreen(&self) -> id {
|
||||
let screens = unsafe { NSScreen::screens(nil) };
|
||||
let count = unsafe { NSArray::count(screens) };
|
||||
let screen_number_key: id = unsafe { NSString::alloc(nil).init_str("NSScreenNumber") };
|
||||
|
||||
for i in 0..count {
|
||||
let screen = unsafe { NSArray::objectAtIndex(screens, i) };
|
||||
let device_description = unsafe { NSScreen::deviceDescription(screen) };
|
||||
let screen_number = unsafe { device_description.objectForKey_(screen_number_key) };
|
||||
let screen_id: CGDirectDisplayID = msg_send![screen_number, unsignedIntegerValue];
|
||||
if screen_id == self.0 {
|
||||
return screen;
|
||||
}
|
||||
}
|
||||
nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{PlatformDispatcher, RunnableVariant, TaskLabel};
|
||||
use crate::{PlatformDispatcher, Priority, RunnableVariant, TaskLabel};
|
||||
use backtrace::Backtrace;
|
||||
use collections::{HashMap, HashSet, VecDeque};
|
||||
use parking::Unparker;
|
||||
@@ -284,7 +284,7 @@ impl PlatformDispatcher for TestDispatcher {
|
||||
state.start_time + state.time
|
||||
}
|
||||
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, _priority: Priority) {
|
||||
{
|
||||
let mut state = self.state.lock();
|
||||
if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
|
||||
@@ -296,7 +296,7 @@ impl PlatformDispatcher for TestDispatcher {
|
||||
self.unpark_all();
|
||||
}
|
||||
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
|
||||
self.state
|
||||
.lock()
|
||||
.foreground
|
||||
@@ -318,4 +318,10 @@ impl PlatformDispatcher for TestDispatcher {
|
||||
fn as_test(&self) -> Option<&TestDispatcher> {
|
||||
Some(self)
|
||||
}
|
||||
|
||||
fn spawn_realtime(&self, _priority: crate::RealtimePriority, f: Box<dyn FnOnce() + Send>) {
|
||||
std::thread::spawn(move || {
|
||||
f();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,24 +4,31 @@ use std::{
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use flume::Sender;
|
||||
use anyhow::Context;
|
||||
use util::ResultExt;
|
||||
use windows::{
|
||||
System::Threading::{ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler},
|
||||
System::Threading::{
|
||||
ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
|
||||
},
|
||||
Win32::{
|
||||
Foundation::{LPARAM, WPARAM},
|
||||
System::Threading::{
|
||||
GetCurrentThread, HIGH_PRIORITY_CLASS, SetPriorityClass, SetThreadPriority,
|
||||
THREAD_PRIORITY_HIGHEST, THREAD_PRIORITY_TIME_CRITICAL,
|
||||
},
|
||||
UI::WindowsAndMessaging::PostMessageW,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, RunnableVariant, SafeHwnd, THREAD_TIMINGS,
|
||||
TaskLabel, TaskTiming, ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
|
||||
GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, Priority, PriorityQueueSender,
|
||||
RealtimePriority, RunnableVariant, SafeHwnd, THREAD_TIMINGS, TaskLabel, TaskTiming,
|
||||
ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD, profiler,
|
||||
};
|
||||
|
||||
pub(crate) struct WindowsDispatcher {
|
||||
pub(crate) wake_posted: AtomicBool,
|
||||
main_sender: Sender<RunnableVariant>,
|
||||
main_sender: PriorityQueueSender<RunnableVariant>,
|
||||
main_thread_id: ThreadId,
|
||||
pub(crate) platform_window_handle: SafeHwnd,
|
||||
validation_number: usize,
|
||||
@@ -29,7 +36,7 @@ pub(crate) struct WindowsDispatcher {
|
||||
|
||||
impl WindowsDispatcher {
|
||||
pub(crate) fn new(
|
||||
main_sender: Sender<RunnableVariant>,
|
||||
main_sender: PriorityQueueSender<RunnableVariant>,
|
||||
platform_window_handle: HWND,
|
||||
validation_number: usize,
|
||||
) -> Self {
|
||||
@@ -45,7 +52,7 @@ impl WindowsDispatcher {
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_on_threadpool(&self, runnable: RunnableVariant) {
|
||||
fn dispatch_on_threadpool(&self, priority: WorkItemPriority, runnable: RunnableVariant) {
|
||||
let handler = {
|
||||
let mut task_wrapper = Some(runnable);
|
||||
WorkItemHandler::new(move |_| {
|
||||
@@ -53,7 +60,8 @@ impl WindowsDispatcher {
|
||||
Ok(())
|
||||
})
|
||||
};
|
||||
ThreadPool::RunAsync(&handler).log_err();
|
||||
|
||||
ThreadPool::RunWithPriorityAsync(&handler, priority).log_err();
|
||||
}
|
||||
|
||||
fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
|
||||
@@ -79,7 +87,7 @@ impl WindowsDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
|
||||
@@ -91,7 +99,7 @@ impl WindowsDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
|
||||
@@ -102,23 +110,7 @@ impl WindowsDispatcher {
|
||||
let end = Instant::now();
|
||||
timing.end = Some(end);
|
||||
|
||||
Self::add_task_timing(timing);
|
||||
}
|
||||
|
||||
pub(crate) fn add_task_timing(timing: TaskTiming) {
|
||||
THREAD_TIMINGS.with(|timings| {
|
||||
let mut timings = timings.lock();
|
||||
let timings = &mut timings.timings;
|
||||
|
||||
if let Some(last_timing) = timings.iter_mut().rev().next() {
|
||||
if last_timing.location == timing.location {
|
||||
last_timing.end = timing.end;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
timings.push_back(timing);
|
||||
});
|
||||
profiler::add_task_timing(timing);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,15 +138,22 @@ impl PlatformDispatcher for WindowsDispatcher {
|
||||
current().id() == self.main_thread_id
|
||||
}
|
||||
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
|
||||
self.dispatch_on_threadpool(runnable);
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority) {
|
||||
let priority = match priority {
|
||||
Priority::Realtime(_) => unreachable!(),
|
||||
Priority::High => WorkItemPriority::High,
|
||||
Priority::Medium => WorkItemPriority::Normal,
|
||||
Priority::Low => WorkItemPriority::Low,
|
||||
};
|
||||
self.dispatch_on_threadpool(priority, runnable);
|
||||
|
||||
if let Some(label) = label {
|
||||
log::debug!("TaskLabel: {label:?}");
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
|
||||
match self.main_sender.send(runnable) {
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
|
||||
match self.main_sender.send(priority, runnable) {
|
||||
Ok(_) => {
|
||||
if !self.wake_posted.swap(true, Ordering::AcqRel) {
|
||||
unsafe {
|
||||
@@ -185,4 +184,28 @@ impl PlatformDispatcher for WindowsDispatcher {
|
||||
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
|
||||
self.dispatch_on_threadpool_after(runnable, duration);
|
||||
}
|
||||
|
||||
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
|
||||
std::thread::spawn(move || {
|
||||
// SAFETY: always safe to call
|
||||
let thread_handle = unsafe { GetCurrentThread() };
|
||||
|
||||
let thread_priority = match priority {
|
||||
RealtimePriority::Audio => THREAD_PRIORITY_TIME_CRITICAL,
|
||||
RealtimePriority::Other => THREAD_PRIORITY_HIGHEST,
|
||||
};
|
||||
|
||||
// SAFETY: thread_handle is a valid handle to a thread
|
||||
unsafe { SetPriorityClass(thread_handle, HIGH_PRIORITY_CLASS) }
|
||||
.context("thread priority class")
|
||||
.log_err();
|
||||
|
||||
// SAFETY: thread_handle is a valid handle to a thread
|
||||
unsafe { SetThreadPriority(thread_handle, thread_priority) }
|
||||
.context("thread priority")
|
||||
.log_err();
|
||||
|
||||
f();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ pub(crate) struct WindowsDisplay {
|
||||
pub display_id: DisplayId,
|
||||
scale_factor: f32,
|
||||
bounds: Bounds<Pixels>,
|
||||
visible_bounds: Bounds<Pixels>,
|
||||
physical_bounds: Bounds<DevicePixels>,
|
||||
uuid: Uuid,
|
||||
}
|
||||
@@ -36,6 +37,7 @@ impl WindowsDisplay {
|
||||
let screen = available_monitors().into_iter().nth(display_id.0 as _)?;
|
||||
let info = get_monitor_info(screen).log_err()?;
|
||||
let monitor_size = info.monitorInfo.rcMonitor;
|
||||
let work_area = info.monitorInfo.rcWork;
|
||||
let uuid = generate_uuid(&info.szDevice);
|
||||
let scale_factor = get_scale_factor_for_monitor(screen).log_err()?;
|
||||
let physical_size = size(
|
||||
@@ -55,6 +57,14 @@ impl WindowsDisplay {
|
||||
),
|
||||
size: physical_size.to_pixels(scale_factor),
|
||||
},
|
||||
visible_bounds: Bounds {
|
||||
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
|
||||
size: size(
|
||||
(work_area.right - work_area.left) as f32 / scale_factor,
|
||||
(work_area.bottom - work_area.top) as f32 / scale_factor,
|
||||
)
|
||||
.map(crate::px),
|
||||
},
|
||||
physical_bounds: Bounds {
|
||||
origin: point(monitor_size.left.into(), monitor_size.top.into()),
|
||||
size: physical_size,
|
||||
@@ -66,6 +76,7 @@ impl WindowsDisplay {
|
||||
pub fn new_with_handle(monitor: HMONITOR) -> anyhow::Result<Self> {
|
||||
let info = get_monitor_info(monitor)?;
|
||||
let monitor_size = info.monitorInfo.rcMonitor;
|
||||
let work_area = info.monitorInfo.rcWork;
|
||||
let uuid = generate_uuid(&info.szDevice);
|
||||
let display_id = available_monitors()
|
||||
.iter()
|
||||
@@ -89,6 +100,14 @@ impl WindowsDisplay {
|
||||
),
|
||||
size: physical_size.to_pixels(scale_factor),
|
||||
},
|
||||
visible_bounds: Bounds {
|
||||
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
|
||||
size: size(
|
||||
(work_area.right - work_area.left) as f32 / scale_factor,
|
||||
(work_area.bottom - work_area.top) as f32 / scale_factor,
|
||||
)
|
||||
.map(crate::px),
|
||||
},
|
||||
physical_bounds: Bounds {
|
||||
origin: point(monitor_size.left.into(), monitor_size.top.into()),
|
||||
size: physical_size,
|
||||
@@ -100,6 +119,7 @@ impl WindowsDisplay {
|
||||
fn new_with_handle_and_id(handle: HMONITOR, display_id: DisplayId) -> anyhow::Result<Self> {
|
||||
let info = get_monitor_info(handle)?;
|
||||
let monitor_size = info.monitorInfo.rcMonitor;
|
||||
let work_area = info.monitorInfo.rcWork;
|
||||
let uuid = generate_uuid(&info.szDevice);
|
||||
let scale_factor = get_scale_factor_for_monitor(handle)?;
|
||||
let physical_size = size(
|
||||
@@ -119,6 +139,14 @@ impl WindowsDisplay {
|
||||
),
|
||||
size: physical_size.to_pixels(scale_factor),
|
||||
},
|
||||
visible_bounds: Bounds {
|
||||
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
|
||||
size: size(
|
||||
(work_area.right - work_area.left) as f32 / scale_factor,
|
||||
(work_area.bottom - work_area.top) as f32 / scale_factor,
|
||||
)
|
||||
.map(crate::px),
|
||||
},
|
||||
physical_bounds: Bounds {
|
||||
origin: point(monitor_size.left.into(), monitor_size.top.into()),
|
||||
size: physical_size,
|
||||
@@ -193,6 +221,10 @@ impl PlatformDisplay for WindowsDisplay {
|
||||
fn bounds(&self) -> Bounds<Pixels> {
|
||||
self.bounds
|
||||
}
|
||||
|
||||
fn visible_bounds(&self) -> Bounds<Pixels> {
|
||||
self.visible_bounds
|
||||
}
|
||||
}
|
||||
|
||||
fn available_monitors() -> SmallVec<[HMONITOR; 4]> {
|
||||
|
||||
@@ -243,7 +243,8 @@ impl WindowsWindowInner {
|
||||
|
||||
fn handle_timer_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> {
|
||||
if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID {
|
||||
for runnable in self.main_receiver.drain() {
|
||||
let mut runnables = self.main_receiver.clone().try_iter();
|
||||
while let Some(Ok(runnable)) = runnables.next() {
|
||||
WindowsDispatcher::execute_runnable(runnable);
|
||||
}
|
||||
self.handle_paint_msg(handle)
|
||||
|
||||
@@ -51,7 +51,7 @@ struct WindowsPlatformInner {
|
||||
raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
|
||||
// The below members will never change throughout the entire lifecycle of the app.
|
||||
validation_number: usize,
|
||||
main_receiver: flume::Receiver<RunnableVariant>,
|
||||
main_receiver: PriorityQueueReceiver<RunnableVariant>,
|
||||
dispatcher: Arc<WindowsDispatcher>,
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ impl WindowsPlatform {
|
||||
OleInitialize(None).context("unable to initialize Windows OLE")?;
|
||||
}
|
||||
let directx_devices = DirectXDevices::new().context("Creating DirectX devices")?;
|
||||
let (main_sender, main_receiver) = flume::unbounded::<RunnableVariant>();
|
||||
let (main_sender, main_receiver) = PriorityQueueReceiver::new();
|
||||
let validation_number = if usize::BITS == 64 {
|
||||
rand::random::<u64>() as usize
|
||||
} else {
|
||||
@@ -857,22 +857,24 @@ impl WindowsPlatformInner {
|
||||
}
|
||||
break 'tasks;
|
||||
}
|
||||
match self.main_receiver.try_recv() {
|
||||
Err(_) => break 'timeout_loop,
|
||||
Ok(runnable) => WindowsDispatcher::execute_runnable(runnable),
|
||||
let mut main_receiver = self.main_receiver.clone();
|
||||
match main_receiver.try_pop() {
|
||||
Ok(Some(runnable)) => WindowsDispatcher::execute_runnable(runnable),
|
||||
_ => break 'timeout_loop,
|
||||
}
|
||||
}
|
||||
|
||||
// Someone could enqueue a Runnable here. The flag is still true, so they will not PostMessage.
|
||||
// We need to check for those Runnables after we clear the flag.
|
||||
self.dispatcher.wake_posted.store(false, Ordering::Release);
|
||||
match self.main_receiver.try_recv() {
|
||||
Err(_) => break 'tasks,
|
||||
Ok(runnable) => {
|
||||
let mut main_receiver = self.main_receiver.clone();
|
||||
match main_receiver.try_pop() {
|
||||
Ok(Some(runnable)) => {
|
||||
self.dispatcher.wake_posted.store(true, Ordering::Release);
|
||||
|
||||
WindowsDispatcher::execute_runnable(runnable);
|
||||
}
|
||||
_ => break 'tasks,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -934,7 +936,7 @@ pub(crate) struct WindowCreationInfo {
|
||||
pub(crate) windows_version: WindowsVersion,
|
||||
pub(crate) drop_target_helper: IDropTargetHelper,
|
||||
pub(crate) validation_number: usize,
|
||||
pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
|
||||
pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
|
||||
pub(crate) platform_window_handle: HWND,
|
||||
pub(crate) disable_direct_composition: bool,
|
||||
pub(crate) directx_devices: DirectXDevices,
|
||||
@@ -947,8 +949,8 @@ struct PlatformWindowCreateContext {
|
||||
inner: Option<Result<Rc<WindowsPlatformInner>>>,
|
||||
raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
|
||||
validation_number: usize,
|
||||
main_sender: Option<flume::Sender<RunnableVariant>>,
|
||||
main_receiver: Option<flume::Receiver<RunnableVariant>>,
|
||||
main_sender: Option<PriorityQueueSender<RunnableVariant>>,
|
||||
main_receiver: Option<PriorityQueueReceiver<RunnableVariant>>,
|
||||
directx_devices: Option<DirectXDevices>,
|
||||
dispatcher: Option<Arc<WindowsDispatcher>>,
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ pub(crate) struct WindowsWindowInner {
|
||||
pub(crate) executor: ForegroundExecutor,
|
||||
pub(crate) windows_version: WindowsVersion,
|
||||
pub(crate) validation_number: usize,
|
||||
pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
|
||||
pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
|
||||
pub(crate) platform_window_handle: HWND,
|
||||
}
|
||||
|
||||
@@ -362,7 +362,7 @@ struct WindowCreateContext {
|
||||
windows_version: WindowsVersion,
|
||||
drop_target_helper: IDropTargetHelper,
|
||||
validation_number: usize,
|
||||
main_receiver: flume::Receiver<RunnableVariant>,
|
||||
main_receiver: PriorityQueueReceiver<RunnableVariant>,
|
||||
platform_window_handle: HWND,
|
||||
appearance: WindowAppearance,
|
||||
disable_direct_composition: bool,
|
||||
|
||||
@@ -216,3 +216,19 @@ impl Drop for ThreadTimings {
|
||||
thread_timings.swap_remove(index);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_task_timing(timing: TaskTiming) {
|
||||
THREAD_TIMINGS.with(|timings| {
|
||||
let mut timings = timings.lock();
|
||||
let timings = &mut timings.timings;
|
||||
|
||||
if let Some(last_timing) = timings.iter_mut().rev().next() {
|
||||
if last_timing.location == timing.location {
|
||||
last_timing.end = timing.end;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
timings.push_back(timing);
|
||||
});
|
||||
}
|
||||
|
||||
329
crates/gpui/src/queue.rs
Normal file
329
crates/gpui/src/queue.rs
Normal file
@@ -0,0 +1,329 @@
|
||||
use std::{
|
||||
fmt,
|
||||
iter::FusedIterator,
|
||||
sync::{Arc, atomic::AtomicUsize},
|
||||
};
|
||||
|
||||
use rand::{Rng, SeedableRng, rngs::SmallRng};
|
||||
|
||||
use crate::Priority;
|
||||
|
||||
struct PriorityQueues<T> {
|
||||
high_priority: Vec<T>,
|
||||
medium_priority: Vec<T>,
|
||||
low_priority: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T> PriorityQueues<T> {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.high_priority.is_empty()
|
||||
&& self.medium_priority.is_empty()
|
||||
&& self.low_priority.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
struct PriorityQueueState<T> {
|
||||
queues: parking_lot::Mutex<PriorityQueues<T>>,
|
||||
condvar: parking_lot::Condvar,
|
||||
receiver_count: AtomicUsize,
|
||||
sender_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<T> PriorityQueueState<T> {
|
||||
fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
|
||||
if self
|
||||
.receiver_count
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
== 0
|
||||
{
|
||||
return Err(SendError(item));
|
||||
}
|
||||
|
||||
let mut queues = self.queues.lock();
|
||||
match priority {
|
||||
Priority::Realtime(_) => unreachable!(),
|
||||
Priority::High => queues.high_priority.push(item),
|
||||
Priority::Medium => queues.medium_priority.push(item),
|
||||
Priority::Low => queues.low_priority.push(item),
|
||||
};
|
||||
self.condvar.notify_one();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
|
||||
let mut queues = self.queues.lock();
|
||||
|
||||
let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
|
||||
if queues.is_empty() && sender_count == 0 {
|
||||
return Err(crate::queue::RecvError);
|
||||
}
|
||||
|
||||
// parking_lot doesn't do spurious wakeups so an if is fine
|
||||
if queues.is_empty() {
|
||||
self.condvar.wait(&mut queues);
|
||||
}
|
||||
|
||||
Ok(queues)
|
||||
}
|
||||
|
||||
fn try_recv<'a>(
|
||||
&'a self,
|
||||
) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
|
||||
let mut queues = self.queues.lock();
|
||||
|
||||
let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
|
||||
if queues.is_empty() && sender_count == 0 {
|
||||
return Err(crate::queue::RecvError);
|
||||
}
|
||||
|
||||
if queues.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(queues))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PriorityQueueSender<T> {
|
||||
state: Arc<PriorityQueueState<T>>,
|
||||
}
|
||||
|
||||
impl<T> PriorityQueueSender<T> {
|
||||
fn new(state: Arc<PriorityQueueState<T>>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
|
||||
pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
|
||||
self.state.send(priority, item)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for PriorityQueueSender<T> {
|
||||
fn drop(&mut self) {
|
||||
self.state
|
||||
.sender_count
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PriorityQueueReceiver<T> {
|
||||
state: Arc<PriorityQueueState<T>>,
|
||||
rand: SmallRng,
|
||||
disconnected: bool,
|
||||
}
|
||||
|
||||
impl<T> Clone for PriorityQueueReceiver<T> {
|
||||
fn clone(&self) -> Self {
|
||||
self.state
|
||||
.receiver_count
|
||||
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
|
||||
Self {
|
||||
state: Arc::clone(&self.state),
|
||||
rand: SmallRng::seed_from_u64(0),
|
||||
disconnected: self.disconnected,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct SendError<T>(T);
|
||||
|
||||
impl<T: fmt::Debug> fmt::Debug for SendError<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_tuple("SendError").field(&self.0).finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct RecvError;
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<T> PriorityQueueReceiver<T> {
|
||||
pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
|
||||
let state = PriorityQueueState {
|
||||
queues: parking_lot::Mutex::new(PriorityQueues {
|
||||
high_priority: Vec::new(),
|
||||
medium_priority: Vec::new(),
|
||||
low_priority: Vec::new(),
|
||||
}),
|
||||
condvar: parking_lot::Condvar::new(),
|
||||
receiver_count: AtomicUsize::new(1),
|
||||
sender_count: AtomicUsize::new(1),
|
||||
};
|
||||
let state = Arc::new(state);
|
||||
|
||||
let sender = PriorityQueueSender::new(Arc::clone(&state));
|
||||
|
||||
let receiver = PriorityQueueReceiver {
|
||||
state,
|
||||
rand: SmallRng::seed_from_u64(0),
|
||||
disconnected: false,
|
||||
};
|
||||
|
||||
(sender, receiver)
|
||||
}
|
||||
|
||||
/// Tries to pop one element from the priority queue without blocking.
|
||||
///
|
||||
/// This will early return if there are no elements in the queue.
|
||||
///
|
||||
/// This method is best suited if you only intend to pop one element, for better performance
|
||||
/// on large queues see [`Self::try_iter`]
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If the sender was dropped
|
||||
pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
|
||||
self.pop_inner(false)
|
||||
}
|
||||
|
||||
/// Pops an element from the priority queue blocking if necessary.
|
||||
///
|
||||
/// This method is best suited if you only intend to pop one element, for better performance
|
||||
/// on large queues see [`Self::iter``]
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If the sender was dropped
|
||||
pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
|
||||
self.pop_inner(true).map(|e| e.unwrap())
|
||||
}
|
||||
|
||||
/// Returns an iterator over the elements of the queue
|
||||
/// this iterator will end when all elements have been consumed and will not wait for new ones.
|
||||
pub(crate) fn try_iter(self) -> TryIter<T> {
|
||||
TryIter {
|
||||
receiver: self,
|
||||
ended: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over the elements of the queue
|
||||
/// this iterator will wait for new elements if the queue is empty.
|
||||
pub(crate) fn iter(self) -> Iter<T> {
|
||||
Iter(self)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
// algorithm is the loaded die from biased coin from
|
||||
// https://www.keithschwarz.com/darts-dice-coins/
|
||||
fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
|
||||
use Priority as P;
|
||||
|
||||
let mut queues = if !block {
|
||||
let Some(queues) = self.state.try_recv()? else {
|
||||
return Ok(None);
|
||||
};
|
||||
queues
|
||||
} else {
|
||||
self.state.recv()?
|
||||
};
|
||||
|
||||
let high = P::High.probability() * !queues.high_priority.is_empty() as u32;
|
||||
let medium = P::Medium.probability() * !queues.medium_priority.is_empty() as u32;
|
||||
let low = P::Low.probability() * !queues.low_priority.is_empty() as u32;
|
||||
let mut mass = high + medium + low; //%
|
||||
|
||||
if !queues.high_priority.is_empty() {
|
||||
let flip = self.rand.random_ratio(P::High.probability(), mass);
|
||||
if flip {
|
||||
return Ok(queues.high_priority.pop());
|
||||
}
|
||||
mass -= P::High.probability();
|
||||
}
|
||||
|
||||
if !queues.medium_priority.is_empty() {
|
||||
let flip = self.rand.random_ratio(P::Medium.probability(), mass);
|
||||
if flip {
|
||||
return Ok(queues.medium_priority.pop());
|
||||
}
|
||||
mass -= P::Medium.probability();
|
||||
}
|
||||
|
||||
if !queues.low_priority.is_empty() {
|
||||
let flip = self.rand.random_ratio(P::Low.probability(), mass);
|
||||
if flip {
|
||||
return Ok(queues.low_priority.pop());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for PriorityQueueReceiver<T> {
|
||||
fn drop(&mut self) {
|
||||
self.state
|
||||
.receiver_count
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
|
||||
}
|
||||
}
|
||||
|
||||
/// If None is returned the sender disconnected
|
||||
pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
|
||||
impl<T> Iterator for Iter<T> {
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.0.pop_inner(true).ok().flatten()
|
||||
}
|
||||
}
|
||||
impl<T> FusedIterator for Iter<T> {}
|
||||
|
||||
/// If None is returned there are no more elements in the queue
|
||||
pub(crate) struct TryIter<T> {
|
||||
receiver: PriorityQueueReceiver<T>,
|
||||
ended: bool,
|
||||
}
|
||||
impl<T> Iterator for TryIter<T> {
|
||||
type Item = Result<T, RecvError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.ended {
|
||||
return None;
|
||||
}
|
||||
|
||||
let res = self.receiver.pop_inner(false);
|
||||
self.ended = res.is_err();
|
||||
|
||||
res.transpose()
|
||||
}
|
||||
}
|
||||
impl<T> FusedIterator for TryIter<T> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use collections::HashSet;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn all_tasks_get_yielded() {
|
||||
let (tx, mut rx) = PriorityQueueReceiver::new();
|
||||
tx.send(Priority::Medium, 20).unwrap();
|
||||
tx.send(Priority::High, 30).unwrap();
|
||||
tx.send(Priority::Low, 10).unwrap();
|
||||
tx.send(Priority::Medium, 21).unwrap();
|
||||
tx.send(Priority::High, 31).unwrap();
|
||||
|
||||
drop(tx);
|
||||
|
||||
assert_eq!(
|
||||
rx.iter().collect::<HashSet<_>>(),
|
||||
[30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_high_prio_task_get_scheduled_quickly() {
|
||||
let (tx, mut rx) = PriorityQueueReceiver::new();
|
||||
for _ in 0..100 {
|
||||
tx.send(Priority::Low, 1).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(rx.pop().unwrap(), 1);
|
||||
tx.send(Priority::High, 3).unwrap();
|
||||
assert_eq!(rx.pop().unwrap(), 3);
|
||||
assert_eq!(rx.pop().unwrap(), 1);
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
use crate::{
|
||||
self as gpui, AbsoluteLength, AlignContent, AlignItems, BorderStyle, CursorStyle,
|
||||
DefiniteLength, Display, Fill, FlexDirection, FlexWrap, Font, FontStyle, FontWeight,
|
||||
GridPlacement, Hsla, JustifyContent, Length, SharedString, StrikethroughStyle, StyleRefinement,
|
||||
TextAlign, TextOverflow, TextStyleRefinement, UnderlineStyle, WhiteSpace, px, relative, rems,
|
||||
DefiniteLength, Display, Fill, FlexDirection, FlexWrap, Font, FontFeatures, FontStyle,
|
||||
FontWeight, GridPlacement, Hsla, JustifyContent, Length, SharedString, StrikethroughStyle,
|
||||
StyleRefinement, TextAlign, TextOverflow, TextStyleRefinement, UnderlineStyle, WhiteSpace, px,
|
||||
relative, rems,
|
||||
};
|
||||
pub use gpui_macros::{
|
||||
border_style_methods, box_shadow_style_methods, cursor_style_methods, margin_style_methods,
|
||||
@@ -630,6 +631,14 @@ pub trait Styled: Sized {
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the font features of this element and its children.
|
||||
fn font_features(mut self, features: FontFeatures) -> Self {
|
||||
self.text_style()
|
||||
.get_or_insert_with(Default::default)
|
||||
.font_features = Some(features);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the font of this element and its children.
|
||||
fn font(mut self, font: Font) -> Self {
|
||||
let Font {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user