Compare commits
78 Commits
ex-test-in
...
migrate-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9cc517e0dd | ||
|
|
d1390a5b78 | ||
|
|
ee4faede38 | ||
|
|
8d96a699b3 | ||
|
|
8cfb7471db | ||
|
|
def9c87837 | ||
|
|
0313ab6d41 | ||
|
|
c5329fdff2 | ||
|
|
a676a6895b | ||
|
|
3b5d7d7d89 | ||
|
|
91f01131b1 | ||
|
|
5fa5226286 | ||
|
|
ae94007227 | ||
|
|
8f425a1bd5 | ||
|
|
743c414e7b | ||
|
|
0fe335efc5 | ||
|
|
36b95aac4b | ||
|
|
b2df70ab58 | ||
|
|
36293d7dd9 | ||
|
|
3ae3e1fce8 | ||
|
|
e5f1fc7478 | ||
|
|
a4f6076da7 | ||
|
|
43726b2620 | ||
|
|
94980ffb49 | ||
|
|
22cc731450 | ||
|
|
d9396373e3 | ||
|
|
48002be135 | ||
|
|
58db83f8f5 | ||
|
|
0243d5b542 | ||
|
|
06230327fa | ||
|
|
ca5c8992f9 | ||
|
|
1038e1c2ef | ||
|
|
e1fe0b3287 | ||
|
|
a0e10a91bf | ||
|
|
272b1aa4bc | ||
|
|
9ef0537b44 | ||
|
|
77f1de742b | ||
|
|
e054cabd41 | ||
|
|
3b95cb5682 | ||
|
|
c89653bd07 | ||
|
|
b90ac2dc07 | ||
|
|
c9998541f0 | ||
|
|
e2b49b3cd3 | ||
|
|
d1e77397c6 | ||
|
|
cc5f5e35e4 | ||
|
|
7183b8a1cd | ||
|
|
b1934fb712 | ||
|
|
a198b6c0d1 | ||
|
|
8b5b2712c8 | ||
|
|
4464392e8e | ||
|
|
a0d3bc31e9 | ||
|
|
ccd6672d1a | ||
|
|
21de6d35dd | ||
|
|
2031ca17e5 | ||
|
|
8b1ce75a57 | ||
|
|
5559726fd7 | ||
|
|
e1a9269921 | ||
|
|
3b6b3ff504 | ||
|
|
aabed94970 | ||
|
|
2d3a3521ba | ||
|
|
a48bd10da0 | ||
|
|
fec9525be4 | ||
|
|
bf2b8e999e | ||
|
|
63c35d2b00 | ||
|
|
1396c68010 | ||
|
|
fcb3d3dec6 | ||
|
|
f54e7f8c9d | ||
|
|
2a89529d7f | ||
|
|
58207325e2 | ||
|
|
e08ab99e8d | ||
|
|
a95f3f33a4 | ||
|
|
b0767c1b1f | ||
|
|
b200e10bc4 | ||
|
|
948905d916 | ||
|
|
04de456373 | ||
|
|
e5ce32e936 | ||
|
|
d7caae30de | ||
|
|
c7e77674a1 |
16
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
16
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -75,22 +75,6 @@ body:
|
||||
</details>
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Relevant Keymap
|
||||
description: |
|
||||
Open the command palette in Zed, then type “zed: open keymap file” and copy/paste the file's contents.
|
||||
value: |
|
||||
<details><summary>keymap.json</summary>
|
||||
|
||||
<!-- Paste your keymap file inside the code block. -->
|
||||
```json
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: (for AI issues) Model provider details
|
||||
|
||||
25
.github/workflows/after_release.yml
vendored
25
.github/workflows/after_release.yml
vendored
@@ -5,27 +5,13 @@ on:
|
||||
release:
|
||||
types:
|
||||
- published
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag_name:
|
||||
description: tag_name
|
||||
required: true
|
||||
type: string
|
||||
prerelease:
|
||||
description: prerelease
|
||||
required: true
|
||||
type: boolean
|
||||
body:
|
||||
description: body
|
||||
type: string
|
||||
default: ''
|
||||
jobs:
|
||||
rebuild_releases_page:
|
||||
if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
|
||||
runs-on: namespace-profile-2x4-ubuntu-2404
|
||||
steps:
|
||||
- name: after_release::rebuild_releases_page::refresh_cloud_releases
|
||||
run: curl -fX POST https://cloud.zed.dev/releases/refresh?expect_tag=${{ github.event.release.tag_name || inputs.tag_name }}
|
||||
run: curl -fX POST https://cloud.zed.dev/releases/refresh?expect_tag=${{ github.event.release.tag_name }}
|
||||
shell: bash -euxo pipefail {0}
|
||||
- name: after_release::rebuild_releases_page::redeploy_zed_dev
|
||||
run: npm exec --yes -- vercel@37 --token="$VERCEL_TOKEN" --scope zed-industries redeploy https://zed.dev
|
||||
@@ -41,7 +27,7 @@ jobs:
|
||||
- id: get-release-url
|
||||
name: after_release::post_to_discord::get_release_url
|
||||
run: |
|
||||
if [ "${{ github.event.release.prerelease || inputs.prerelease }}" == "true" ]; then
|
||||
if [ "${{ github.event.release.prerelease }}" == "true" ]; then
|
||||
URL="https://zed.dev/releases/preview"
|
||||
else
|
||||
URL="https://zed.dev/releases/stable"
|
||||
@@ -54,9 +40,9 @@ jobs:
|
||||
uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757
|
||||
with:
|
||||
stringToTruncate: |
|
||||
📣 Zed [${{ github.event.release.tag_name || inputs.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released!
|
||||
📣 Zed [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released!
|
||||
|
||||
${{ github.event.release.body || inputs.body }}
|
||||
${{ github.event.release.body }}
|
||||
maxLength: 2000
|
||||
truncationSymbol: '...'
|
||||
- name: after_release::post_to_discord::discord_webhook_action
|
||||
@@ -70,7 +56,7 @@ jobs:
|
||||
- id: set-package-name
|
||||
name: after_release::publish_winget::set_package_name
|
||||
run: |
|
||||
if ("${{ github.event.release.prerelease || inputs.prerelease }}" -eq "true") {
|
||||
if ("${{ github.event.release.prerelease }}" -eq "true") {
|
||||
$PACKAGE_NAME = "ZedIndustries.Zed.Preview"
|
||||
} else {
|
||||
$PACKAGE_NAME = "ZedIndustries.Zed"
|
||||
@@ -82,7 +68,6 @@ jobs:
|
||||
uses: vedantmgoyal9/winget-releaser@19e706d4c9121098010096f9c495a70a7518b30f
|
||||
with:
|
||||
identifier: ${{ steps.set-package-name.outputs.PACKAGE_NAME }}
|
||||
release-tag: ${{ github.event.release.tag_name || inputs.tag_name }}
|
||||
max-versions-to-keep: 5
|
||||
token: ${{ secrets.WINGET_TOKEN }}
|
||||
create_sentry_release:
|
||||
|
||||
39
Cargo.lock
generated
39
Cargo.lock
generated
@@ -2770,9 +2770,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.49"
|
||||
version = "1.2.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215"
|
||||
checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7"
|
||||
dependencies = [
|
||||
"find-msvc-tools",
|
||||
"jobserver",
|
||||
@@ -3113,9 +3113,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.56"
|
||||
version = "0.1.54"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b042e5d8a74ae91bb0961acd039822472ec99f8ab0948cbf6d1369588f8be586"
|
||||
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
@@ -5111,6 +5111,7 @@ dependencies = [
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"copilot",
|
||||
"credentials_provider",
|
||||
"ctor",
|
||||
"db",
|
||||
"edit_prediction_context",
|
||||
@@ -5178,7 +5179,6 @@ dependencies = [
|
||||
"language_model",
|
||||
"language_models",
|
||||
"languages",
|
||||
"libc",
|
||||
"log",
|
||||
"node_runtime",
|
||||
"paths",
|
||||
@@ -5200,6 +5200,7 @@ dependencies = [
|
||||
"wasmtime",
|
||||
"watch",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5274,6 +5275,7 @@ dependencies = [
|
||||
"text",
|
||||
"theme",
|
||||
"ui",
|
||||
"ui_input",
|
||||
"util",
|
||||
"workspace",
|
||||
"zed_actions",
|
||||
@@ -5841,9 +5843,12 @@ dependencies = [
|
||||
"async-trait",
|
||||
"client",
|
||||
"collections",
|
||||
"credentials_provider",
|
||||
"criterion",
|
||||
"ctor",
|
||||
"dap",
|
||||
"dirs 4.0.0",
|
||||
"editor",
|
||||
"extension",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
@@ -5852,8 +5857,11 @@ dependencies = [
|
||||
"http_client",
|
||||
"language",
|
||||
"language_extension",
|
||||
"language_model",
|
||||
"log",
|
||||
"lsp",
|
||||
"markdown",
|
||||
"menu",
|
||||
"moka",
|
||||
"node_runtime",
|
||||
"parking_lot",
|
||||
@@ -5868,12 +5876,14 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
"settings",
|
||||
"smol",
|
||||
"task",
|
||||
"telemetry",
|
||||
"tempfile",
|
||||
"theme",
|
||||
"theme_extension",
|
||||
"toml 0.8.23",
|
||||
"ui",
|
||||
"url",
|
||||
"util",
|
||||
"wasmparser 0.221.3",
|
||||
@@ -6091,9 +6101,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "find-msvc-tools"
|
||||
version = "0.1.5"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844"
|
||||
checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127"
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
@@ -7237,7 +7247,6 @@ dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"lyon",
|
||||
"mach2 0.5.0",
|
||||
"media",
|
||||
"metal",
|
||||
"naga",
|
||||
@@ -8800,7 +8809,6 @@ dependencies = [
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"credentials_provider",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"http_client",
|
||||
@@ -8819,7 +8827,6 @@ dependencies = [
|
||||
"telemetry_events",
|
||||
"thiserror 2.0.17",
|
||||
"util",
|
||||
"zed_env_vars",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8843,6 +8850,8 @@ dependencies = [
|
||||
"credentials_provider",
|
||||
"deepseek",
|
||||
"editor",
|
||||
"extension",
|
||||
"extension_host",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"google_ai",
|
||||
@@ -8876,6 +8885,7 @@ dependencies = [
|
||||
"util",
|
||||
"vercel",
|
||||
"x_ai",
|
||||
"zed_env_vars",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -14445,14 +14455,12 @@ dependencies = [
|
||||
"settings",
|
||||
"smol",
|
||||
"theme",
|
||||
"tracing",
|
||||
"ui",
|
||||
"unindent",
|
||||
"util",
|
||||
"util_macros",
|
||||
"workspace",
|
||||
"zed_actions",
|
||||
"ztracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -14777,8 +14785,6 @@ dependencies = [
|
||||
"assets",
|
||||
"bm25",
|
||||
"client",
|
||||
"copilot",
|
||||
"edit_prediction",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
@@ -14787,7 +14793,6 @@ dependencies = [
|
||||
"gpui",
|
||||
"heck 0.5.0",
|
||||
"language",
|
||||
"language_models",
|
||||
"log",
|
||||
"menu",
|
||||
"node_runtime",
|
||||
@@ -16368,13 +16373,13 @@ dependencies = [
|
||||
"alacritty_terminal",
|
||||
"anyhow",
|
||||
"collections",
|
||||
"fancy-regex",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"itertools 0.14.0",
|
||||
"libc",
|
||||
"log",
|
||||
"rand 0.9.2",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"schemars",
|
||||
"serde",
|
||||
@@ -18102,7 +18107,6 @@ dependencies = [
|
||||
"language",
|
||||
"log",
|
||||
"lsp",
|
||||
"markdown_preview",
|
||||
"menu",
|
||||
"multi_buffer",
|
||||
"nvim-rs",
|
||||
@@ -21027,7 +21031,6 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"tracing-tracy",
|
||||
"zlog",
|
||||
"ztracing_macro",
|
||||
]
|
||||
|
||||
|
||||
@@ -631,7 +631,7 @@ shellexpand = "2.1.0"
|
||||
shlex = "1.3.0"
|
||||
simplelog = "0.12.2"
|
||||
slotmap = "1.0.6"
|
||||
smallvec = { version = "1.6", features = ["union", "const_new"] }
|
||||
smallvec = { version = "1.6", features = ["union"] }
|
||||
smol = "2.0"
|
||||
sqlformat = "0.2"
|
||||
stacksafe = "0.1"
|
||||
|
||||
@@ -25,8 +25,7 @@
|
||||
"ctrl-shift-w": "workspace::CloseWindow",
|
||||
"shift-escape": "workspace::ToggleZoom",
|
||||
"open": "workspace::Open",
|
||||
"ctrl-o": "workspace::OpenFiles",
|
||||
"ctrl-k ctrl-o": "workspace::Open",
|
||||
"ctrl-o": "workspace::Open",
|
||||
"ctrl-=": ["zed::IncreaseBufferFontSize", { "persist": false }],
|
||||
"ctrl-+": ["zed::IncreaseBufferFontSize", { "persist": false }],
|
||||
"ctrl--": ["zed::DecreaseBufferFontSize", { "persist": false }],
|
||||
@@ -815,6 +814,7 @@
|
||||
"ctrl-]": "agent::CycleNextInlineAssist",
|
||||
"ctrl-shift-enter": "inline_assistant::ThumbsUpResult",
|
||||
"ctrl-shift-backspace": "inline_assistant::ThumbsDownResult"
|
||||
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -1192,12 +1192,8 @@
|
||||
{
|
||||
"context": "MarkdownPreview",
|
||||
"bindings": {
|
||||
"pageup": "markdown::ScrollPageUp",
|
||||
"pagedown": "markdown::ScrollPageDown",
|
||||
"up": "markdown::ScrollUp",
|
||||
"down": "markdown::ScrollDown",
|
||||
"alt-up": "markdown::ScrollUpByItem",
|
||||
"alt-down": "markdown::ScrollDownByItem"
|
||||
"pageup": "markdown::MovePageUp",
|
||||
"pagedown": "markdown::MovePageDown"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1296,12 +1296,8 @@
|
||||
{
|
||||
"context": "MarkdownPreview",
|
||||
"bindings": {
|
||||
"pageup": "markdown::ScrollPageUp",
|
||||
"pagedown": "markdown::ScrollPageDown",
|
||||
"up": "markdown::ScrollUp",
|
||||
"down": "markdown::ScrollDown",
|
||||
"alt-up": "markdown::ScrollUpByItem",
|
||||
"alt-down": "markdown::ScrollDownByItem"
|
||||
"pageup": "markdown::MovePageUp",
|
||||
"pagedown": "markdown::MovePageDown"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -489,8 +489,8 @@
|
||||
"bindings": {
|
||||
"ctrl-[": "editor::Outdent",
|
||||
"ctrl-]": "editor::Indent",
|
||||
"ctrl-alt-up": ["editor::AddSelectionAbove", { "skip_soft_wrap": true }], // Insert Cursor Above
|
||||
"ctrl-alt-down": ["editor::AddSelectionBelow", { "skip_soft_wrap": true }], // Insert Cursor Below
|
||||
"ctrl-shift-alt-up": ["editor::AddSelectionAbove", { "skip_soft_wrap": true }], // Insert Cursor Above
|
||||
"ctrl-shift-alt-down": ["editor::AddSelectionBelow", { "skip_soft_wrap": true }], // Insert Cursor Below
|
||||
"ctrl-shift-k": "editor::DeleteLine",
|
||||
"alt-up": "editor::MoveLineUp",
|
||||
"alt-down": "editor::MoveLineDown",
|
||||
@@ -501,12 +501,9 @@
|
||||
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
|
||||
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
|
||||
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
|
||||
"ctrl-f3": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
|
||||
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip
|
||||
"ctrl-shift-f3": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
|
||||
"ctrl-k ctrl-i": "editor::Hover",
|
||||
"ctrl-k ctrl-b": "editor::BlameHover",
|
||||
"ctrl-k ctrl-f": "editor::FormatSelections",
|
||||
"ctrl-/": ["editor::ToggleComments", { "advance_downwards": false }],
|
||||
"f8": ["editor::GoToDiagnostic", { "severity": { "min": "hint", "max": "error" } }],
|
||||
"shift-f8": ["editor::GoToPreviousDiagnostic", { "severity": { "min": "hint", "max": "error" } }],
|
||||
@@ -539,7 +536,7 @@
|
||||
"ctrl-k p": "editor::CopyPath",
|
||||
"ctrl-\\": "pane::SplitRight",
|
||||
"alt-.": "editor::GoToHunk",
|
||||
"alt-,": "editor::GoToPreviousHunk",
|
||||
"alt-,": "editor::GoToPreviousHunk"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -1223,12 +1220,8 @@
|
||||
"context": "MarkdownPreview",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"pageup": "markdown::ScrollPageUp",
|
||||
"pagedown": "markdown::ScrollPageDown",
|
||||
"up": "markdown::ScrollUp",
|
||||
"down": "markdown::ScrollDown",
|
||||
"alt-up": "markdown::ScrollUpByItem",
|
||||
"alt-down": "markdown::ScrollDownByItem"
|
||||
"pageup": "markdown::MovePageUp",
|
||||
"pagedown": "markdown::MovePageDown"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1046,14 +1046,5 @@
|
||||
"g g": "settings_editor::FocusFirstNavEntry",
|
||||
"shift-g": "settings_editor::FocusLastNavEntry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "MarkdownPreview",
|
||||
"bindings": {
|
||||
"ctrl-u": "markdown::ScrollPageUp",
|
||||
"ctrl-d": "markdown::ScrollPageDown",
|
||||
"ctrl-y": "markdown::ScrollUp",
|
||||
"ctrl-e": "markdown::ScrollDown"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -39,5 +39,6 @@ Only make changes that are necessary to fulfill the prompt, leave everything els
|
||||
|
||||
Start at the indentation level in the original file in the rewritten {{content_type}}.
|
||||
|
||||
IMPORTANT: You MUST use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled. You MUST NOT send back unstructured text. If you need to make a statement or ask a question you MUST use one of the tools to do so.
|
||||
You must use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled. It is an error if
|
||||
you simply send back unstructured text. If you need to make a statement or ask a question you must use one of the tools to do so.
|
||||
It is an error if you try to make a change that cannot be made simply by editing the rewrite_section.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
"theme": {
|
||||
"mode": "system",
|
||||
"light": "One Light",
|
||||
"dark": "One Dark",
|
||||
"dark": "One Dark"
|
||||
},
|
||||
"icon_theme": "Zed (Default)",
|
||||
// The name of a base set of key bindings to use.
|
||||
@@ -29,7 +29,7 @@
|
||||
// Features that can be globally enabled or disabled
|
||||
"features": {
|
||||
// Which edit prediction provider to use.
|
||||
"edit_prediction_provider": "zed",
|
||||
"edit_prediction_provider": "zed"
|
||||
},
|
||||
// The name of a font to use for rendering text in the editor
|
||||
// ".ZedMono" currently aliases to Lilex
|
||||
@@ -69,7 +69,7 @@
|
||||
// The OpenType features to enable for text in the UI
|
||||
"ui_font_features": {
|
||||
// Disable ligatures:
|
||||
"calt": false,
|
||||
"calt": false
|
||||
},
|
||||
// The weight of the UI font in standard CSS units from 100 to 900.
|
||||
"ui_font_weight": 400,
|
||||
@@ -87,7 +87,7 @@
|
||||
"border_size": 0.0,
|
||||
// Opacity of the inactive panes. 0 means transparent, 1 means opaque.
|
||||
// Values are clamped to the [0.0, 1.0] range.
|
||||
"inactive_opacity": 1.0,
|
||||
"inactive_opacity": 1.0
|
||||
},
|
||||
// Layout mode of the bottom dock. Defaults to "contained"
|
||||
// choices: contained, full, left_aligned, right_aligned
|
||||
@@ -103,12 +103,12 @@
|
||||
"left_padding": 0.2,
|
||||
// The relative width of the right padding of the central pane from the
|
||||
// workspace when the centered layout is used.
|
||||
"right_padding": 0.2,
|
||||
"right_padding": 0.2
|
||||
},
|
||||
// Image viewer settings
|
||||
"image_viewer": {
|
||||
// The unit for image file sizes: "binary" (KiB, MiB) or decimal (KB, MB)
|
||||
"unit": "binary",
|
||||
"unit": "binary"
|
||||
},
|
||||
// Determines the modifier to be used to add multiple cursors with the mouse. The open hover link mouse gestures will adapt such that it do not conflict with the multicursor modifier.
|
||||
//
|
||||
@@ -296,7 +296,7 @@
|
||||
// When true, enables drag and drop text selection in buffer.
|
||||
"enabled": true,
|
||||
// The delay in milliseconds that must elapse before drag and drop is allowed. Otherwise, a new text selection is created.
|
||||
"delay": 300,
|
||||
"delay": 300
|
||||
},
|
||||
// What to do when go to definition yields no results.
|
||||
//
|
||||
@@ -400,14 +400,14 @@
|
||||
// Visible characters used to render whitespace when show_whitespaces is enabled.
|
||||
"whitespace_map": {
|
||||
"space": "•",
|
||||
"tab": "→",
|
||||
"tab": "→"
|
||||
},
|
||||
// Settings related to calls in Zed
|
||||
"calls": {
|
||||
// Join calls with the microphone live by default
|
||||
"mute_on_join": false,
|
||||
// Share your project when you are the first to join a channel
|
||||
"share_on_join": false,
|
||||
"share_on_join": false
|
||||
},
|
||||
// Toolbar related settings
|
||||
"toolbar": {
|
||||
@@ -420,7 +420,7 @@
|
||||
// Whether to show agent review buttons in the editor toolbar.
|
||||
"agent_review": true,
|
||||
// Whether to show code action buttons in the editor toolbar.
|
||||
"code_actions": false,
|
||||
"code_actions": false
|
||||
},
|
||||
// Whether to allow windows to tab together based on the user’s tabbing preference (macOS only).
|
||||
"use_system_window_tabs": false,
|
||||
@@ -439,7 +439,7 @@
|
||||
// Whether to show the sign in button in the titlebar.
|
||||
"show_sign_in": true,
|
||||
// Whether to show the menus in the titlebar.
|
||||
"show_menus": false,
|
||||
"show_menus": false
|
||||
},
|
||||
"audio": {
|
||||
// Opt into the new audio system.
|
||||
@@ -472,7 +472,7 @@
|
||||
// the future we will migrate by setting this to false
|
||||
//
|
||||
// You need to rejoin a call for this setting to apply
|
||||
"experimental.legacy_audio_compatible": true,
|
||||
"experimental.legacy_audio_compatible": true
|
||||
},
|
||||
// Scrollbar related settings
|
||||
"scrollbar": {
|
||||
@@ -511,8 +511,8 @@
|
||||
// When false, forcefully disables the horizontal scrollbar. Otherwise, obey other settings.
|
||||
"horizontal": true,
|
||||
// When false, forcefully disables the vertical scrollbar. Otherwise, obey other settings.
|
||||
"vertical": true,
|
||||
},
|
||||
"vertical": true
|
||||
}
|
||||
},
|
||||
// Minimap related settings
|
||||
"minimap": {
|
||||
@@ -560,7 +560,7 @@
|
||||
// 3. "gutter" or "none" to not highlight the current line in the minimap.
|
||||
"current_line_highlight": null,
|
||||
// Maximum number of columns to display in the minimap.
|
||||
"max_width_columns": 80,
|
||||
"max_width_columns": 80
|
||||
},
|
||||
// Enable middle-click paste on Linux.
|
||||
"middle_click_paste": true,
|
||||
@@ -583,7 +583,7 @@
|
||||
// Whether to show fold buttons in the gutter.
|
||||
"folds": true,
|
||||
// Minimum number of characters to reserve space for in the gutter.
|
||||
"min_line_number_digits": 4,
|
||||
"min_line_number_digits": 4
|
||||
},
|
||||
"indent_guides": {
|
||||
// Whether to show indent guides in the editor.
|
||||
@@ -604,7 +604,7 @@
|
||||
//
|
||||
// 1. "disabled"
|
||||
// 2. "indent_aware"
|
||||
"background_coloring": "disabled",
|
||||
"background_coloring": "disabled"
|
||||
},
|
||||
// Whether the editor will scroll beyond the last line.
|
||||
"scroll_beyond_last_line": "one_page",
|
||||
@@ -623,7 +623,7 @@
|
||||
"fast_scroll_sensitivity": 4.0,
|
||||
"sticky_scroll": {
|
||||
// Whether to stick scopes to the top of the editor.
|
||||
"enabled": false,
|
||||
"enabled": false
|
||||
},
|
||||
"relative_line_numbers": "disabled",
|
||||
// If 'search_wrap' is disabled, search result do not wrap around the end of the file.
|
||||
@@ -641,7 +641,7 @@
|
||||
// Whether to interpret the search query as a regular expression.
|
||||
"regex": false,
|
||||
// Whether to center the cursor on each search match when navigating.
|
||||
"center_on_match": false,
|
||||
"center_on_match": false
|
||||
},
|
||||
// When to populate a new search's query based on the text under the cursor.
|
||||
// This setting can take the following three values:
|
||||
@@ -684,8 +684,8 @@
|
||||
"shift": false,
|
||||
"alt": false,
|
||||
"platform": false,
|
||||
"function": false,
|
||||
},
|
||||
"function": false
|
||||
}
|
||||
},
|
||||
// Whether to resize all the panels in a dock when resizing the dock.
|
||||
// Can be a combination of "left", "right" and "bottom".
|
||||
@@ -733,7 +733,7 @@
|
||||
// "always"
|
||||
// 5. Never show the scrollbar:
|
||||
// "never"
|
||||
"show": null,
|
||||
"show": null
|
||||
},
|
||||
// Which files containing diagnostic errors/warnings to mark in the project panel.
|
||||
// This setting can take the following three values:
|
||||
@@ -756,7 +756,7 @@
|
||||
// "always"
|
||||
// 2. Never show indent guides:
|
||||
// "never"
|
||||
"show": "always",
|
||||
"show": "always"
|
||||
},
|
||||
// Sort order for entries in the project panel.
|
||||
// This setting can take three values:
|
||||
@@ -781,8 +781,8 @@
|
||||
// Whether to automatically open files after pasting or duplicating them.
|
||||
"on_paste": true,
|
||||
// Whether to automatically open files dropped from external sources.
|
||||
"on_drop": true,
|
||||
},
|
||||
"on_drop": true
|
||||
}
|
||||
},
|
||||
"outline_panel": {
|
||||
// Whether to show the outline panel button in the status bar
|
||||
@@ -815,7 +815,7 @@
|
||||
// "always"
|
||||
// 2. Never show indent guides:
|
||||
// "never"
|
||||
"show": "always",
|
||||
"show": "always"
|
||||
},
|
||||
// Scrollbar-related settings
|
||||
"scrollbar": {
|
||||
@@ -832,11 +832,11 @@
|
||||
// "always"
|
||||
// 5. Never show the scrollbar:
|
||||
// "never"
|
||||
"show": null,
|
||||
"show": null
|
||||
},
|
||||
// Default depth to expand outline items in the current file.
|
||||
// Set to 0 to collapse all items that have children, 1 or higher to collapse items at that depth or deeper.
|
||||
"expand_outlines_with_depth": 100,
|
||||
"expand_outlines_with_depth": 100
|
||||
},
|
||||
"collaboration_panel": {
|
||||
// Whether to show the collaboration panel button in the status bar.
|
||||
@@ -844,7 +844,7 @@
|
||||
// Where to dock the collaboration panel. Can be 'left' or 'right'.
|
||||
"dock": "left",
|
||||
// Default width of the collaboration panel.
|
||||
"default_width": 240,
|
||||
"default_width": 240
|
||||
},
|
||||
"git_panel": {
|
||||
// Whether to show the git panel button in the status bar.
|
||||
@@ -880,12 +880,12 @@
|
||||
// Choices: always, auto, never, system
|
||||
// Default: inherits editor scrollbar settings
|
||||
// "show": null
|
||||
},
|
||||
}
|
||||
},
|
||||
"message_editor": {
|
||||
// Whether to automatically replace emoji shortcodes with emoji characters.
|
||||
// For example: typing `:wave:` gets replaced with `👋`.
|
||||
"auto_replace_emoji_shortcode": true,
|
||||
"auto_replace_emoji_shortcode": true
|
||||
},
|
||||
"notification_panel": {
|
||||
// Whether to show the notification panel button in the status bar.
|
||||
@@ -893,11 +893,9 @@
|
||||
// Where to dock the notification panel. Can be 'left' or 'right'.
|
||||
"dock": "right",
|
||||
// Default width of the notification panel.
|
||||
"default_width": 380,
|
||||
"default_width": 380
|
||||
},
|
||||
"agent": {
|
||||
// Whether the inline assistant should use streaming tools, when available
|
||||
"inline_assistant_use_streaming_tools": true,
|
||||
// Whether the agent is enabled.
|
||||
"enabled": true,
|
||||
// What completion mode to start new threads in, if available. Can be 'normal' or 'burn'.
|
||||
@@ -917,7 +915,7 @@
|
||||
// The provider to use.
|
||||
"provider": "zed.dev",
|
||||
// The model to use.
|
||||
"model": "claude-sonnet-4",
|
||||
"model": "claude-sonnet-4"
|
||||
},
|
||||
// Additional parameters for language model requests. When making a request to a model, parameters will be taken
|
||||
// from the last entry in this list that matches the model's provider and name. In each entry, both provider
|
||||
@@ -972,8 +970,8 @@
|
||||
"grep": true,
|
||||
"terminal": true,
|
||||
"thinking": true,
|
||||
"web_search": true,
|
||||
},
|
||||
"web_search": true
|
||||
}
|
||||
},
|
||||
"ask": {
|
||||
"name": "Ask",
|
||||
@@ -990,14 +988,14 @@
|
||||
"open": true,
|
||||
"grep": true,
|
||||
"thinking": true,
|
||||
"web_search": true,
|
||||
},
|
||||
"web_search": true
|
||||
}
|
||||
},
|
||||
"minimal": {
|
||||
"name": "Minimal",
|
||||
"enable_all_context_servers": false,
|
||||
"tools": {},
|
||||
},
|
||||
"tools": {}
|
||||
}
|
||||
},
|
||||
// Where to show notifications when the agent has either completed
|
||||
// its response, or else needs confirmation before it can run a
|
||||
@@ -1026,7 +1024,7 @@
|
||||
// Minimum number of lines to display in the agent message editor.
|
||||
//
|
||||
// Default: 4
|
||||
"message_editor_min_lines": 4,
|
||||
"message_editor_min_lines": 4
|
||||
},
|
||||
// Whether the screen sharing icon is shown in the os status bar.
|
||||
"show_call_status_icon": true,
|
||||
@@ -1061,7 +1059,7 @@
|
||||
// Whether or not to show the navigation history buttons.
|
||||
"show_nav_history_buttons": true,
|
||||
// Whether or not to show the tab bar buttons.
|
||||
"show_tab_bar_buttons": true,
|
||||
"show_tab_bar_buttons": true
|
||||
},
|
||||
// Settings related to the editor's tabs
|
||||
"tabs": {
|
||||
@@ -1100,7 +1098,7 @@
|
||||
// "errors"
|
||||
// 3. Mark files with errors and warnings:
|
||||
// "all"
|
||||
"show_diagnostics": "off",
|
||||
"show_diagnostics": "off"
|
||||
},
|
||||
// Settings related to preview tabs.
|
||||
"preview_tabs": {
|
||||
@@ -1121,7 +1119,7 @@
|
||||
"enable_preview_file_from_code_navigation": true,
|
||||
// Whether to keep tabs in preview mode when code navigation is used to navigate away from them.
|
||||
// If `enable_preview_file_from_code_navigation` or `enable_preview_multibuffer_from_code_navigation` is also true, the new tab may replace the existing one.
|
||||
"enable_keep_preview_on_code_navigation": false,
|
||||
"enable_keep_preview_on_code_navigation": false
|
||||
},
|
||||
// Settings related to the file finder.
|
||||
"file_finder": {
|
||||
@@ -1165,7 +1163,7 @@
|
||||
// * "all": Use all gitignored files
|
||||
// * "indexed": Use only the files Zed had indexed
|
||||
// * "smart": Be smart and search for ignored when called from a gitignored worktree
|
||||
"include_ignored": "smart",
|
||||
"include_ignored": "smart"
|
||||
},
|
||||
// Whether or not to remove any trailing whitespace from lines of a buffer
|
||||
// before saving it.
|
||||
@@ -1236,7 +1234,7 @@
|
||||
// Send debug info like crash reports.
|
||||
"diagnostics": true,
|
||||
// Send anonymized usage data like what languages you're using Zed with.
|
||||
"metrics": true,
|
||||
"metrics": true
|
||||
},
|
||||
// Whether to disable all AI features in Zed.
|
||||
//
|
||||
@@ -1270,7 +1268,7 @@
|
||||
"enabled": true,
|
||||
// Minimum time to wait before pulling diagnostics from the language server(s).
|
||||
// 0 turns the debounce off.
|
||||
"debounce_ms": 50,
|
||||
"debounce_ms": 50
|
||||
},
|
||||
// Settings for inline diagnostics
|
||||
"inline": {
|
||||
@@ -1288,8 +1286,8 @@
|
||||
"min_column": 0,
|
||||
// The minimum severity of the diagnostics to show inline.
|
||||
// Inherits editor's diagnostics' max severity settings when `null`.
|
||||
"max_severity": null,
|
||||
},
|
||||
"max_severity": null
|
||||
}
|
||||
},
|
||||
// Files or globs of files that will be excluded by Zed entirely. They will be skipped during file
|
||||
// scans, file searches, and not be displayed in the project file tree. Takes precedence over `file_scan_inclusions`.
|
||||
@@ -1303,7 +1301,7 @@
|
||||
"**/.DS_Store",
|
||||
"**/Thumbs.db",
|
||||
"**/.classpath",
|
||||
"**/.settings",
|
||||
"**/.settings"
|
||||
],
|
||||
// Files or globs of files that will be included by Zed, even when ignored by git. This is useful
|
||||
// for files that are not tracked by git, but are still important to your project. Note that globs
|
||||
@@ -1338,14 +1336,14 @@
|
||||
// Whether or not to display the git commit summary on the same line.
|
||||
"show_commit_summary": false,
|
||||
// The minimum column number to show the inline blame information at
|
||||
"min_column": 0,
|
||||
"min_column": 0
|
||||
},
|
||||
"blame": {
|
||||
"show_avatar": true,
|
||||
"show_avatar": true
|
||||
},
|
||||
// Control which information is shown in the branch picker.
|
||||
"branch_picker": {
|
||||
"show_author_name": true,
|
||||
"show_author_name": true
|
||||
},
|
||||
// How git hunks are displayed visually in the editor.
|
||||
// This setting can take two values:
|
||||
@@ -1357,7 +1355,7 @@
|
||||
"hunk_style": "staged_hollow",
|
||||
// Should the name or path be displayed first in the git view.
|
||||
// "path_style": "file_name_first" or "file_path_first"
|
||||
"path_style": "file_name_first",
|
||||
"path_style": "file_name_first"
|
||||
},
|
||||
// The list of custom Git hosting providers.
|
||||
"git_hosting_providers": [
|
||||
@@ -1391,7 +1389,7 @@
|
||||
"**/secrets.yml",
|
||||
"**/.zed/settings.json", // zed project settings
|
||||
"/**/zed/settings.json", // zed user settings
|
||||
"/**/zed/keymap.json",
|
||||
"/**/zed/keymap.json"
|
||||
],
|
||||
// When to show edit predictions previews in buffer.
|
||||
// This setting takes two possible values:
|
||||
@@ -1409,16 +1407,15 @@
|
||||
"copilot": {
|
||||
"enterprise_uri": null,
|
||||
"proxy": null,
|
||||
"proxy_no_verify": null,
|
||||
"proxy_no_verify": null
|
||||
},
|
||||
"codestral": {
|
||||
"api_url": "https://codestral.mistral.ai",
|
||||
"model": "codestral-latest",
|
||||
"max_tokens": 150,
|
||||
"model": null,
|
||||
"max_tokens": null
|
||||
},
|
||||
// Whether edit predictions are enabled when editing text threads in the agent panel.
|
||||
// This setting has no effect if globally disabled.
|
||||
"enabled_in_text_threads": true,
|
||||
"enabled_in_text_threads": true
|
||||
},
|
||||
// Settings specific to journaling
|
||||
"journal": {
|
||||
@@ -1428,7 +1425,7 @@
|
||||
// May take 2 values:
|
||||
// 1. hour12
|
||||
// 2. hour24
|
||||
"hour_format": "hour12",
|
||||
"hour_format": "hour12"
|
||||
},
|
||||
// Status bar-related settings.
|
||||
"status_bar": {
|
||||
@@ -1439,7 +1436,7 @@
|
||||
// Whether to show the cursor position button in the status bar.
|
||||
"cursor_position_button": true,
|
||||
// Whether to show active line endings button in the status bar.
|
||||
"line_endings_button": false,
|
||||
"line_endings_button": false
|
||||
},
|
||||
// Settings specific to the terminal
|
||||
"terminal": {
|
||||
@@ -1560,8 +1557,8 @@
|
||||
// Preferred Conda manager to use when activating Conda environments.
|
||||
// Values: "auto", "conda", "mamba", "micromamba"
|
||||
// Default: "auto"
|
||||
"conda_manager": "auto",
|
||||
},
|
||||
"conda_manager": "auto"
|
||||
}
|
||||
},
|
||||
"toolbar": {
|
||||
// Whether to display the terminal title in its toolbar's breadcrumbs.
|
||||
@@ -1569,7 +1566,7 @@
|
||||
//
|
||||
// The shell running in the terminal needs to be configured to emit the title.
|
||||
// Example: `echo -e "\e]2;New Title\007";`
|
||||
"breadcrumbs": false,
|
||||
"breadcrumbs": false
|
||||
},
|
||||
// Scrollbar-related settings
|
||||
"scrollbar": {
|
||||
@@ -1586,7 +1583,7 @@
|
||||
// "always"
|
||||
// 5. Never show the scrollbar:
|
||||
// "never"
|
||||
"show": null,
|
||||
"show": null
|
||||
},
|
||||
// Set the terminal's font size. If this option is not included,
|
||||
// the terminal will default to matching the buffer's font size.
|
||||
@@ -1649,26 +1646,30 @@
|
||||
// surrounding symbols or quotes
|
||||
[
|
||||
"(?x)",
|
||||
"(?<path>",
|
||||
" (",
|
||||
" # multi-char path: first char (not opening delimiter or space)",
|
||||
" [^({\\[<\"'`\\ ]",
|
||||
" # middle chars: non-space, and colon/paren only if not followed by digit/paren",
|
||||
" ([^\\ :(]|[:(][^0-9()])*",
|
||||
" # last char: not closing delimiter or colon",
|
||||
" [^()}\\]>\"'`.,;:\\ ]",
|
||||
" |",
|
||||
" # single-char path: not delimiter, punctuation, or space",
|
||||
" [^(){}\\[\\]<>\"'`.,;:\\ ]",
|
||||
" )",
|
||||
" # optional line/column suffix (included in path for PathWithPosition::parse_str)",
|
||||
" (:+[0-9]+(:[0-9]+)?|:?\\([0-9]+([,:]?[0-9]+)?\\))?",
|
||||
")",
|
||||
],
|
||||
"# optionally starts with 0-2 opening prefix symbols",
|
||||
"[({\\[<]{0,2}",
|
||||
"# which may be followed by an opening quote",
|
||||
"(?<quote>[\"'`])?",
|
||||
"# `path` is the shortest sequence of any non-space character",
|
||||
"(?<link>(?<path>[^ ]+?",
|
||||
" # which may end with a line and optionally a column,",
|
||||
" (?<line_column>:+[0-9]+(:[0-9]+)?|:?\\([0-9]+([,:][0-9]+)?\\))?",
|
||||
"))",
|
||||
"# which must be followed by a matching quote",
|
||||
"(?(<quote>)\\k<quote>)",
|
||||
"# and optionally a single closing symbol",
|
||||
"[)}\\]>]?",
|
||||
"# if line/column matched, may be followed by a description",
|
||||
"(?(<line_column>):[^ 0-9][^ ]*)?",
|
||||
"# which may be followed by trailing punctuation",
|
||||
"[.,:)}\\]>]*",
|
||||
"# and always includes trailing whitespace or end of line",
|
||||
"([ ]+|$)"
|
||||
]
|
||||
],
|
||||
// Timeout for hover and Cmd-click path hyperlink discovery in milliseconds. Specifying a
|
||||
// timeout of `0` will disable path hyperlinking in terminal.
|
||||
"path_hyperlink_timeout_ms": 1,
|
||||
"path_hyperlink_timeout_ms": 1
|
||||
},
|
||||
"code_actions_on_format": {},
|
||||
// Settings related to running tasks.
|
||||
@@ -1684,7 +1685,7 @@
|
||||
// * Zed task from history (e.g. one-off task was spawned before)
|
||||
//
|
||||
// Default: true
|
||||
"prefer_lsp": true,
|
||||
"prefer_lsp": true
|
||||
},
|
||||
// An object whose keys are language names, and whose values
|
||||
// are arrays of filenames or extensions of files that should
|
||||
@@ -1701,7 +1702,7 @@
|
||||
"file_types": {
|
||||
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "**/.vscode/**/*.json", "tsconfig*.json"],
|
||||
"Markdown": [".rules", ".cursorrules", ".windsurfrules", ".clinerules"],
|
||||
"Shell Script": [".env.*"],
|
||||
"Shell Script": [".env.*"]
|
||||
},
|
||||
// Settings for which version of Node.js and NPM to use when installing
|
||||
// language servers and Copilot.
|
||||
@@ -1717,7 +1718,7 @@
|
||||
// `path`, but not `npm_path`, Zed will assume that `npm` is located at
|
||||
// `${path}/../npm`.
|
||||
"path": null,
|
||||
"npm_path": null,
|
||||
"npm_path": null
|
||||
},
|
||||
// The extensions that Zed should automatically install on startup.
|
||||
//
|
||||
@@ -1725,6 +1726,11 @@
|
||||
// and change the value to `false`.
|
||||
"auto_install_extensions": {
|
||||
"html": true,
|
||||
"copilot-chat": true,
|
||||
"anthropic": true,
|
||||
"google-ai": true,
|
||||
"openai": true,
|
||||
"openrouter": true,
|
||||
},
|
||||
// The capabilities granted to extensions.
|
||||
//
|
||||
@@ -1732,7 +1738,7 @@
|
||||
"granted_extension_capabilities": [
|
||||
{ "kind": "process:exec", "command": "*", "args": ["**"] },
|
||||
{ "kind": "download_file", "host": "*", "path": ["**"] },
|
||||
{ "kind": "npm:install", "package": "*" },
|
||||
{ "kind": "npm:install", "package": "*" }
|
||||
],
|
||||
// Controls how completions are processed for this language.
|
||||
"completions": {
|
||||
@@ -1783,7 +1789,7 @@
|
||||
// 4. "replace_suffix"
|
||||
// Behaves like `"replace"` if the text after the cursor is a suffix of the completion, and like
|
||||
// `"insert"` otherwise.
|
||||
"lsp_insert_mode": "replace_suffix",
|
||||
"lsp_insert_mode": "replace_suffix"
|
||||
},
|
||||
// Different settings for specific languages.
|
||||
"languages": {
|
||||
@@ -1791,116 +1797,116 @@
|
||||
"language_servers": ["astro-language-server", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-astro"],
|
||||
},
|
||||
"plugins": ["prettier-plugin-astro"]
|
||||
}
|
||||
},
|
||||
"Blade": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"C": {
|
||||
"format_on_save": "off",
|
||||
"use_on_type_format": false,
|
||||
"prettier": {
|
||||
"allowed": false,
|
||||
},
|
||||
"allowed": false
|
||||
}
|
||||
},
|
||||
"C++": {
|
||||
"format_on_save": "off",
|
||||
"use_on_type_format": false,
|
||||
"prettier": {
|
||||
"allowed": false,
|
||||
},
|
||||
"allowed": false
|
||||
}
|
||||
},
|
||||
"CSharp": {
|
||||
"language_servers": ["roslyn", "!omnisharp", "..."],
|
||||
"language_servers": ["roslyn", "!omnisharp", "..."]
|
||||
},
|
||||
"CSS": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"Dart": {
|
||||
"tab_size": 2,
|
||||
"tab_size": 2
|
||||
},
|
||||
"Diff": {
|
||||
"show_edit_predictions": false,
|
||||
"remove_trailing_whitespace_on_save": false,
|
||||
"ensure_final_newline_on_save": false,
|
||||
"ensure_final_newline_on_save": false
|
||||
},
|
||||
"Elixir": {
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."],
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."]
|
||||
},
|
||||
"Elm": {
|
||||
"tab_size": 4,
|
||||
"tab_size": 4
|
||||
},
|
||||
"Erlang": {
|
||||
"language_servers": ["erlang-ls", "!elp", "..."],
|
||||
"language_servers": ["erlang-ls", "!elp", "..."]
|
||||
},
|
||||
"Git Commit": {
|
||||
"allow_rewrap": "anywhere",
|
||||
"soft_wrap": "editor_width",
|
||||
"preferred_line_length": 72,
|
||||
"preferred_line_length": 72
|
||||
},
|
||||
"Go": {
|
||||
"hard_tabs": true,
|
||||
"code_actions_on_format": {
|
||||
"source.organizeImports": true,
|
||||
"source.organizeImports": true
|
||||
},
|
||||
"debuggers": ["Delve"],
|
||||
"debuggers": ["Delve"]
|
||||
},
|
||||
"GraphQL": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"HEEX": {
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."],
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."]
|
||||
},
|
||||
"HTML": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"HTML+ERB": {
|
||||
"language_servers": ["herb", "!ruby-lsp", "..."],
|
||||
"language_servers": ["herb", "!ruby-lsp", "..."]
|
||||
},
|
||||
"Java": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-java"],
|
||||
},
|
||||
"plugins": ["prettier-plugin-java"]
|
||||
}
|
||||
},
|
||||
"JavaScript": {
|
||||
"language_servers": ["!typescript-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"JSON": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"JSONC": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"JS+ERB": {
|
||||
"language_servers": ["!ruby-lsp", "..."],
|
||||
"language_servers": ["!ruby-lsp", "..."]
|
||||
},
|
||||
"Kotlin": {
|
||||
"language_servers": ["!kotlin-language-server", "kotlin-lsp", "..."],
|
||||
"language_servers": ["!kotlin-language-server", "kotlin-lsp", "..."]
|
||||
},
|
||||
"LaTeX": {
|
||||
"formatter": "language_server",
|
||||
"language_servers": ["texlab", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-latex"],
|
||||
},
|
||||
"plugins": ["prettier-plugin-latex"]
|
||||
}
|
||||
},
|
||||
"Markdown": {
|
||||
"format_on_save": "off",
|
||||
@@ -1908,142 +1914,136 @@
|
||||
"remove_trailing_whitespace_on_save": false,
|
||||
"allow_rewrap": "anywhere",
|
||||
"soft_wrap": "editor_width",
|
||||
"completions": {
|
||||
"words": "disabled",
|
||||
},
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"PHP": {
|
||||
"language_servers": ["phpactor", "!intelephense", "!phptools", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["@prettier/plugin-php"],
|
||||
"parser": "php",
|
||||
},
|
||||
"parser": "php"
|
||||
}
|
||||
},
|
||||
"Plain Text": {
|
||||
"allow_rewrap": "anywhere",
|
||||
"soft_wrap": "editor_width",
|
||||
"completions": {
|
||||
"words": "disabled",
|
||||
},
|
||||
"soft_wrap": "editor_width"
|
||||
},
|
||||
"Python": {
|
||||
"code_actions_on_format": {
|
||||
"source.organizeImports.ruff": true,
|
||||
"source.organizeImports.ruff": true
|
||||
},
|
||||
"formatter": {
|
||||
"language_server": {
|
||||
"name": "ruff",
|
||||
},
|
||||
"name": "ruff"
|
||||
}
|
||||
},
|
||||
"debuggers": ["Debugpy"],
|
||||
"language_servers": ["basedpyright", "ruff", "!ty", "!pyrefly", "!pyright", "!pylsp", "..."],
|
||||
"language_servers": ["basedpyright", "ruff", "!ty", "!pyrefly", "!pyright", "!pylsp", "..."]
|
||||
},
|
||||
"Ruby": {
|
||||
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "!sorbet", "!steep", "..."],
|
||||
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "!sorbet", "!steep", "..."]
|
||||
},
|
||||
"Rust": {
|
||||
"debuggers": ["CodeLLDB"],
|
||||
"debuggers": ["CodeLLDB"]
|
||||
},
|
||||
"SCSS": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"Starlark": {
|
||||
"language_servers": ["starpls", "!buck2-lsp", "..."],
|
||||
"language_servers": ["starpls", "!buck2-lsp", "..."]
|
||||
},
|
||||
"Svelte": {
|
||||
"language_servers": ["svelte-language-server", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-svelte"],
|
||||
},
|
||||
"plugins": ["prettier-plugin-svelte"]
|
||||
}
|
||||
},
|
||||
"TSX": {
|
||||
"language_servers": ["!typescript-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"Twig": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"TypeScript": {
|
||||
"language_servers": ["!typescript-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"SystemVerilog": {
|
||||
"format_on_save": "off",
|
||||
"language_servers": ["!slang", "..."],
|
||||
"use_on_type_format": false,
|
||||
"use_on_type_format": false
|
||||
},
|
||||
"Vue.js": {
|
||||
"language_servers": ["vue-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"XML": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["@prettier/plugin-xml"],
|
||||
},
|
||||
"plugins": ["@prettier/plugin-xml"]
|
||||
}
|
||||
},
|
||||
"YAML": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
},
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"YAML+ERB": {
|
||||
"language_servers": ["!ruby-lsp", "..."],
|
||||
"language_servers": ["!ruby-lsp", "..."]
|
||||
},
|
||||
"Zig": {
|
||||
"language_servers": ["zls", "..."],
|
||||
},
|
||||
"language_servers": ["zls", "..."]
|
||||
}
|
||||
},
|
||||
// Different settings for specific language models.
|
||||
"language_models": {
|
||||
"anthropic": {
|
||||
"api_url": "https://api.anthropic.com",
|
||||
"api_url": "https://api.anthropic.com"
|
||||
},
|
||||
"bedrock": {},
|
||||
"google": {
|
||||
"api_url": "https://generativelanguage.googleapis.com",
|
||||
"api_url": "https://generativelanguage.googleapis.com"
|
||||
},
|
||||
"ollama": {
|
||||
"api_url": "http://localhost:11434",
|
||||
"api_url": "http://localhost:11434"
|
||||
},
|
||||
"openai": {
|
||||
"api_url": "https://api.openai.com/v1",
|
||||
"api_url": "https://api.openai.com/v1"
|
||||
},
|
||||
"openai_compatible": {},
|
||||
"open_router": {
|
||||
"api_url": "https://openrouter.ai/api/v1",
|
||||
"api_url": "https://openrouter.ai/api/v1"
|
||||
},
|
||||
"lmstudio": {
|
||||
"api_url": "http://localhost:1234/api/v0",
|
||||
"api_url": "http://localhost:1234/api/v0"
|
||||
},
|
||||
"deepseek": {
|
||||
"api_url": "https://api.deepseek.com/v1",
|
||||
"api_url": "https://api.deepseek.com/v1"
|
||||
},
|
||||
"mistral": {
|
||||
"api_url": "https://api.mistral.ai/v1",
|
||||
"api_url": "https://api.mistral.ai/v1"
|
||||
},
|
||||
"vercel": {
|
||||
"api_url": "https://api.v0.dev/v1",
|
||||
"api_url": "https://api.v0.dev/v1"
|
||||
},
|
||||
"x_ai": {
|
||||
"api_url": "https://api.x.ai/v1",
|
||||
"api_url": "https://api.x.ai/v1"
|
||||
},
|
||||
"zed.dev": {},
|
||||
"zed.dev": {}
|
||||
},
|
||||
"session": {
|
||||
// Whether or not to restore unsaved buffers on restart.
|
||||
@@ -2052,7 +2052,7 @@
|
||||
// dirty files when closing the application.
|
||||
//
|
||||
// Default: true
|
||||
"restore_unsaved_buffers": true,
|
||||
"restore_unsaved_buffers": true
|
||||
},
|
||||
// Zed's Prettier integration settings.
|
||||
// Allows to enable/disable formatting with Prettier
|
||||
@@ -2070,11 +2070,11 @@
|
||||
// "singleQuote": true
|
||||
// Forces Prettier integration to use a specific parser name when formatting files with the language
|
||||
// when set to a non-empty string.
|
||||
"parser": "",
|
||||
"parser": ""
|
||||
},
|
||||
// Settings for auto-closing of JSX tags.
|
||||
"jsx_tag_auto_close": {
|
||||
"enabled": true,
|
||||
"enabled": true
|
||||
},
|
||||
// LSP Specific settings.
|
||||
"lsp": {
|
||||
@@ -2095,19 +2095,19 @@
|
||||
// Specify the DAP name as a key here.
|
||||
"CodeLLDB": {
|
||||
"env": {
|
||||
"RUST_LOG": "info",
|
||||
},
|
||||
},
|
||||
"RUST_LOG": "info"
|
||||
}
|
||||
}
|
||||
},
|
||||
// Common language server settings.
|
||||
"global_lsp_settings": {
|
||||
// Whether to show the LSP servers button in the status bar.
|
||||
"button": true,
|
||||
"button": true
|
||||
},
|
||||
// Jupyter settings
|
||||
"jupyter": {
|
||||
"enabled": true,
|
||||
"kernel_selections": {},
|
||||
"kernel_selections": {}
|
||||
// Specify the language name as the key and the kernel name as the value.
|
||||
// "kernel_selections": {
|
||||
// "python": "conda-base"
|
||||
@@ -2121,7 +2121,7 @@
|
||||
"max_columns": 128,
|
||||
// Maximum number of lines to keep in REPL's scrollback buffer.
|
||||
// Clamped with [4, 256] range.
|
||||
"max_lines": 32,
|
||||
"max_lines": 32
|
||||
},
|
||||
// Vim settings
|
||||
"vim": {
|
||||
@@ -2135,7 +2135,7 @@
|
||||
// Specify the mode as the key and the shape as the value.
|
||||
// The mode can be one of the following: "normal", "replace", "insert", "visual".
|
||||
// The shape can be one of the following: "block", "bar", "underline", "hollow".
|
||||
"cursor_shape": {},
|
||||
"cursor_shape": {}
|
||||
},
|
||||
// The server to connect to. If the environment variable
|
||||
// ZED_SERVER_URL is set, it will override this setting.
|
||||
@@ -2168,9 +2168,9 @@
|
||||
"windows": {
|
||||
"languages": {
|
||||
"PHP": {
|
||||
"language_servers": ["intelephense", "!phpactor", "!phptools", "..."],
|
||||
},
|
||||
},
|
||||
"language_servers": ["intelephense", "!phpactor", "!phptools", "..."]
|
||||
}
|
||||
}
|
||||
},
|
||||
// Whether to show full labels in line indicator or short ones
|
||||
//
|
||||
@@ -2229,7 +2229,7 @@
|
||||
"dock": "bottom",
|
||||
"log_dap_communications": true,
|
||||
"format_dap_log_messages": true,
|
||||
"button": true,
|
||||
"button": true
|
||||
},
|
||||
// Configures any number of settings profiles that are temporarily applied on
|
||||
// top of your existing user settings when selected from
|
||||
@@ -2256,5 +2256,5 @@
|
||||
// Useful for filtering out noisy logs or enabling more verbose logging.
|
||||
//
|
||||
// Example: {"log": {"client": "warn"}}
|
||||
"log": {},
|
||||
"log": {}
|
||||
}
|
||||
|
||||
@@ -204,12 +204,21 @@ pub trait AgentModelSelector: 'static {
|
||||
}
|
||||
}
|
||||
|
||||
/// Icon for a model in the model selector.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AgentModelIcon {
|
||||
/// A built-in icon from Zed's icon set.
|
||||
Named(IconName),
|
||||
/// Path to a custom SVG icon file.
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AgentModelInfo {
|
||||
pub id: acp::ModelId,
|
||||
pub name: SharedString,
|
||||
pub description: Option<SharedString>,
|
||||
pub icon: Option<IconName>,
|
||||
pub icon: Option<AgentModelIcon>,
|
||||
}
|
||||
|
||||
impl From<acp::ModelInfo> for AgentModelInfo {
|
||||
|
||||
@@ -739,7 +739,7 @@ impl ActivityIndicator {
|
||||
extension_store.outstanding_operations().iter().next()
|
||||
{
|
||||
let (message, icon, rotate) = match operation {
|
||||
ExtensionOperation::Install => (
|
||||
ExtensionOperation::Install | ExtensionOperation::AutoInstall => (
|
||||
format!("Installing {extension_id} extension…"),
|
||||
IconName::LoadCircle,
|
||||
true,
|
||||
|
||||
@@ -18,7 +18,7 @@ pub use templates::*;
|
||||
pub use thread::*;
|
||||
pub use tools::*;
|
||||
|
||||
use acp_thread::{AcpThread, AgentModelSelector};
|
||||
use acp_thread::{AcpThread, AgentModelIcon, AgentModelSelector};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
@@ -105,7 +105,7 @@ impl LanguageModels {
|
||||
fn refresh_list(&mut self, cx: &App) {
|
||||
let providers = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.providers()
|
||||
.visible_providers()
|
||||
.into_iter()
|
||||
.filter(|provider| provider.is_authenticated(cx))
|
||||
.collect::<Vec<_>>();
|
||||
@@ -161,11 +161,16 @@ impl LanguageModels {
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
provider: &Arc<dyn LanguageModelProvider>,
|
||||
) -> acp_thread::AgentModelInfo {
|
||||
let icon = if let Some(path) = provider.icon_path() {
|
||||
Some(AgentModelIcon::Path(path))
|
||||
} else {
|
||||
Some(AgentModelIcon::Named(provider.icon()))
|
||||
};
|
||||
acp_thread::AgentModelInfo {
|
||||
id: Self::model_id(model),
|
||||
name: model.name().0,
|
||||
description: None,
|
||||
icon: Some(provider.icon()),
|
||||
icon,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1356,7 +1361,7 @@ mod internal_tests {
|
||||
id: acp::ModelId::new("fake/fake"),
|
||||
name: "Fake".into(),
|
||||
description: None,
|
||||
icon: Some(ui::IconName::ZedAssistant),
|
||||
icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
|
||||
}]
|
||||
)])
|
||||
);
|
||||
|
||||
@@ -11,6 +11,8 @@ use project::agent_server_store::AgentServerCommand;
|
||||
use serde::Deserialize;
|
||||
use settings::Settings as _;
|
||||
use task::ShellBuilder;
|
||||
#[cfg(windows)]
|
||||
use task::ShellKind;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use std::path::PathBuf;
|
||||
@@ -90,8 +92,23 @@ impl AcpConnection {
|
||||
) -> Result<Self> {
|
||||
let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
|
||||
let builder = ShellBuilder::new(&shell, cfg!(windows));
|
||||
let mut child =
|
||||
builder.build_command(Some(command.path.display().to_string()), &command.args);
|
||||
#[cfg(windows)]
|
||||
let kind = builder.kind();
|
||||
let (cmd, args) = builder.build(Some(command.path.display().to_string()), &command.args);
|
||||
|
||||
let mut child = util::command::new_smol_command(cmd);
|
||||
#[cfg(windows)]
|
||||
if kind == ShellKind::Cmd {
|
||||
use smol::process::windows::CommandExt;
|
||||
for arg in args {
|
||||
child.raw_arg(arg);
|
||||
}
|
||||
} else {
|
||||
child.args(args);
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
child.args(args);
|
||||
|
||||
child
|
||||
.envs(command.env.iter().flatten())
|
||||
.stdin(std::process::Stdio::piped())
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
|
||||
use acp_thread::AgentConnection;
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{App, SharedString, Task};
|
||||
use language_models::provider::google::GoogleLanguageModelProvider;
|
||||
use language_models::api_key_for_gemini_cli;
|
||||
use project::agent_server_store::GEMINI_NAME;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -37,11 +37,7 @@ impl AgentServer for Gemini {
|
||||
cx.spawn(async move |cx| {
|
||||
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
|
||||
|
||||
if let Some(api_key) = cx
|
||||
.update(GoogleLanguageModelProvider::api_key_for_gemini_cli)?
|
||||
.await
|
||||
.ok()
|
||||
{
|
||||
if let Some(api_key) = cx.update(api_key_for_gemini_cli)?.await.ok() {
|
||||
extra_env.insert("GEMINI_API_KEY".into(), api_key);
|
||||
}
|
||||
let (command, root_dir, login) = store
|
||||
|
||||
@@ -28,7 +28,6 @@ pub struct AgentSettings {
|
||||
pub default_height: Pixels,
|
||||
pub default_model: Option<LanguageModelSelection>,
|
||||
pub inline_assistant_model: Option<LanguageModelSelection>,
|
||||
pub inline_assistant_use_streaming_tools: bool,
|
||||
pub commit_message_model: Option<LanguageModelSelection>,
|
||||
pub thread_summary_model: Option<LanguageModelSelection>,
|
||||
pub inline_alternatives: Vec<LanguageModelSelection>,
|
||||
@@ -156,9 +155,6 @@ impl Settings for AgentSettings {
|
||||
default_height: px(agent.default_height.unwrap()),
|
||||
default_model: Some(agent.default_model.unwrap()),
|
||||
inline_assistant_model: agent.inline_assistant_model,
|
||||
inline_assistant_use_streaming_tools: agent
|
||||
.inline_assistant_use_streaming_tools
|
||||
.unwrap_or(true),
|
||||
commit_message_model: agent.commit_message_model,
|
||||
thread_summary_model: agent.thread_summary_model,
|
||||
inline_alternatives: agent.inline_alternatives.unwrap_or_default(),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{cmp::Reverse, rc::Rc, sync::Arc};
|
||||
|
||||
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use agent_servers::AgentServer;
|
||||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
@@ -292,12 +292,18 @@ impl PickerDelegate for AcpModelPickerDelegate {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1p5()
|
||||
.when_some(model_info.icon, |this, icon| {
|
||||
this.child(
|
||||
Icon::new(icon)
|
||||
.map(|this| match &model_info.icon {
|
||||
Some(AgentModelIcon::Path(path)) => this.child(
|
||||
Icon::from_external_svg(path.clone())
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small)
|
||||
)
|
||||
.size(IconSize::Small),
|
||||
),
|
||||
Some(AgentModelIcon::Named(icon)) => this.child(
|
||||
Icon::new(*icon)
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small),
|
||||
),
|
||||
None => this,
|
||||
})
|
||||
.child(Label::new(model_info.name.clone()).truncate()),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use acp_thread::{AgentModelInfo, AgentModelSelector};
|
||||
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector};
|
||||
use agent_servers::AgentServer;
|
||||
use fs::Fs;
|
||||
use gpui::{Entity, FocusHandle};
|
||||
@@ -64,7 +64,7 @@ impl Render for AcpModelSelectorPopover {
|
||||
.map(|model| model.name.clone())
|
||||
.unwrap_or_else(|| SharedString::from("Select a Model"));
|
||||
|
||||
let model_icon = model.as_ref().and_then(|model| model.icon);
|
||||
let model_icon = model.as_ref().and_then(|model| model.icon.clone());
|
||||
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
@@ -78,8 +78,15 @@ impl Render for AcpModelSelectorPopover {
|
||||
self.selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.when_some(model_icon, |this, icon| {
|
||||
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
|
||||
.when_some(model_icon, |this, icon| match icon {
|
||||
AgentModelIcon::Path(path) => this.child(
|
||||
Icon::from_external_svg(path)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall),
|
||||
),
|
||||
AgentModelIcon::Named(icon_name) => {
|
||||
this.child(Icon::new(icon_name).color(color).size(IconSize::XSmall))
|
||||
}
|
||||
})
|
||||
.child(
|
||||
Label::new(model_name)
|
||||
|
||||
@@ -34,9 +34,9 @@ use project::{
|
||||
};
|
||||
use settings::{Settings, SettingsStore, update_settings_file};
|
||||
use ui::{
|
||||
ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure, Divider,
|
||||
DividerColor, ElevationIndex, Indicator, LabelSize, PopoverMenu, Switch, Tooltip,
|
||||
WithScrollbar, prelude::*,
|
||||
Button, ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure,
|
||||
Divider, DividerColor, ElevationIndex, IconName, IconPosition, IconSize, Indicator, LabelSize,
|
||||
PopoverMenu, Switch, SwitchColor, Tooltip, WithScrollbar, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{Workspace, create_and_open_local_file};
|
||||
@@ -117,7 +117,7 @@ impl AgentConfiguration {
|
||||
}
|
||||
|
||||
fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let providers = LanguageModelRegistry::read_global(cx).providers();
|
||||
let providers = LanguageModelRegistry::read_global(cx).visible_providers();
|
||||
for provider in providers {
|
||||
self.add_provider_configuration_view(&provider, window, cx);
|
||||
}
|
||||
@@ -260,11 +260,15 @@ impl AgentConfiguration {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
.child(if let Some(icon_path) = provider.icon_path() {
|
||||
Icon::from_external_svg(icon_path)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted)
|
||||
} else {
|
||||
Icon::new(provider.icon())
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.color(Color::Muted)
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
@@ -416,7 +420,7 @@ impl AgentConfiguration {
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let providers = LanguageModelRegistry::read_global(cx).providers();
|
||||
let providers = LanguageModelRegistry::read_global(cx).visible_providers();
|
||||
|
||||
let popover_menu = PopoverMenu::new("add-provider-popover")
|
||||
.trigger(
|
||||
@@ -879,6 +883,7 @@ impl AgentConfiguration {
|
||||
.child(context_server_configuration_menu)
|
||||
.child(
|
||||
Switch::new("context-server-switch", is_running.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let context_server_manager = self.context_server_store.clone();
|
||||
let fs = self.fs.clone();
|
||||
|
||||
@@ -77,7 +77,8 @@ impl Render for AgentModelSelector {
|
||||
.map(|model| model.model.name().0)
|
||||
.unwrap_or_else(|| SharedString::from("Select a Model"));
|
||||
|
||||
let provider_icon = model.as_ref().map(|model| model.provider.icon());
|
||||
let provider_icon_path = model.as_ref().and_then(|model| model.provider.icon_path());
|
||||
let provider_icon_name = model.as_ref().map(|model| model.provider.icon());
|
||||
let color = if self.menu_handle.is_deployed() {
|
||||
Color::Accent
|
||||
} else {
|
||||
@@ -89,8 +90,17 @@ impl Render for AgentModelSelector {
|
||||
PickerPopoverMenu::new(
|
||||
self.selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
.when_some(provider_icon, |this, icon| {
|
||||
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
|
||||
.when_some(provider_icon_path.clone(), |this, icon_path| {
|
||||
this.child(
|
||||
Icon::from_external_svg(icon_path)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall),
|
||||
)
|
||||
})
|
||||
.when(provider_icon_path.is_none(), |this| {
|
||||
this.when_some(provider_icon_name, |this, icon| {
|
||||
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
|
||||
})
|
||||
})
|
||||
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.child(
|
||||
@@ -102,7 +112,7 @@ impl Render for AgentModelSelector {
|
||||
.child(
|
||||
Icon::new(IconName::ChevronDown)
|
||||
.color(color)
|
||||
.size(IconSize::Small),
|
||||
.size(IconSize::XSmall),
|
||||
),
|
||||
move |_window, cx| {
|
||||
Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx)
|
||||
|
||||
@@ -2292,7 +2292,7 @@ impl AgentPanel {
|
||||
let history_is_empty = self.history_store.read(cx).is_empty(cx);
|
||||
|
||||
let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.visible_providers()
|
||||
.iter()
|
||||
.any(|provider| {
|
||||
provider.is_authenticated(cx)
|
||||
|
||||
@@ -338,7 +338,8 @@ fn init_language_model_settings(cx: &mut App) {
|
||||
|_, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
| language_model::Event::RemovedProvider(_)
|
||||
| language_model::Event::ProvidersChanged => {
|
||||
update_active_language_model_from_settings(cx);
|
||||
}
|
||||
_ => {}
|
||||
@@ -357,26 +358,49 @@ fn update_active_language_model_from_settings(cx: &mut App) {
|
||||
}
|
||||
}
|
||||
|
||||
let default = settings.default_model.as_ref().map(to_selected_model);
|
||||
// Filter out models from providers that are not authenticated
|
||||
fn is_provider_authenticated(
|
||||
selection: &LanguageModelSelection,
|
||||
registry: &LanguageModelRegistry,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
let provider_id = LanguageModelProviderId::from(selection.provider.0.clone());
|
||||
registry
|
||||
.provider(&provider_id)
|
||||
.map_or(false, |provider| provider.is_authenticated(cx))
|
||||
}
|
||||
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
let registry_ref = registry.read(cx);
|
||||
|
||||
let default = settings
|
||||
.default_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let inline_assistant = settings
|
||||
.inline_assistant_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let commit_message = settings
|
||||
.commit_message_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let thread_summary = settings
|
||||
.thread_summary_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let inline_alternatives = settings
|
||||
.inline_alternatives
|
||||
.iter()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.select_default_model(default.as_ref(), cx);
|
||||
registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
|
||||
registry.select_commit_message_model(commit_message.as_ref(), cx);
|
||||
@@ -445,7 +469,6 @@ mod tests {
|
||||
default_height: px(600.),
|
||||
default_model: None,
|
||||
inline_assistant_model: None,
|
||||
inline_assistant_use_streaming_tools: false,
|
||||
commit_message_model: None,
|
||||
thread_summary_model: None,
|
||||
inline_alternatives: vec![],
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result};
|
||||
|
||||
use client::telemetry::Telemetry;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
|
||||
use feature_flags::{FeatureFlagAppExt as _, InlineAssistantUseToolFeatureFlag};
|
||||
use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag};
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _,
|
||||
channel::mpsc,
|
||||
future::{LocalBoxFuture, Shared},
|
||||
join,
|
||||
stream::BoxStream,
|
||||
};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
|
||||
LanguageModelToolUse, Role, TokenUsage, report_assistant_event,
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
|
||||
report_assistant_event,
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
@@ -28,7 +25,6 @@ use prompt_store::PromptBuilder;
|
||||
use rope::Rope;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings as _;
|
||||
use smol::future::FutureExt;
|
||||
use std::{
|
||||
cmp,
|
||||
@@ -50,7 +46,6 @@ pub struct FailureMessageInput {
|
||||
/// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
|
||||
///
|
||||
/// The message may use markdown formatting if you wish.
|
||||
#[serde(default)]
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
@@ -61,11 +56,9 @@ pub struct RewriteSectionInput {
|
||||
///
|
||||
/// The description may use markdown formatting if you wish.
|
||||
/// This is optional - if the edit is simple or obvious, you should leave it empty.
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
|
||||
/// The text to replace the section with.
|
||||
#[serde(default)]
|
||||
pub replacement_text: String,
|
||||
}
|
||||
|
||||
@@ -386,12 +379,6 @@ impl CodegenAlternative {
|
||||
&self.last_equal_ranges
|
||||
}
|
||||
|
||||
fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
|
||||
model.supports_streaming_tools()
|
||||
&& cx.has_flag::<InlineAssistantUseToolFeatureFlag>()
|
||||
&& AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
|
||||
}
|
||||
|
||||
pub fn start(
|
||||
&mut self,
|
||||
user_prompt: String,
|
||||
@@ -411,17 +398,11 @@ impl CodegenAlternative {
|
||||
let telemetry_id = model.telemetry_id();
|
||||
let provider_id = model.provider_id();
|
||||
|
||||
if Self::use_streaming_tools(model.as_ref(), cx) {
|
||||
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
|
||||
let request = self.build_request(&model, user_prompt, context_task, cx)?;
|
||||
let completion_events =
|
||||
cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await);
|
||||
self.generation = self.handle_completion(
|
||||
telemetry_id,
|
||||
provider_id.to_string(),
|
||||
api_key,
|
||||
completion_events,
|
||||
cx,
|
||||
);
|
||||
let tool_use =
|
||||
cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
|
||||
self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
|
||||
} else {
|
||||
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
|
||||
if user_prompt.trim().to_lowercase() == "delete" {
|
||||
@@ -433,14 +414,13 @@ impl CodegenAlternative {
|
||||
})
|
||||
.boxed_local()
|
||||
};
|
||||
self.generation =
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_request_tools(
|
||||
fn build_request_v2(
|
||||
&self,
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
user_prompt: String,
|
||||
@@ -476,7 +456,7 @@ impl CodegenAlternative {
|
||||
|
||||
let system_prompt = self
|
||||
.builder
|
||||
.generate_inline_transformation_prompt_tools(
|
||||
.generate_inline_transformation_prompt_v2(
|
||||
language_name,
|
||||
buffer,
|
||||
range.start.0..range.end.0,
|
||||
@@ -486,9 +466,6 @@ impl CodegenAlternative {
|
||||
let temperature = AgentSettings::temperature_for_model(model, cx);
|
||||
|
||||
let tool_input_format = model.tool_input_format();
|
||||
let tool_choice = model
|
||||
.supports_tool_choice(LanguageModelToolChoice::Any)
|
||||
.then_some(LanguageModelToolChoice::Any);
|
||||
|
||||
Ok(cx.spawn(async move |_cx| {
|
||||
let mut messages = vec![LanguageModelRequestMessage {
|
||||
@@ -531,7 +508,7 @@ impl CodegenAlternative {
|
||||
intent: Some(CompletionIntent::InlineAssist),
|
||||
mode: None,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature,
|
||||
messages,
|
||||
@@ -547,8 +524,8 @@ impl CodegenAlternative {
|
||||
context_task: Shared<Task<Option<LoadedContext>>>,
|
||||
cx: &mut App,
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
if Self::use_streaming_tools(model.as_ref(), cx) {
|
||||
return self.build_request_tools(model, user_prompt, context_task, cx);
|
||||
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
|
||||
return self.build_request_v2(model, user_prompt, context_task, cx);
|
||||
}
|
||||
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
@@ -626,7 +603,7 @@ impl CodegenAlternative {
|
||||
model_api_key: Option<String>,
|
||||
stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<()> {
|
||||
) {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Make a new snapshot and re-resolve anchor in case the document was modified.
|
||||
@@ -682,8 +659,7 @@ impl CodegenAlternative {
|
||||
let completion = Arc::new(Mutex::new(String::new()));
|
||||
let completion_clone = completion.clone();
|
||||
|
||||
cx.notify();
|
||||
cx.spawn(async move |codegen, cx| {
|
||||
self.generation = cx.spawn(async move |codegen, cx| {
|
||||
let stream = stream.await;
|
||||
|
||||
let token_usage = stream
|
||||
@@ -709,7 +685,6 @@ impl CodegenAlternative {
|
||||
stream?.stream.map_err(|error| error.into()),
|
||||
);
|
||||
futures::pin_mut!(chunks);
|
||||
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
||||
@@ -901,7 +876,8 @@ impl CodegenAlternative {
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn current_completion(&self) -> Option<String> {
|
||||
@@ -1084,29 +1060,21 @@ impl CodegenAlternative {
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_completion(
|
||||
fn handle_tool_use(
|
||||
&mut self,
|
||||
telemetry_id: String,
|
||||
provider_id: String,
|
||||
api_key: Option<String>,
|
||||
completion_stream: Task<
|
||||
Result<
|
||||
BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
_telemetry_id: String,
|
||||
_provider_id: String,
|
||||
_api_key: Option<String>,
|
||||
tool_use: impl 'static
|
||||
+ Future<
|
||||
Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
|
||||
>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<()> {
|
||||
) {
|
||||
self.diff = Diff::default();
|
||||
self.status = CodegenStatus::Pending;
|
||||
|
||||
cx.notify();
|
||||
// Leaving this in generation so that STOP equivalent events are respected even
|
||||
// while we're still pre-processing the completion event
|
||||
cx.spawn(async move |codegen, cx| {
|
||||
self.generation = cx.spawn(async move |codegen, cx| {
|
||||
let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
|
||||
let _ = codegen.update(cx, |this, cx| {
|
||||
this.status = status;
|
||||
@@ -1115,176 +1083,76 @@ impl CodegenAlternative {
|
||||
});
|
||||
};
|
||||
|
||||
let mut completion_events = match completion_stream.await {
|
||||
Ok(events) => events,
|
||||
Err(err) => {
|
||||
finish_with_status(CodegenStatus::Error(err.into()), cx);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let tool_use = tool_use.await;
|
||||
|
||||
let chars_read_so_far = Arc::new(Mutex::new(0usize));
|
||||
let tool_to_text_and_message =
|
||||
move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
|
||||
let mut chars_read_so_far = chars_read_so_far.lock();
|
||||
match tool_use.name.as_ref() {
|
||||
"rewrite_section" => {
|
||||
let Ok(mut input) =
|
||||
serde_json::from_value::<RewriteSectionInput>(tool_use.input)
|
||||
else {
|
||||
return (None, None);
|
||||
match tool_use {
|
||||
Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
|
||||
// Parse the input JSON into RewriteSectionInput
|
||||
match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
|
||||
Ok(input) => {
|
||||
// Store the description if non-empty
|
||||
let description = if !input.description.trim().is_empty() {
|
||||
Some(input.description.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let value = input.replacement_text[*chars_read_so_far..].to_string();
|
||||
*chars_read_so_far = input.replacement_text.len();
|
||||
(Some(value), Some(std::mem::take(&mut input.description)))
|
||||
}
|
||||
"failure_message" => {
|
||||
let Ok(mut input) =
|
||||
serde_json::from_value::<FailureMessageInput>(tool_use.input)
|
||||
else {
|
||||
return (None, None);
|
||||
};
|
||||
(None, Some(std::mem::take(&mut input.message)))
|
||||
}
|
||||
_ => (None, None),
|
||||
}
|
||||
};
|
||||
|
||||
let mut message_id = None;
|
||||
let mut first_text = None;
|
||||
let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
|
||||
let total_text = Arc::new(Mutex::new(String::new()));
|
||||
// Apply the replacement text to the buffer and compute diff
|
||||
let batch_diff_task = codegen
|
||||
.update(cx, |this, cx| {
|
||||
this.model_explanation = description.map(Into::into);
|
||||
let range = this.range.clone();
|
||||
this.apply_edits(
|
||||
std::iter::once((range, input.replacement_text)),
|
||||
cx,
|
||||
);
|
||||
this.reapply_batch_diff(cx)
|
||||
})
|
||||
.ok();
|
||||
|
||||
loop {
|
||||
if let Some(first_event) = completion_events.next().await {
|
||||
match first_event {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
|
||||
message_id = Some(id);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
|
||||
if matches!(
|
||||
tool_use.name.as_ref(),
|
||||
"rewrite_section" | "failure_message"
|
||||
) =>
|
||||
{
|
||||
let is_complete = tool_use.is_input_complete;
|
||||
let (text, message) = tool_to_text_and_message(tool_use);
|
||||
// Only update the model explanation if the tool use is complete.
|
||||
// Otherwise the UI element bounces around as it's updated.
|
||||
if is_complete {
|
||||
let _ = codegen.update(cx, |this, _cx| {
|
||||
this.model_explanation = message.map(Into::into);
|
||||
});
|
||||
// Wait for the diff computation to complete
|
||||
if let Some(diff_task) = batch_diff_task {
|
||||
diff_task.await;
|
||||
}
|
||||
first_text = text;
|
||||
if first_text.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
let mut lock = total_text.lock();
|
||||
lock.push_str(&text);
|
||||
}
|
||||
Ok(e) => {
|
||||
log::warn!("Unexpected event: {:?}", e);
|
||||
break;
|
||||
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
finish_with_status(CodegenStatus::Error(e.into()), cx);
|
||||
break;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
|
||||
// Handle failure message tool use
|
||||
match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
|
||||
Ok(input) => {
|
||||
let _ = codegen.update(cx, |this, _cx| {
|
||||
// Store the failure message as the tool description
|
||||
this.model_explanation = Some(input.message.into());
|
||||
});
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
finish_with_status(CodegenStatus::Error(e.into()), cx);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(_tool_use) => {
|
||||
// Unexpected tool.
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
finish_with_status(CodegenStatus::Error(e.into()), cx);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let Some(first_text) = first_text else {
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
};
|
||||
|
||||
let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
|
||||
|
||||
cx.spawn({
|
||||
let codegen = codegen.clone();
|
||||
async move |cx| {
|
||||
while let Some(message) = message_rx.next().await {
|
||||
let _ = codegen.update(cx, |this, _cx| {
|
||||
this.model_explanation = message;
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let move_last_token_usage = last_token_usage.clone();
|
||||
|
||||
let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
|
||||
completion_events.filter_map(move |e| {
|
||||
let tool_to_text_and_message = tool_to_text_and_message.clone();
|
||||
let last_token_usage = move_last_token_usage.clone();
|
||||
let total_text = total_text.clone();
|
||||
let mut message_tx = message_tx.clone();
|
||||
async move {
|
||||
match e {
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
|
||||
if matches!(
|
||||
tool_use.name.as_ref(),
|
||||
"rewrite_section" | "failure_message"
|
||||
) =>
|
||||
{
|
||||
let is_complete = tool_use.is_input_complete;
|
||||
let (text, message) = tool_to_text_and_message(tool_use);
|
||||
if is_complete {
|
||||
// Again only send the message when complete to not get a bouncing UI element.
|
||||
let _ = message_tx.send(message.map(Into::into)).await;
|
||||
}
|
||||
text.map(Ok)
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
None
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
let mut lock = total_text.lock();
|
||||
lock.push_str(&text);
|
||||
None
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
|
||||
e => {
|
||||
log::error!("UNEXPECTED EVENT {:?}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
));
|
||||
|
||||
let language_model_text_stream = LanguageModelTextStream {
|
||||
message_id: message_id,
|
||||
stream: text_stream,
|
||||
last_token_usage,
|
||||
};
|
||||
|
||||
let Some(task) = codegen
|
||||
.update(cx, move |codegen, cx| {
|
||||
codegen.handle_stream(
|
||||
telemetry_id,
|
||||
provider_id,
|
||||
api_key,
|
||||
async { Ok(language_model_text_stream) },
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.ok()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
task.await;
|
||||
})
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1811,7 +1679,7 @@ mod tests {
|
||||
) -> mpsc::UnboundedSender<String> {
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.generation = codegen.handle_stream(
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
String::new(),
|
||||
None,
|
||||
|
||||
@@ -1455,8 +1455,60 @@ impl InlineAssistant {
|
||||
let old_snapshot = codegen.snapshot(cx);
|
||||
let old_buffer = codegen.old_buffer(cx);
|
||||
let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
|
||||
// let model_explanation = codegen.model_explanation(cx);
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
// Update tool description block
|
||||
// if let Some(description) = model_explanation {
|
||||
// if let Some(block_id) = decorations.model_explanation {
|
||||
// editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
|
||||
// let new_block_id = editor.insert_blocks(
|
||||
// [BlockProperties {
|
||||
// style: BlockStyle::Flex,
|
||||
// placement: BlockPlacement::Below(assist.range.end),
|
||||
// height: Some(1),
|
||||
// render: Arc::new({
|
||||
// let description = description.clone();
|
||||
// move |cx| {
|
||||
// div()
|
||||
// .w_full()
|
||||
// .py_1()
|
||||
// .px_2()
|
||||
// .bg(cx.theme().colors().editor_background)
|
||||
// .border_y_1()
|
||||
// .border_color(cx.theme().status().info_border)
|
||||
// .child(
|
||||
// Label::new(description.clone())
|
||||
// .color(Color::Muted)
|
||||
// .size(LabelSize::Small),
|
||||
// )
|
||||
// .into_any_element()
|
||||
// }
|
||||
// }),
|
||||
// priority: 0,
|
||||
// }],
|
||||
// None,
|
||||
// cx,
|
||||
// );
|
||||
// decorations.model_explanation = new_block_id.into_iter().next();
|
||||
// }
|
||||
// } else if let Some(block_id) = decorations.model_explanation {
|
||||
// // Hide the block if there's no description
|
||||
// editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
|
||||
// let new_block_id = editor.insert_blocks(
|
||||
// [BlockProperties {
|
||||
// style: BlockStyle::Flex,
|
||||
// placement: BlockPlacement::Below(assist.range.end),
|
||||
// height: Some(0),
|
||||
// render: Arc::new(|_cx| div().into_any_element()),
|
||||
// priority: 0,
|
||||
// }],
|
||||
// None,
|
||||
// cx,
|
||||
// );
|
||||
// decorations.model_explanation = new_block_id.into_iter().next();
|
||||
// }
|
||||
|
||||
let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
|
||||
editor.remove_blocks(old_blocks, None, cx);
|
||||
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
use std::{cmp::Reverse, sync::Arc};
|
||||
|
||||
use collections::IndexMap;
|
||||
use futures::{StreamExt, channel::mpsc};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
|
||||
use gpui::{
|
||||
Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
|
||||
};
|
||||
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
|
||||
use language_model::{
|
||||
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
|
||||
LanguageModelRegistry,
|
||||
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
@@ -47,7 +46,9 @@ pub fn language_model_selector(
|
||||
}
|
||||
|
||||
fn all_models(cx: &App) -> GroupedModels {
|
||||
let providers = LanguageModelRegistry::global(cx).read(cx).providers();
|
||||
let providers = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.visible_providers();
|
||||
|
||||
let recommended = providers
|
||||
.iter()
|
||||
@@ -57,12 +58,12 @@ fn all_models(cx: &App) -> GroupedModels {
|
||||
.into_iter()
|
||||
.map(|model| ModelInfo {
|
||||
model,
|
||||
icon: provider.icon(),
|
||||
icon: ProviderIcon::from_provider(provider.as_ref()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let all = providers
|
||||
let all: Vec<ModelInfo> = providers
|
||||
.iter()
|
||||
.flat_map(|provider| {
|
||||
provider
|
||||
@@ -70,7 +71,7 @@ fn all_models(cx: &App) -> GroupedModels {
|
||||
.into_iter()
|
||||
.map(|model| ModelInfo {
|
||||
model,
|
||||
icon: provider.icon(),
|
||||
icon: ProviderIcon::from_provider(provider.as_ref()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -78,10 +79,26 @@ fn all_models(cx: &App) -> GroupedModels {
|
||||
GroupedModels::new(all, recommended)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum ProviderIcon {
|
||||
Name(IconName),
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
impl ProviderIcon {
|
||||
fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
|
||||
if let Some(path) = provider.icon_path() {
|
||||
Self::Path(path)
|
||||
} else {
|
||||
Self::Name(provider.icon())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ModelInfo {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
icon: IconName,
|
||||
icon: ProviderIcon,
|
||||
}
|
||||
|
||||
pub struct LanguageModelPickerDelegate {
|
||||
@@ -91,7 +108,7 @@ pub struct LanguageModelPickerDelegate {
|
||||
filtered_entries: Vec<LanguageModelPickerEntry>,
|
||||
selected_index: usize,
|
||||
_authenticate_all_providers_task: Task<()>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
_refresh_models_task: Task<()>,
|
||||
popover_styles: bool,
|
||||
focus_handle: FocusHandle,
|
||||
}
|
||||
@@ -116,24 +133,43 @@ impl LanguageModelPickerDelegate {
|
||||
filtered_entries: entries,
|
||||
get_active_model: Arc::new(get_active_model),
|
||||
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
|
||||
_subscriptions: vec![cx.subscribe_in(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
window,
|
||||
|picker, _, event, window, cx| {
|
||||
match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
let query = picker.query(cx);
|
||||
picker.delegate.all_models = Arc::new(all_models(cx));
|
||||
// Update matches will automatically drop the previous task
|
||||
// if we get a provider event again
|
||||
picker.update_matches(query, window, cx)
|
||||
}
|
||||
_ => {}
|
||||
_refresh_models_task: {
|
||||
// Create a channel to signal when models need refreshing
|
||||
let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
|
||||
|
||||
// Subscribe to registry events and send refresh signals through the channel
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_) => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
},
|
||||
)],
|
||||
language_model::Event::AddedProvider(_) => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
language_model::Event::RemovedProvider(_) => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
language_model::Event::ProvidersChanged => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
.detach();
|
||||
|
||||
// Spawn a task that listens for refresh signals and updates the picker
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
while let Some(()) = refresh_rx.next().await {
|
||||
let result = this.update_in(cx, |picker, window, cx| {
|
||||
picker.delegate.all_models = Arc::new(all_models(cx));
|
||||
picker.refresh(window, cx);
|
||||
});
|
||||
if result.is_err() {
|
||||
// Picker was dropped, exit the loop
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
popover_styles,
|
||||
focus_handle,
|
||||
}
|
||||
@@ -392,7 +428,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
|
||||
let configured_providers = language_model_registry
|
||||
.read(cx)
|
||||
.providers()
|
||||
.visible_providers()
|
||||
.into_iter()
|
||||
.filter(|provider| provider.is_authenticated(cx))
|
||||
.collect::<Vec<_>>();
|
||||
@@ -504,11 +540,16 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(model_info.icon)
|
||||
.child(match &model_info.icon {
|
||||
ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small),
|
||||
)
|
||||
ProviderIcon::Path(icon_path) => {
|
||||
Icon::from_external_svg(icon_path.clone())
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small)
|
||||
}
|
||||
})
|
||||
.child(Label::new(model_info.model.name().0).truncate()),
|
||||
)
|
||||
.end_slot(div().pr_3().when(is_selected, |this| {
|
||||
@@ -657,7 +698,7 @@ mod tests {
|
||||
.into_iter()
|
||||
.map(|(provider, name)| ModelInfo {
|
||||
model: Arc::new(TestLanguageModel::new(name, provider)),
|
||||
icon: IconName::Ai,
|
||||
icon: ProviderIcon::Name(IconName::Ai),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -1682,98 +1682,6 @@ impl TextThreadEditor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let editor_clipboard_selections = cx
|
||||
.read_from_clipboard()
|
||||
.and_then(|item| item.entries().first().cloned())
|
||||
.and_then(|entry| match entry {
|
||||
ClipboardEntry::String(text) => {
|
||||
text.metadata_json::<Vec<editor::ClipboardSelection>>()
|
||||
}
|
||||
_ => None,
|
||||
});
|
||||
|
||||
let has_file_context = editor_clipboard_selections
|
||||
.as_ref()
|
||||
.is_some_and(|selections| {
|
||||
selections
|
||||
.iter()
|
||||
.any(|sel| sel.file_path.is_some() && sel.line_range.is_some())
|
||||
});
|
||||
|
||||
if has_file_context {
|
||||
if let Some(clipboard_item) = cx.read_from_clipboard() {
|
||||
if let Some(ClipboardEntry::String(clipboard_text)) =
|
||||
clipboard_item.entries().first()
|
||||
{
|
||||
if let Some(selections) = editor_clipboard_selections {
|
||||
cx.stop_propagation();
|
||||
|
||||
let text = clipboard_text.text();
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
let mut current_offset = 0;
|
||||
let weak_editor = cx.entity().downgrade();
|
||||
|
||||
for selection in selections {
|
||||
if let (Some(file_path), Some(line_range)) =
|
||||
(selection.file_path, selection.line_range)
|
||||
{
|
||||
let selected_text =
|
||||
&text[current_offset..current_offset + selection.len];
|
||||
let fence = assistant_slash_commands::codeblock_fence_for_path(
|
||||
file_path.to_str(),
|
||||
Some(line_range.clone()),
|
||||
);
|
||||
let formatted_text = format!("{fence}{selected_text}\n```");
|
||||
|
||||
let insert_point = editor
|
||||
.selections
|
||||
.newest::<Point>(&editor.display_snapshot(cx))
|
||||
.head();
|
||||
let start_row = MultiBufferRow(insert_point.row);
|
||||
|
||||
editor.insert(&formatted_text, window, cx);
|
||||
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
let anchor_before = snapshot.anchor_after(insert_point);
|
||||
let anchor_after = editor
|
||||
.selections
|
||||
.newest_anchor()
|
||||
.head()
|
||||
.bias_left(&snapshot);
|
||||
|
||||
editor.insert("\n", window, cx);
|
||||
|
||||
let crease_text = acp_thread::selection_name(
|
||||
Some(file_path.as_ref()),
|
||||
&line_range,
|
||||
);
|
||||
|
||||
let fold_placeholder = quote_selection_fold_placeholder(
|
||||
crease_text,
|
||||
weak_editor.clone(),
|
||||
);
|
||||
let crease = Crease::inline(
|
||||
anchor_before..anchor_after,
|
||||
fold_placeholder,
|
||||
render_quote_selection_output_toggle,
|
||||
|_, _, _, _| Empty.into_any(),
|
||||
);
|
||||
editor.insert_creases(vec![crease], cx);
|
||||
editor.fold_at(start_row, window, cx);
|
||||
|
||||
current_offset += selection.len;
|
||||
if !selection.is_entire_line && current_offset < text.len() {
|
||||
current_offset += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cx.stop_propagation();
|
||||
|
||||
let mut images = if let Some(item) = cx.read_from_clipboard() {
|
||||
@@ -2189,7 +2097,8 @@ impl TextThreadEditor {
|
||||
.default_model()
|
||||
.map(|default| default.provider);
|
||||
|
||||
let provider_icon = match active_provider {
|
||||
let provider_icon_path = active_provider.as_ref().and_then(|p| p.icon_path());
|
||||
let provider_icon_name = match &active_provider {
|
||||
Some(provider) => provider.icon(),
|
||||
None => IconName::Ai,
|
||||
};
|
||||
@@ -2201,6 +2110,16 @@ impl TextThreadEditor {
|
||||
(Color::Muted, IconName::ChevronDown)
|
||||
};
|
||||
|
||||
let provider_icon_element = if let Some(icon_path) = provider_icon_path {
|
||||
Icon::from_external_svg(icon_path)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall)
|
||||
} else {
|
||||
Icon::new(provider_icon_name)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall)
|
||||
};
|
||||
|
||||
PickerPopoverMenu::new(
|
||||
self.language_model_selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
@@ -2208,7 +2127,7 @@ impl TextThreadEditor {
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_0p5()
|
||||
.child(Icon::new(provider_icon).color(color).size(IconSize::XSmall))
|
||||
.child(provider_icon_element)
|
||||
.child(
|
||||
Label::new(model_name)
|
||||
.color(color)
|
||||
|
||||
@@ -1,9 +1,25 @@
|
||||
use gpui::{Action, IntoElement, ParentElement, RenderOnce, point};
|
||||
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
|
||||
use language_model::{LanguageModelProvider, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
|
||||
use ui::{Divider, List, ListBulletItem, prelude::*};
|
||||
|
||||
#[derive(Clone)]
|
||||
enum ProviderIcon {
|
||||
Name(IconName),
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
impl ProviderIcon {
|
||||
fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
|
||||
if let Some(path) = provider.icon_path() {
|
||||
Self::Path(path)
|
||||
} else {
|
||||
Self::Name(provider.icon())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ApiKeysWithProviders {
|
||||
configured_providers: Vec<(IconName, SharedString)>,
|
||||
configured_providers: Vec<(ProviderIcon, SharedString)>,
|
||||
}
|
||||
|
||||
impl ApiKeysWithProviders {
|
||||
@@ -13,7 +29,8 @@ impl ApiKeysWithProviders {
|
||||
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
| language_model::Event::RemovedProvider(_)
|
||||
| language_model::Event::ProvidersChanged => {
|
||||
this.configured_providers = Self::compute_configured_providers(cx)
|
||||
}
|
||||
_ => {}
|
||||
@@ -26,14 +43,19 @@ impl ApiKeysWithProviders {
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> {
|
||||
fn compute_configured_providers(cx: &App) -> Vec<(ProviderIcon, SharedString)> {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.visible_providers()
|
||||
.iter()
|
||||
.filter(|provider| {
|
||||
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
|
||||
})
|
||||
.map(|provider| (provider.icon(), provider.name().0))
|
||||
.map(|provider| {
|
||||
(
|
||||
ProviderIcon::from_provider(provider.as_ref()),
|
||||
provider.name().0,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
@@ -47,7 +69,14 @@ impl Render for ApiKeysWithProviders {
|
||||
.map(|(icon, name)| {
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
|
||||
.child(match icon {
|
||||
ProviderIcon::Name(icon_name) => Icon::new(icon_name)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
ProviderIcon::Path(icon_path) => Icon::from_external_svg(icon_path)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
})
|
||||
.child(Label::new(name))
|
||||
});
|
||||
div()
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding};
|
||||
pub struct AgentPanelOnboarding {
|
||||
user_store: Entity<UserStore>,
|
||||
client: Arc<Client>,
|
||||
configured_providers: Vec<(IconName, SharedString)>,
|
||||
has_configured_providers: bool,
|
||||
continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
}
|
||||
|
||||
@@ -27,8 +27,9 @@ impl AgentPanelOnboarding {
|
||||
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
this.configured_providers = Self::compute_available_providers(cx)
|
||||
| language_model::Event::RemovedProvider(_)
|
||||
| language_model::Event::ProvidersChanged => {
|
||||
this.has_configured_providers = Self::has_configured_providers(cx)
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
@@ -38,20 +39,16 @@ impl AgentPanelOnboarding {
|
||||
Self {
|
||||
user_store,
|
||||
client,
|
||||
configured_providers: Self::compute_available_providers(cx),
|
||||
has_configured_providers: Self::has_configured_providers(cx),
|
||||
continue_with_zed_ai: Arc::new(continue_with_zed_ai),
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> {
|
||||
fn has_configured_providers(cx: &App) -> bool {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.visible_providers()
|
||||
.iter()
|
||||
.filter(|provider| {
|
||||
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
|
||||
})
|
||||
.map(|provider| (provider.icon(), provider.name().0))
|
||||
.collect()
|
||||
.any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +78,7 @@ impl Render for AgentPanelOnboarding {
|
||||
}),
|
||||
)
|
||||
.map(|this| {
|
||||
if enrolled_in_trial || is_pro_user || !self.configured_providers.is_empty() {
|
||||
if enrolled_in_trial || is_pro_user || self.has_configured_providers {
|
||||
this
|
||||
} else {
|
||||
this.child(ApiKeysWithoutProviders::new())
|
||||
|
||||
@@ -8,7 +8,7 @@ use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::B
|
||||
use http_client::http::{self, HeaderMap, HeaderValue};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode};
|
||||
pub use settings::ModelMode;
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -429,24 +429,10 @@ impl Model {
|
||||
let mut headers = vec![];
|
||||
|
||||
match self {
|
||||
Self::ClaudeOpus4
|
||||
| Self::ClaudeOpus4_1
|
||||
| Self::ClaudeOpus4_5
|
||||
| Self::ClaudeSonnet4
|
||||
| Self::ClaudeSonnet4_5
|
||||
| Self::ClaudeOpus4Thinking
|
||||
| Self::ClaudeOpus4_1Thinking
|
||||
| Self::ClaudeOpus4_5Thinking
|
||||
| Self::ClaudeSonnet4Thinking
|
||||
| Self::ClaudeSonnet4_5Thinking => {
|
||||
// Fine-grained tool streaming for newer models
|
||||
headers.push("fine-grained-tool-streaming-2025-05-14".to_string());
|
||||
}
|
||||
Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => {
|
||||
// Try beta token-efficient tool use (supported in Claude 3.7 Sonnet only)
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use
|
||||
headers.push("token-efficient-tools-2025-02-19".to_string());
|
||||
headers.push("fine-grained-tool-streaming-2025-05-14".to_string());
|
||||
}
|
||||
Self::Custom {
|
||||
extra_beta_headers, ..
|
||||
|
||||
@@ -371,8 +371,6 @@ pub struct LanguageModel {
|
||||
pub supports_images: bool,
|
||||
pub supports_thinking: bool,
|
||||
pub supports_max_mode: bool,
|
||||
#[serde(default)]
|
||||
pub supports_streaming_tools: bool,
|
||||
// only used by OpenAI and xAI
|
||||
#[serde(default)]
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
|
||||
@@ -33,10 +33,12 @@ impl StdioTransport {
|
||||
) -> Result<Self> {
|
||||
let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
|
||||
let builder = ShellBuilder::new(&shell, cfg!(windows));
|
||||
let mut command =
|
||||
builder.build_command(Some(binary.executable.display().to_string()), &binary.args);
|
||||
let (command, args) =
|
||||
builder.build(Some(binary.executable.display().to_string()), &binary.args);
|
||||
|
||||
let mut command = util::command::new_smol_command(command);
|
||||
command
|
||||
.args(args)
|
||||
.envs(binary.env.unwrap_or_default())
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
|
||||
@@ -4,7 +4,7 @@ pub mod copilot_responses;
|
||||
pub mod request;
|
||||
mod sign_in;
|
||||
|
||||
use crate::sign_in::initiate_sign_out;
|
||||
use crate::sign_in::initiate_sign_in_within_workspace;
|
||||
use ::fs::Fs;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
@@ -28,10 +28,12 @@ use project::DisableAiSettings;
|
||||
use request::StatusNotification;
|
||||
use semver::Version;
|
||||
use serde_json::json;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use settings::Settings;
|
||||
use settings::SettingsStore;
|
||||
use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::hash_map::Entry,
|
||||
env,
|
||||
ffi::OsString,
|
||||
mem,
|
||||
@@ -40,14 +42,12 @@ use std::{
|
||||
sync::Arc,
|
||||
};
|
||||
use sum_tree::Dimensions;
|
||||
use util::{ResultExt, fs::remove_matching, rel_path::RelPath};
|
||||
use util::rel_path::RelPath;
|
||||
use util::{ResultExt, fs::remove_matching};
|
||||
use workspace::Workspace;
|
||||
|
||||
pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate;
|
||||
pub use crate::sign_in::{
|
||||
ConfigurationMode, ConfigurationView, CopilotCodeVerification, initiate_sign_in,
|
||||
reinstall_and_sign_in,
|
||||
};
|
||||
pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
|
||||
|
||||
actions!(
|
||||
copilot,
|
||||
@@ -98,14 +98,21 @@ pub fn init(
|
||||
.detach();
|
||||
|
||||
cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
|
||||
workspace.register_action(|_, _: &SignIn, window, cx| {
|
||||
initiate_sign_in(window, cx);
|
||||
workspace.register_action(|workspace, _: &SignIn, window, cx| {
|
||||
if let Some(copilot) = Copilot::global(cx) {
|
||||
let is_reinstall = false;
|
||||
initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
|
||||
}
|
||||
});
|
||||
workspace.register_action(|_, _: &Reinstall, window, cx| {
|
||||
reinstall_and_sign_in(window, cx);
|
||||
workspace.register_action(|workspace, _: &Reinstall, window, cx| {
|
||||
if let Some(copilot) = Copilot::global(cx) {
|
||||
reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
|
||||
}
|
||||
});
|
||||
workspace.register_action(|_, _: &SignOut, window, cx| {
|
||||
initiate_sign_out(window, cx);
|
||||
workspace.register_action(|workspace, _: &SignOut, _window, cx| {
|
||||
if let Some(copilot) = Copilot::global(cx) {
|
||||
sign_out_within_workspace(workspace, copilot, cx);
|
||||
}
|
||||
});
|
||||
})
|
||||
.detach();
|
||||
@@ -368,7 +375,7 @@ impl Copilot {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_copilot(
|
||||
fn start_copilot(
|
||||
&mut self,
|
||||
check_edit_prediction_provider: bool,
|
||||
awaiting_sign_in_after_start: bool,
|
||||
@@ -556,14 +563,6 @@ impl Copilot {
|
||||
let server = start_language_server.await;
|
||||
this.update(cx, |this, cx| {
|
||||
cx.notify();
|
||||
|
||||
if env::var("ZED_FORCE_COPILOT_ERROR").is_ok() {
|
||||
this.server = CopilotServer::Error(
|
||||
"Forced error for testing (ZED_FORCE_COPILOT_ERROR)".into(),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
match server {
|
||||
Ok((server, status)) => {
|
||||
this.server = CopilotServer::Running(RunningCopilotServer {
|
||||
@@ -585,17 +584,7 @@ impl Copilot {
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
return matches!(
|
||||
self.server,
|
||||
CopilotServer::Running(RunningCopilotServer {
|
||||
sign_in_status: SignInStatus::Authorized,
|
||||
..
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
pub fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
pub(crate) fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
if let CopilotServer::Running(server) = &mut self.server {
|
||||
let task = match &server.sign_in_status {
|
||||
SignInStatus::Authorized => Task::ready(Ok(())).shared(),
|
||||
|
||||
@@ -1,151 +1,160 @@
|
||||
use crate::{Copilot, Status, request::PromptUserDeviceFlow};
|
||||
use anyhow::Context as _;
|
||||
use gpui::{
|
||||
App, ClipboardItem, Context, DismissEvent, Element, Entity, EventEmitter, FocusHandle,
|
||||
Focusable, InteractiveElement, IntoElement, MouseDownEvent, ParentElement, Render, Styled,
|
||||
Subscription, Window, WindowBounds, WindowOptions, div, point,
|
||||
Animation, AnimationExt, App, ClipboardItem, Context, DismissEvent, Element, Entity,
|
||||
EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, MouseDownEvent,
|
||||
ParentElement, Render, Styled, Subscription, Transformation, Window, div, percentage, svg,
|
||||
};
|
||||
use ui::{ButtonLike, CommonAnimationExt, ConfiguredApiCard, Vector, VectorName, prelude::*};
|
||||
use std::time::Duration;
|
||||
use ui::{Button, Label, Vector, VectorName, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{Toast, Workspace, notifications::NotificationId};
|
||||
use workspace::notifications::NotificationId;
|
||||
use workspace::{ModalView, Toast, Workspace};
|
||||
|
||||
const COPILOT_SIGN_UP_URL: &str = "https://github.com/features/copilot";
|
||||
const ERROR_LABEL: &str =
|
||||
"Copilot had issues starting. You can try reinstalling it and signing in again.";
|
||||
|
||||
struct CopilotStatusToast;
|
||||
|
||||
pub fn initiate_sign_in(window: &mut Window, cx: &mut App) {
|
||||
let is_reinstall = false;
|
||||
initiate_sign_in_impl(is_reinstall, window, cx)
|
||||
}
|
||||
|
||||
pub fn initiate_sign_out(window: &mut Window, cx: &mut App) {
|
||||
let Some(copilot) = Copilot::global(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
copilot_toast(Some("Signing out of Copilot…"), window, cx);
|
||||
|
||||
let sign_out_task = copilot.update(cx, |copilot, cx| copilot.sign_out(cx));
|
||||
window
|
||||
.spawn(cx, async move |cx| match sign_out_task.await {
|
||||
Ok(()) => {
|
||||
cx.update(|window, cx| copilot_toast(Some("Signed out of Copilot"), window, cx))
|
||||
}
|
||||
Err(err) => cx.update(|window, cx| {
|
||||
if let Some(workspace) = window.root::<Workspace>().flatten() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
workspace.show_error(&err, cx);
|
||||
})
|
||||
} else {
|
||||
log::error!("{:?}", err);
|
||||
}
|
||||
}),
|
||||
})
|
||||
.detach();
|
||||
let Some(workspace) = window.root::<Workspace>().flatten() else {
|
||||
return;
|
||||
};
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
let is_reinstall = false;
|
||||
initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx)
|
||||
});
|
||||
}
|
||||
|
||||
pub fn reinstall_and_sign_in(window: &mut Window, cx: &mut App) {
|
||||
let Some(copilot) = Copilot::global(cx) else {
|
||||
return;
|
||||
};
|
||||
let _ = copilot.update(cx, |copilot, cx| copilot.reinstall(cx));
|
||||
let is_reinstall = true;
|
||||
initiate_sign_in_impl(is_reinstall, window, cx);
|
||||
}
|
||||
|
||||
fn open_copilot_code_verification_window(copilot: &Entity<Copilot>, window: &Window, cx: &mut App) {
|
||||
let current_window_center = window.bounds().center();
|
||||
let height = px(450.);
|
||||
let width = px(350.);
|
||||
let window_bounds = WindowBounds::Windowed(gpui::bounds(
|
||||
current_window_center - point(height / 2.0, width / 2.0),
|
||||
gpui::size(height, width),
|
||||
));
|
||||
cx.open_window(
|
||||
WindowOptions {
|
||||
kind: gpui::WindowKind::PopUp,
|
||||
window_bounds: Some(window_bounds),
|
||||
is_resizable: false,
|
||||
is_movable: true,
|
||||
titlebar: Some(gpui::TitlebarOptions {
|
||||
appears_transparent: true,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
|window, cx| cx.new(|cx| CopilotCodeVerification::new(&copilot, window, cx)),
|
||||
)
|
||||
.context("Failed to open Copilot code verification window")
|
||||
.log_err();
|
||||
}
|
||||
|
||||
fn copilot_toast(message: Option<&'static str>, window: &Window, cx: &mut App) {
|
||||
const NOTIFICATION_ID: NotificationId = NotificationId::unique::<CopilotStatusToast>();
|
||||
|
||||
let Some(workspace) = window.root::<Workspace>().flatten() else {
|
||||
return;
|
||||
};
|
||||
|
||||
workspace.update(cx, |workspace, cx| match message {
|
||||
Some(message) => workspace.show_toast(Toast::new(NOTIFICATION_ID, message), cx),
|
||||
None => workspace.dismiss_toast(&NOTIFICATION_ID, cx),
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn initiate_sign_in_impl(is_reinstall: bool, window: &mut Window, cx: &mut App) {
|
||||
let Some(copilot) = Copilot::global(cx) else {
|
||||
return;
|
||||
};
|
||||
pub fn reinstall_and_sign_in_within_workspace(
|
||||
workspace: &mut Workspace,
|
||||
copilot: Entity<Copilot>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
let _ = copilot.update(cx, |copilot, cx| copilot.reinstall(cx));
|
||||
let is_reinstall = true;
|
||||
initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
|
||||
}
|
||||
|
||||
pub fn initiate_sign_in_within_workspace(
|
||||
workspace: &mut Workspace,
|
||||
copilot: Entity<Copilot>,
|
||||
is_reinstall: bool,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
if matches!(copilot.read(cx).status(), Status::Disabled) {
|
||||
copilot.update(cx, |copilot, cx| copilot.start_copilot(false, true, cx));
|
||||
}
|
||||
match copilot.read(cx).status() {
|
||||
Status::Starting { task } => {
|
||||
copilot_toast(
|
||||
Some(if is_reinstall {
|
||||
"Copilot is reinstalling…"
|
||||
} else {
|
||||
"Copilot is starting…"
|
||||
}),
|
||||
window,
|
||||
workspace.show_toast(
|
||||
Toast::new(
|
||||
NotificationId::unique::<CopilotStatusToast>(),
|
||||
if is_reinstall {
|
||||
"Copilot is reinstalling..."
|
||||
} else {
|
||||
"Copilot is starting..."
|
||||
},
|
||||
),
|
||||
cx,
|
||||
);
|
||||
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
task.await;
|
||||
cx.update(|window, cx| {
|
||||
let Some(copilot) = Copilot::global(cx) else {
|
||||
return;
|
||||
};
|
||||
match copilot.read(cx).status() {
|
||||
Status::Authorized => {
|
||||
copilot_toast(Some("Copilot has started."), window, cx)
|
||||
cx.spawn_in(window, async move |workspace, cx| {
|
||||
task.await;
|
||||
if let Some(copilot) = cx.update(|_window, cx| Copilot::global(cx)).ok().flatten() {
|
||||
workspace
|
||||
.update_in(cx, |workspace, window, cx| {
|
||||
match copilot.read(cx).status() {
|
||||
Status::Authorized => workspace.show_toast(
|
||||
Toast::new(
|
||||
NotificationId::unique::<CopilotStatusToast>(),
|
||||
"Copilot has started.",
|
||||
),
|
||||
cx,
|
||||
),
|
||||
_ => {
|
||||
workspace.dismiss_toast(
|
||||
&NotificationId::unique::<CopilotStatusToast>(),
|
||||
cx,
|
||||
);
|
||||
copilot
|
||||
.update(cx, |copilot, cx| copilot.sign_in(cx))
|
||||
.detach_and_log_err(cx);
|
||||
workspace.toggle_modal(window, cx, |_, cx| {
|
||||
CopilotCodeVerification::new(&copilot, cx)
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
copilot_toast(None, window, cx);
|
||||
copilot
|
||||
.update(cx, |copilot, cx| copilot.sign_in(cx))
|
||||
.detach_and_log_err(cx);
|
||||
open_copilot_code_verification_window(&copilot, window, cx);
|
||||
}
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
})
|
||||
.detach();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
_ => {
|
||||
copilot
|
||||
.update(cx, |copilot, cx| copilot.sign_in(cx))
|
||||
.detach();
|
||||
open_copilot_code_verification_window(&copilot, window, cx);
|
||||
workspace.toggle_modal(window, cx, |_, cx| {
|
||||
CopilotCodeVerification::new(&copilot, cx)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sign_out_within_workspace(
|
||||
workspace: &mut Workspace,
|
||||
copilot: Entity<Copilot>,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
workspace.show_toast(
|
||||
Toast::new(
|
||||
NotificationId::unique::<CopilotStatusToast>(),
|
||||
"Signing out of Copilot...",
|
||||
),
|
||||
cx,
|
||||
);
|
||||
let sign_out_task = copilot.update(cx, |copilot, cx| copilot.sign_out(cx));
|
||||
cx.spawn(async move |workspace, cx| match sign_out_task.await {
|
||||
Ok(()) => {
|
||||
workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
workspace.show_toast(
|
||||
Toast::new(
|
||||
NotificationId::unique::<CopilotStatusToast>(),
|
||||
"Signed out of Copilot.",
|
||||
),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(err) => {
|
||||
workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
workspace.show_error(&err, cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub struct CopilotCodeVerification {
|
||||
status: Status,
|
||||
connect_clicked: bool,
|
||||
@@ -161,27 +170,23 @@ impl Focusable for CopilotCodeVerification {
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for CopilotCodeVerification {}
|
||||
impl ModalView for CopilotCodeVerification {
|
||||
fn on_before_dismiss(
|
||||
&mut self,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> workspace::DismissDecision {
|
||||
self.copilot.update(cx, |copilot, cx| {
|
||||
if matches!(copilot.status(), Status::SigningIn { .. }) {
|
||||
copilot.sign_out(cx).detach_and_log_err(cx);
|
||||
}
|
||||
});
|
||||
workspace::DismissDecision::Dismiss(true)
|
||||
}
|
||||
}
|
||||
|
||||
impl CopilotCodeVerification {
|
||||
pub fn new(copilot: &Entity<Copilot>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
window.on_window_should_close(cx, |window, cx| {
|
||||
if let Some(this) = window.root::<CopilotCodeVerification>().flatten() {
|
||||
this.update(cx, |this, cx| {
|
||||
this.before_dismiss(cx);
|
||||
});
|
||||
}
|
||||
true
|
||||
});
|
||||
cx.subscribe_in(
|
||||
&cx.entity(),
|
||||
window,
|
||||
|this, _, _: &DismissEvent, window, cx| {
|
||||
window.remove_window();
|
||||
this.before_dismiss(cx);
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
|
||||
pub fn new(copilot: &Entity<Copilot>, cx: &mut Context<Self>) -> Self {
|
||||
let status = copilot.read(cx).status();
|
||||
Self {
|
||||
status,
|
||||
@@ -210,45 +215,45 @@ impl CopilotCodeVerification {
|
||||
.read_from_clipboard()
|
||||
.map(|item| item.text().as_ref() == Some(&data.user_code))
|
||||
.unwrap_or(false);
|
||||
|
||||
ButtonLike::new("copy-button")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.size(ButtonSize::Medium)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.child(Label::new(data.user_code.clone()))
|
||||
.child(Label::new(if copied { "Copied!" } else { "Copy" })),
|
||||
)
|
||||
.on_click({
|
||||
h_flex()
|
||||
.w_full()
|
||||
.p_1()
|
||||
.border_1()
|
||||
.border_muted(cx)
|
||||
.rounded_sm()
|
||||
.cursor_pointer()
|
||||
.justify_between()
|
||||
.on_mouse_down(gpui::MouseButton::Left, {
|
||||
let user_code = data.user_code.clone();
|
||||
move |_, window, cx| {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(user_code.clone()));
|
||||
window.refresh();
|
||||
}
|
||||
})
|
||||
.child(div().flex_1().child(Label::new(data.user_code.clone())))
|
||||
.child(div().flex_none().px_1().child(Label::new(if copied {
|
||||
"Copied!"
|
||||
} else {
|
||||
"Copy"
|
||||
})))
|
||||
}
|
||||
|
||||
fn render_prompting_modal(
|
||||
connect_clicked: bool,
|
||||
data: &PromptUserDeviceFlow,
|
||||
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl Element {
|
||||
let connect_button_label = if connect_clicked {
|
||||
"Waiting for connection…"
|
||||
"Waiting for connection..."
|
||||
} else {
|
||||
"Connect to GitHub"
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.flex_1()
|
||||
.gap_2p5()
|
||||
.gap_2()
|
||||
.items_center()
|
||||
.text_center()
|
||||
.child(Headline::new("Use GitHub Copilot in Zed").size(HeadlineSize::Large))
|
||||
.child(Headline::new("Use GitHub Copilot in Zed.").size(HeadlineSize::Large))
|
||||
.child(
|
||||
Label::new("Using Copilot requires an active subscription on GitHub.")
|
||||
.color(Color::Muted),
|
||||
@@ -256,119 +261,83 @@ impl CopilotCodeVerification {
|
||||
.child(Self::render_device_code(data, cx))
|
||||
.child(
|
||||
Label::new("Paste this code into GitHub after clicking the button below.")
|
||||
.color(Color::Muted),
|
||||
.size(ui::LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new("connect-button", connect_button_label)
|
||||
.full_width()
|
||||
.style(ButtonStyle::Outlined)
|
||||
.size(ButtonSize::Medium)
|
||||
.on_click({
|
||||
let verification_uri = data.verification_uri.clone();
|
||||
cx.listener(move |this, _, _window, cx| {
|
||||
cx.open_url(&verification_uri);
|
||||
this.connect_clicked = true;
|
||||
})
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
Button::new("copilot-enable-cancel-button", "Cancel")
|
||||
.full_width()
|
||||
.size(ButtonSize::Medium)
|
||||
.on_click(cx.listener(|_, _, _, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
})),
|
||||
),
|
||||
Button::new("connect-button", connect_button_label)
|
||||
.on_click({
|
||||
let verification_uri = data.verification_uri.clone();
|
||||
cx.listener(move |this, _, _window, cx| {
|
||||
cx.open_url(&verification_uri);
|
||||
this.connect_clicked = true;
|
||||
})
|
||||
})
|
||||
.full_width()
|
||||
.style(ButtonStyle::Filled),
|
||||
)
|
||||
.child(
|
||||
Button::new("copilot-enable-cancel-button", "Cancel")
|
||||
.full_width()
|
||||
.on_click(cx.listener(|_, _, _, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_enabled_modal(cx: &mut Context<Self>) -> impl Element {
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.text_center()
|
||||
.justify_center()
|
||||
.child(Headline::new("Copilot Enabled!").size(HeadlineSize::Large))
|
||||
.child(Label::new("You're all set to use GitHub Copilot.").color(Color::Muted))
|
||||
.child(Label::new(
|
||||
"You can update your settings or sign out from the Copilot menu in the status bar.",
|
||||
))
|
||||
.child(
|
||||
Button::new("copilot-enabled-done-button", "Done")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Outlined)
|
||||
.size(ButtonSize::Medium)
|
||||
.on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_unauthorized_modal(cx: &mut Context<Self>) -> impl Element {
|
||||
let description = "Enable Copilot by connecting your existing license once you have subscribed or renewed your subscription.";
|
||||
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.text_center()
|
||||
.justify_center()
|
||||
.child(
|
||||
Headline::new("You must have an active GitHub Copilot subscription.")
|
||||
.size(HeadlineSize::Large),
|
||||
)
|
||||
.child(Label::new(description).color(Color::Warning))
|
||||
.child(Headline::new("You must have an active GitHub Copilot subscription.").size(HeadlineSize::Large))
|
||||
|
||||
.child(Label::new(
|
||||
"You can enable Copilot by connecting your existing license once you have subscribed or renewed your subscription.",
|
||||
).color(Color::Warning))
|
||||
.child(
|
||||
Button::new("copilot-subscribe-button", "Subscribe on GitHub")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Outlined)
|
||||
.size(ButtonSize::Medium)
|
||||
.on_click(|_, _, cx| cx.open_url(COPILOT_SIGN_UP_URL)),
|
||||
)
|
||||
.child(
|
||||
Button::new("copilot-subscribe-cancel-button", "Cancel")
|
||||
.full_width()
|
||||
.size(ButtonSize::Medium)
|
||||
.on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_error_modal(_cx: &mut Context<Self>) -> impl Element {
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.text_center()
|
||||
.justify_center()
|
||||
.child(Headline::new("An Error Happened").size(HeadlineSize::Large))
|
||||
.child(Label::new(ERROR_LABEL).color(Color::Muted))
|
||||
.child(
|
||||
Button::new("copilot-subscribe-button", "Reinstall Copilot and Sign In")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Outlined)
|
||||
.size(ButtonSize::Medium)
|
||||
.icon(IconName::Download)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.on_click(|_, window, cx| reinstall_and_sign_in(window, cx)),
|
||||
)
|
||||
}
|
||||
fn render_loading(window: &mut Window, _: &mut Context<Self>) -> impl Element {
|
||||
let loading_icon = svg()
|
||||
.size_8()
|
||||
.path(IconName::ArrowCircle.path())
|
||||
.text_color(window.text_style().color)
|
||||
.with_animation(
|
||||
"icon_circle_arrow",
|
||||
Animation::new(Duration::from_secs(2)).repeat(),
|
||||
|svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))),
|
||||
);
|
||||
|
||||
fn before_dismiss(
|
||||
&mut self,
|
||||
cx: &mut Context<'_, CopilotCodeVerification>,
|
||||
) -> workspace::DismissDecision {
|
||||
self.copilot.update(cx, |copilot, cx| {
|
||||
if matches!(copilot.status(), Status::SigningIn { .. }) {
|
||||
copilot.sign_out(cx).detach_and_log_err(cx);
|
||||
}
|
||||
});
|
||||
workspace::DismissDecision::Dismiss(true)
|
||||
h_flex().justify_center().child(loading_icon)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for CopilotCodeVerification {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let prompt = match &self.status {
|
||||
Status::SigningIn { prompt: None } => Icon::new(IconName::ArrowCircle)
|
||||
.color(Color::Muted)
|
||||
.with_rotate_animation(2)
|
||||
.into_any_element(),
|
||||
Status::SigningIn { prompt: None } => {
|
||||
Self::render_loading(window, cx).into_any_element()
|
||||
}
|
||||
Status::SigningIn {
|
||||
prompt: Some(prompt),
|
||||
} => Self::render_prompting_modal(self.connect_clicked, prompt, cx).into_any_element(),
|
||||
@@ -380,20 +349,17 @@ impl Render for CopilotCodeVerification {
|
||||
self.connect_clicked = false;
|
||||
Self::render_enabled_modal(cx).into_any_element()
|
||||
}
|
||||
Status::Error(..) => Self::render_error_modal(cx).into_any_element(),
|
||||
_ => div().into_any_element(),
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.id("copilot_code_verification")
|
||||
.id("copilot code verification")
|
||||
.track_focus(&self.focus_handle(cx))
|
||||
.size_full()
|
||||
.px_4()
|
||||
.py_8()
|
||||
.gap_2()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.elevation_3(cx)
|
||||
.w_96()
|
||||
.items_center()
|
||||
.p_4()
|
||||
.gap_2()
|
||||
.on_action(cx.listener(|_, _: &menu::Cancel, _, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
}))
|
||||
@@ -407,243 +373,3 @@ impl Render for CopilotCodeVerification {
|
||||
.child(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConfigurationView {
|
||||
copilot_status: Option<Status>,
|
||||
is_authenticated: fn(cx: &App) -> bool,
|
||||
edit_prediction: bool,
|
||||
_subscription: Option<Subscription>,
|
||||
}
|
||||
|
||||
pub enum ConfigurationMode {
|
||||
Chat,
|
||||
EditPrediction,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
pub fn new(
|
||||
is_authenticated: fn(cx: &App) -> bool,
|
||||
mode: ConfigurationMode,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let copilot = Copilot::global(cx);
|
||||
|
||||
Self {
|
||||
copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
|
||||
is_authenticated,
|
||||
edit_prediction: matches!(mode, ConfigurationMode::EditPrediction),
|
||||
_subscription: copilot.as_ref().map(|copilot| {
|
||||
cx.observe(copilot, |this, model, cx| {
|
||||
this.copilot_status = Some(model.read(cx).status());
|
||||
cx.notify();
|
||||
})
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
fn is_starting(&self) -> bool {
|
||||
matches!(&self.copilot_status, Some(Status::Starting { .. }))
|
||||
}
|
||||
|
||||
fn is_signing_in(&self) -> bool {
|
||||
matches!(
|
||||
&self.copilot_status,
|
||||
Some(Status::SigningIn { .. })
|
||||
| Some(Status::SignedOut {
|
||||
awaiting_signing_in: true
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
fn is_error(&self) -> bool {
|
||||
matches!(&self.copilot_status, Some(Status::Error(_)))
|
||||
}
|
||||
|
||||
fn has_no_status(&self) -> bool {
|
||||
self.copilot_status.is_none()
|
||||
}
|
||||
|
||||
fn loading_message(&self) -> Option<SharedString> {
|
||||
if self.is_starting() {
|
||||
Some("Starting Copilot…".into())
|
||||
} else if self.is_signing_in() {
|
||||
Some("Signing into Copilot…".into())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn render_loading_button(
|
||||
&self,
|
||||
label: impl Into<SharedString>,
|
||||
edit_prediction: bool,
|
||||
) -> impl IntoElement {
|
||||
ButtonLike::new("loading_button")
|
||||
.disabled(true)
|
||||
.style(ButtonStyle::Outlined)
|
||||
.when(edit_prediction, |this| this.size(ButtonSize::Medium))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.justify_center()
|
||||
.child(
|
||||
Icon::new(IconName::ArrowCircle)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted)
|
||||
.with_rotate_animation(4),
|
||||
)
|
||||
.child(Label::new(label)),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_sign_in_button(&self, edit_prediction: bool) -> impl IntoElement {
|
||||
let label = if edit_prediction {
|
||||
"Sign in to GitHub"
|
||||
} else {
|
||||
"Sign in to use GitHub Copilot"
|
||||
};
|
||||
|
||||
Button::new("sign_in", label)
|
||||
.map(|this| {
|
||||
if edit_prediction {
|
||||
this.size(ButtonSize::Medium)
|
||||
} else {
|
||||
this.full_width()
|
||||
}
|
||||
})
|
||||
.style(ButtonStyle::Outlined)
|
||||
.icon(IconName::Github)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.on_click(|_, window, cx| initiate_sign_in(window, cx))
|
||||
}
|
||||
|
||||
fn render_reinstall_button(&self, edit_prediction: bool) -> impl IntoElement {
|
||||
let label = if edit_prediction {
|
||||
"Reinstall and Sign in"
|
||||
} else {
|
||||
"Reinstall Copilot and Sign in"
|
||||
};
|
||||
|
||||
Button::new("reinstall_and_sign_in", label)
|
||||
.map(|this| {
|
||||
if edit_prediction {
|
||||
this.size(ButtonSize::Medium)
|
||||
} else {
|
||||
this.full_width()
|
||||
}
|
||||
})
|
||||
.style(ButtonStyle::Outlined)
|
||||
.icon(IconName::Download)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Small)
|
||||
.on_click(|_, window, cx| reinstall_and_sign_in(window, cx))
|
||||
}
|
||||
|
||||
fn render_for_edit_prediction(&self) -> impl IntoElement {
|
||||
let container = |description: SharedString, action: AnyElement| {
|
||||
h_flex()
|
||||
.pt_2p5()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.child(
|
||||
v_flex()
|
||||
.w_full()
|
||||
.max_w_1_2()
|
||||
.child(Label::new("Authenticate To Use"))
|
||||
.child(
|
||||
Label::new(description)
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.child(action)
|
||||
};
|
||||
|
||||
let start_label = "To use Copilot for edit predictions, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot subscription.".into();
|
||||
let no_status_label = "Copilot requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different edit predictions provider.".into();
|
||||
|
||||
if let Some(msg) = self.loading_message() {
|
||||
container(
|
||||
start_label,
|
||||
self.render_loading_button(msg, true).into_any_element(),
|
||||
)
|
||||
.into_any_element()
|
||||
} else if self.is_error() {
|
||||
container(
|
||||
ERROR_LABEL.into(),
|
||||
self.render_reinstall_button(true).into_any_element(),
|
||||
)
|
||||
.into_any_element()
|
||||
} else if self.has_no_status() {
|
||||
container(
|
||||
no_status_label,
|
||||
self.render_sign_in_button(true).into_any_element(),
|
||||
)
|
||||
.into_any_element()
|
||||
} else {
|
||||
container(
|
||||
start_label,
|
||||
self.render_sign_in_button(true).into_any_element(),
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
|
||||
fn render_for_chat(&self) -> impl IntoElement {
|
||||
let start_label = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
|
||||
let no_status_label = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different LLM provider.";
|
||||
|
||||
if let Some(msg) = self.loading_message() {
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(Label::new(start_label))
|
||||
.child(self.render_loading_button(msg, false))
|
||||
.into_any_element()
|
||||
} else if self.is_error() {
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(Label::new(ERROR_LABEL))
|
||||
.child(self.render_reinstall_button(false))
|
||||
.into_any_element()
|
||||
} else if self.has_no_status() {
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(Label::new(no_status_label))
|
||||
.child(self.render_sign_in_button(false))
|
||||
.into_any_element()
|
||||
} else {
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(Label::new(start_label))
|
||||
.child(self.render_sign_in_button(false))
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let is_authenticated = self.is_authenticated;
|
||||
|
||||
if is_authenticated(cx) {
|
||||
return ConfiguredApiCard::new("Authorized")
|
||||
.button_label("Sign Out")
|
||||
.on_click(|_, window, cx| {
|
||||
initiate_sign_out(window, cx);
|
||||
})
|
||||
.into_any_element();
|
||||
}
|
||||
|
||||
if self.edit_prediction {
|
||||
self.render_for_edit_prediction().into_any_element()
|
||||
} else {
|
||||
self.render_for_chat().into_any_element()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1045,47 +1045,54 @@ async fn heuristic_syntactic_expand(
|
||||
let node_range = node_start..node_end;
|
||||
let row_count = node_end.row - node_start.row + 1;
|
||||
let mut ancestor_range = None;
|
||||
cx.background_executor()
|
||||
.await_on_background(async {
|
||||
// Stop if we've exceeded the row count or reached an outline node. Then, find the interval
|
||||
// of node children which contains the query range. For example, this allows just returning
|
||||
// the header of a declaration rather than the entire declaration.
|
||||
if row_count > max_row_count || outline_range == Some(node_range.clone()) {
|
||||
let mut cursor = node.walk();
|
||||
let mut included_child_start = None;
|
||||
let mut included_child_end = None;
|
||||
let mut previous_end = node_start;
|
||||
if cursor.goto_first_child() {
|
||||
loop {
|
||||
let child_node = cursor.node();
|
||||
let child_range =
|
||||
previous_end..Point::from_ts_point(child_node.end_position());
|
||||
if included_child_start.is_none()
|
||||
&& child_range.contains(&input_range.start)
|
||||
{
|
||||
included_child_start = Some(child_range.start);
|
||||
}
|
||||
if child_range.contains(&input_range.end) {
|
||||
included_child_end = Some(child_range.end);
|
||||
}
|
||||
previous_end = child_range.end;
|
||||
if !cursor.goto_next_sibling() {
|
||||
break;
|
||||
let reached_outline_node = cx.background_executor().scoped({
|
||||
let node_range = node_range.clone();
|
||||
let outline_range = outline_range.clone();
|
||||
let ancestor_range = &mut ancestor_range;
|
||||
|scope| {
|
||||
scope.spawn(async move {
|
||||
// Stop if we've exceeded the row count or reached an outline node. Then, find the interval
|
||||
// of node children which contains the query range. For example, this allows just returning
|
||||
// the header of a declaration rather than the entire declaration.
|
||||
if row_count > max_row_count || outline_range == Some(node_range.clone()) {
|
||||
let mut cursor = node.walk();
|
||||
let mut included_child_start = None;
|
||||
let mut included_child_end = None;
|
||||
let mut previous_end = node_start;
|
||||
if cursor.goto_first_child() {
|
||||
loop {
|
||||
let child_node = cursor.node();
|
||||
let child_range =
|
||||
previous_end..Point::from_ts_point(child_node.end_position());
|
||||
if included_child_start.is_none()
|
||||
&& child_range.contains(&input_range.start)
|
||||
{
|
||||
included_child_start = Some(child_range.start);
|
||||
}
|
||||
if child_range.contains(&input_range.end) {
|
||||
included_child_end = Some(child_range.end);
|
||||
}
|
||||
previous_end = child_range.end;
|
||||
if !cursor.goto_next_sibling() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let end = included_child_end.unwrap_or(node_range.end);
|
||||
if let Some(start) = included_child_start {
|
||||
let row_count = end.row - start.row;
|
||||
if row_count < max_row_count {
|
||||
ancestor_range = Some(Some(RangeInclusive::new(start.row, end.row)));
|
||||
return;
|
||||
let end = included_child_end.unwrap_or(node_range.end);
|
||||
if let Some(start) = included_child_start {
|
||||
let row_count = end.row - start.row;
|
||||
if row_count < max_row_count {
|
||||
*ancestor_range =
|
||||
Some(Some(RangeInclusive::new(start.row, end.row)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
*ancestor_range = Some(None);
|
||||
}
|
||||
ancestor_range = Some(None);
|
||||
}
|
||||
})
|
||||
.await;
|
||||
})
|
||||
}
|
||||
});
|
||||
reached_outline_node.await;
|
||||
if let Some(node) = ancestor_range {
|
||||
return node;
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ workspace = true
|
||||
path = "src/edit_prediction.rs"
|
||||
|
||||
[features]
|
||||
cli-support = []
|
||||
eval-support = []
|
||||
|
||||
[dependencies]
|
||||
ai_onboarding.workspace = true
|
||||
@@ -23,6 +23,7 @@ client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
copilot.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
db.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
edit_prediction_context.workspace = true
|
||||
|
||||
@@ -55,7 +55,7 @@ pub mod open_ai_response;
|
||||
mod prediction;
|
||||
pub mod sweep_ai;
|
||||
|
||||
#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
|
||||
#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
|
||||
pub mod udiff;
|
||||
|
||||
mod zed_edit_prediction_delegate;
|
||||
@@ -72,7 +72,6 @@ pub use crate::prediction::EditPrediction;
|
||||
pub use crate::prediction::EditPredictionId;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
pub use crate::sweep_ai::SweepAi;
|
||||
pub use language_model::ApiKeyState;
|
||||
pub use telemetry_events::EditPredictionRating;
|
||||
pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
|
||||
|
||||
@@ -159,7 +158,7 @@ pub struct EditPredictionStore {
|
||||
use_context: bool,
|
||||
options: ZetaOptions,
|
||||
update_required: bool,
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
edit_prediction_model: EditPredictionModel,
|
||||
pub sweep_ai: SweepAi,
|
||||
@@ -284,18 +283,6 @@ impl ProjectState {
|
||||
})
|
||||
.detach()
|
||||
}
|
||||
|
||||
fn active_buffer(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Option<(Entity<Buffer>, Option<Anchor>)> {
|
||||
let project = project.read(cx);
|
||||
let active_path = project.path_for_entry(project.active_entry()?, cx)?;
|
||||
let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
|
||||
let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
|
||||
Some((active_buffer, registered_buffer.last_position))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -386,7 +373,6 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
|
||||
|
||||
struct RegisteredBuffer {
|
||||
snapshot: BufferSnapshot,
|
||||
last_position: Option<Anchor>,
|
||||
_subscriptions: [gpui::Subscription; 2],
|
||||
}
|
||||
|
||||
@@ -506,7 +492,7 @@ impl EditPredictionStore {
|
||||
},
|
||||
),
|
||||
update_required: false,
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache: None,
|
||||
edit_prediction_model: EditPredictionModel::Zeta2,
|
||||
sweep_ai: SweepAi::new(cx),
|
||||
@@ -537,15 +523,25 @@ impl EditPredictionStore {
|
||||
self.edit_prediction_model = model;
|
||||
}
|
||||
|
||||
pub fn has_sweep_api_token(&self, cx: &App) -> bool {
|
||||
self.sweep_ai.api_token.read(cx).has_key()
|
||||
pub fn has_sweep_api_token(&self) -> bool {
|
||||
self.sweep_ai
|
||||
.api_token
|
||||
.clone()
|
||||
.now_or_never()
|
||||
.flatten()
|
||||
.is_some()
|
||||
}
|
||||
|
||||
pub fn has_mercury_api_token(&self, cx: &App) -> bool {
|
||||
self.mercury.api_token.read(cx).has_key()
|
||||
pub fn has_mercury_api_token(&self) -> bool {
|
||||
self.mercury
|
||||
.api_token
|
||||
.clone()
|
||||
.now_or_never()
|
||||
.flatten()
|
||||
.is_some()
|
||||
}
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
|
||||
self.eval_cache = Some(cache);
|
||||
}
|
||||
@@ -568,20 +564,13 @@ impl EditPredictionStore {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
|
||||
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
|
||||
project_state.events.clear();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn edit_history_for_project(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project_state| project_state.events(cx))
|
||||
.map(|project_state| project_state.events.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
@@ -806,7 +795,6 @@ impl EditPredictionStore {
|
||||
let project_entity_id = project.entity_id();
|
||||
entry.insert(RegisteredBuffer {
|
||||
snapshot,
|
||||
last_position: None,
|
||||
_subscriptions: [
|
||||
cx.subscribe(buffer, {
|
||||
let project = project.downgrade();
|
||||
@@ -894,21 +882,13 @@ impl EditPredictionStore {
|
||||
});
|
||||
}
|
||||
|
||||
fn prediction_at(
|
||||
&mut self,
|
||||
fn current_prediction_for_buffer(
|
||||
&self,
|
||||
buffer: &Entity<Buffer>,
|
||||
position: Option<language::Anchor>,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Option<BufferEditPrediction<'_>> {
|
||||
let project_state = self.projects.get_mut(&project.entity_id())?;
|
||||
if let Some(position) = position
|
||||
&& let Some(buffer) = project_state
|
||||
.registered_buffers
|
||||
.get_mut(&buffer.entity_id())
|
||||
{
|
||||
buffer.last_position = Some(position);
|
||||
}
|
||||
let project_state = self.projects.get(&project.entity_id())?;
|
||||
|
||||
let CurrentEditPrediction {
|
||||
requested_by,
|
||||
@@ -1151,21 +1131,12 @@ impl EditPredictionStore {
|
||||
};
|
||||
|
||||
self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
|
||||
let Some((active_buffer, snapshot, cursor_point)) = this
|
||||
.read_with(cx, |this, cx| {
|
||||
let project_state = this.projects.get(&project.entity_id())?;
|
||||
let (buffer, position) = project_state.active_buffer(&project, cx)?;
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
if !Self::predictions_enabled_at(&snapshot, position, cx) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let cursor_point = position
|
||||
.map(|pos| pos.to_point(&snapshot))
|
||||
.unwrap_or_default();
|
||||
|
||||
Some((buffer, snapshot, cursor_point))
|
||||
let Some(open_buffer_task) = project
|
||||
.update(cx, |project, cx| {
|
||||
project
|
||||
.active_entry()
|
||||
.and_then(|entry| project.path_for_entry(entry, cx))
|
||||
.map(|path| project.open_buffer(path, cx))
|
||||
})
|
||||
.log_err()
|
||||
.flatten()
|
||||
@@ -1174,11 +1145,14 @@ impl EditPredictionStore {
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let active_buffer = open_buffer_task.await?;
|
||||
let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
|
||||
active_buffer,
|
||||
&snapshot,
|
||||
Default::default(),
|
||||
cursor_point,
|
||||
Default::default(),
|
||||
&project,
|
||||
cx,
|
||||
)
|
||||
@@ -1223,37 +1197,6 @@ impl EditPredictionStore {
|
||||
});
|
||||
}
|
||||
|
||||
fn predictions_enabled_at(
|
||||
snapshot: &BufferSnapshot,
|
||||
position: Option<language::Anchor>,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
let file = snapshot.file();
|
||||
let all_settings = all_language_settings(file, cx);
|
||||
if !all_settings.show_edit_predictions(snapshot.language(), cx)
|
||||
|| file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(last_position) = position {
|
||||
let settings = snapshot.settings_at(last_position, cx);
|
||||
|
||||
if !settings.edit_predictions_disabled_in.is_empty()
|
||||
&& let Some(scope) = snapshot.language_scope_at(last_position)
|
||||
&& let Some(scope_name) = scope.override_name()
|
||||
&& settings
|
||||
.edit_predictions_disabled_in
|
||||
.iter()
|
||||
.any(|s| s == scope_name)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
|
||||
#[cfg(test)]
|
||||
@@ -1588,8 +1531,8 @@ impl EditPredictionStore {
|
||||
client: Arc<Client>,
|
||||
llm_token: LlmApiToken,
|
||||
app_version: Version,
|
||||
#[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
#[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
|
||||
#[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
#[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
|
||||
) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
|
||||
let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
|
||||
http_client::Url::parse(&predict_edits_url)?
|
||||
@@ -1599,7 +1542,7 @@ impl EditPredictionStore {
|
||||
.build_zed_llm_url("/predict_edits/raw", &[])?
|
||||
};
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
let cache_key = if let Some(cache) = eval_cache {
|
||||
use collections::FxHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
@@ -1633,7 +1576,7 @@ impl EditPredictionStore {
|
||||
)
|
||||
.await?;
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
if let Some((cache, request, key)) = cache_key {
|
||||
cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
|
||||
}
|
||||
@@ -1765,7 +1708,7 @@ impl EditPredictionStore {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
pub fn set_context_for_buffer(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
@@ -1890,10 +1833,10 @@ pub struct ZedUpdateRequiredError {
|
||||
minimum_version: Version,
|
||||
}
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
pub type EvalCacheKey = (EvalCacheEntryKind, u64);
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum EvalCacheEntryKind {
|
||||
Context,
|
||||
@@ -1901,7 +1844,7 @@ pub enum EvalCacheEntryKind {
|
||||
Prediction,
|
||||
}
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
impl std::fmt::Display for EvalCacheEntryKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
@@ -1912,7 +1855,7 @@ impl std::fmt::Display for EvalCacheEntryKind {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
pub trait EvalCache: Send + Sync {
|
||||
fn read(&self, key: EvalCacheKey) -> Option<String>;
|
||||
fn write(&self, key: EvalCacheKey, input: &str, value: &str);
|
||||
|
||||
@@ -45,6 +45,10 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
.await;
|
||||
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.register_project(&project, cx);
|
||||
});
|
||||
|
||||
let buffer1 = project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
|
||||
@@ -56,11 +60,6 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
let position = snapshot1.anchor_before(language::Point::new(1, 3));
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.register_project(&project, cx);
|
||||
ep_store.register_buffer(&buffer1, &project, cx);
|
||||
});
|
||||
|
||||
// Prediction for current file
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
@@ -85,9 +84,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
let prediction = ep_store
|
||||
.prediction_at(&buffer1, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(prediction, BufferEditPrediction::Local { .. });
|
||||
});
|
||||
@@ -141,9 +140,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
let prediction = ep_store
|
||||
.prediction_at(&buffer1, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
prediction,
|
||||
@@ -159,9 +158,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
let prediction = ep_store
|
||||
.prediction_at(&buffer2, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer2, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(prediction, BufferEditPrediction::Local { .. });
|
||||
});
|
||||
@@ -345,10 +344,10 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
assert!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.is_none()
|
||||
);
|
||||
});
|
||||
@@ -405,10 +404,10 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
assert!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.is_none()
|
||||
);
|
||||
});
|
||||
@@ -470,10 +469,10 @@ async fn test_replace_current(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -493,11 +492,11 @@ async fn test_replace_current(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
// second replaces first
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -552,10 +551,10 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -587,11 +586,11 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
// first is preferred over second
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -658,11 +657,11 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
// current prediction is second
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -676,11 +675,11 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
// current prediction is still second, since first was cancelled
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -769,11 +768,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
// current prediction is first
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -787,11 +786,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
// current prediction is still first, since second was cancelled
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -805,11 +804,11 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
// third completes and replaces first
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
|
||||
@@ -735,7 +735,6 @@ mod tests {
|
||||
true,
|
||||
fs.clone(),
|
||||
Default::default(),
|
||||
true,
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
@@ -759,7 +758,6 @@ mod tests {
|
||||
true,
|
||||
fs.clone(),
|
||||
Default::default(),
|
||||
true,
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
@@ -818,7 +816,6 @@ mod tests {
|
||||
true,
|
||||
fs.clone(),
|
||||
Default::default(),
|
||||
true,
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -1,34 +1,40 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
|
||||
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
use crate::{
|
||||
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
|
||||
EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
|
||||
prediction::EditPredictionResult,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, SharedString, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
|
||||
use language_model::{ApiKeyState, EnvVar, env_var};
|
||||
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
|
||||
const MAX_CONTEXT_TOKENS: usize = 150;
|
||||
const MAX_REWRITE_TOKENS: usize = 350;
|
||||
|
||||
pub struct Mercury {
|
||||
pub api_token: Entity<ApiKeyState>,
|
||||
pub api_token: Shared<Task<Option<String>>>,
|
||||
}
|
||||
|
||||
impl Mercury {
|
||||
pub fn new(cx: &mut App) -> Self {
|
||||
pub fn new(cx: &App) -> Self {
|
||||
Mercury {
|
||||
api_token: mercury_api_token(cx),
|
||||
api_token: load_api_token(cx).shared(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
|
||||
self.api_token = Task::ready(api_token.clone()).shared();
|
||||
store_api_token_in_keychain(api_token, cx)
|
||||
}
|
||||
|
||||
pub(crate) fn request_prediction(
|
||||
&self,
|
||||
EditPredictionModelInput {
|
||||
@@ -42,10 +48,7 @@ impl Mercury {
|
||||
}: EditPredictionModelInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
self.api_token.update(cx, |key_state, cx| {
|
||||
_ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
|
||||
});
|
||||
let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
|
||||
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
@@ -296,16 +299,45 @@ fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(
|
||||
prompt.push_str(delimiters.end);
|
||||
}
|
||||
|
||||
pub const MERCURY_CREDENTIALS_URL: SharedString =
|
||||
SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
|
||||
pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
|
||||
pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
|
||||
pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
|
||||
pub static MERCURY_API_KEY: std::sync::OnceLock<Entity<ApiKeyState>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
|
||||
MERCURY_API_KEY
|
||||
.get_or_init(|| {
|
||||
cx.new(|_| ApiKeyState::new(MERCURY_CREDENTIALS_URL, MERCURY_TOKEN_ENV_VAR.clone()))
|
||||
})
|
||||
.clone()
|
||||
pub fn load_api_token(cx: &App) -> Task<Option<String>> {
|
||||
if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN")
|
||||
.ok()
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
return Task::ready(Some(api_token));
|
||||
}
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let (_, credentials) = credentials_provider
|
||||
.read_credentials(MERCURY_CREDENTIALS_URL, &cx)
|
||||
.await
|
||||
.ok()??;
|
||||
String::from_utf8(credentials).ok()
|
||||
})
|
||||
}
|
||||
|
||||
fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
if let Some(api_token) = api_token {
|
||||
credentials_provider
|
||||
.write_credentials(
|
||||
MERCURY_CREDENTIALS_URL,
|
||||
MERCURY_CREDENTIALS_USERNAME,
|
||||
api_token.as_bytes(),
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.context("Failed to save Mercury API token to system keychain")
|
||||
} else {
|
||||
credentials_provider
|
||||
.delete_credentials(MERCURY_CREDENTIALS_URL, cx)
|
||||
.await
|
||||
.context("Failed to delete Mercury API token from system keychain")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use anyhow::Result;
|
||||
use futures::AsyncReadExt as _;
|
||||
use anyhow::{Context as _, Result};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, SharedString, Task,
|
||||
App, AppContext as _, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{Point, ToOffset as _};
|
||||
use language_model::{ApiKeyState, EnvVar, env_var};
|
||||
use lsp::DiagnosticSeverity;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
@@ -20,28 +20,30 @@ use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredicti
|
||||
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
|
||||
|
||||
pub struct SweepAi {
|
||||
pub api_token: Entity<ApiKeyState>,
|
||||
pub api_token: Shared<Task<Option<String>>>,
|
||||
pub debug_info: Arc<str>,
|
||||
}
|
||||
|
||||
impl SweepAi {
|
||||
pub fn new(cx: &mut App) -> Self {
|
||||
pub fn new(cx: &App) -> Self {
|
||||
SweepAi {
|
||||
api_token: sweep_api_token(cx),
|
||||
api_token: load_api_token(cx).shared(),
|
||||
debug_info: debug_info(cx),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
|
||||
self.api_token = Task::ready(api_token.clone()).shared();
|
||||
store_api_token_in_keychain(api_token, cx)
|
||||
}
|
||||
|
||||
pub fn request_prediction_with_sweep(
|
||||
&self,
|
||||
inputs: EditPredictionModelInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let debug_info = self.debug_info.clone();
|
||||
self.api_token.update(cx, |key_state, cx| {
|
||||
_ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
|
||||
});
|
||||
let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
|
||||
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = inputs
|
||||
@@ -268,18 +270,47 @@ impl SweepAi {
|
||||
}
|
||||
}
|
||||
|
||||
pub const SWEEP_CREDENTIALS_URL: SharedString =
|
||||
SharedString::new_static("https://autocomplete.sweep.dev");
|
||||
pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev";
|
||||
pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
|
||||
pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
|
||||
pub static SWEEP_API_KEY: std::sync::OnceLock<Entity<ApiKeyState>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
|
||||
SWEEP_API_KEY
|
||||
.get_or_init(|| {
|
||||
cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()))
|
||||
})
|
||||
.clone()
|
||||
pub fn load_api_token(cx: &App) -> Task<Option<String>> {
|
||||
if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN")
|
||||
.ok()
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
return Task::ready(Some(api_token));
|
||||
}
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let (_, credentials) = credentials_provider
|
||||
.read_credentials(SWEEP_CREDENTIALS_URL, &cx)
|
||||
.await
|
||||
.ok()??;
|
||||
String::from_utf8(credentials).ok()
|
||||
})
|
||||
}
|
||||
|
||||
fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
if let Some(api_token) = api_token {
|
||||
credentials_provider
|
||||
.write_credentials(
|
||||
SWEEP_CREDENTIALS_URL,
|
||||
SWEEP_CREDENTIALS_USERNAME,
|
||||
api_token.as_bytes(),
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.context("Failed to save Sweep API token to system keychain")
|
||||
} else {
|
||||
credentials_provider
|
||||
.delete_credentials(SWEEP_CREDENTIALS_URL, cx)
|
||||
.await
|
||||
.context("Failed to delete Sweep API token from system keychain")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
|
||||
@@ -15,9 +15,7 @@ use collections::HashMap;
|
||||
use gpui::AsyncApp;
|
||||
use gpui::Entity;
|
||||
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
|
||||
use project::{Project, ProjectPath};
|
||||
use util::paths::PathStyle;
|
||||
use util::rel_path::RelPath;
|
||||
use project::Project;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
|
||||
@@ -30,27 +28,18 @@ pub async fn apply_diff(
|
||||
) -> Result<OpenedBuffers> {
|
||||
let mut included_files = HashMap::default();
|
||||
|
||||
let worktree_id = project.read_with(cx, |project, cx| {
|
||||
anyhow::Ok(
|
||||
project
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.context("no worktrees")?
|
||||
.read(cx)
|
||||
.id(),
|
||||
)
|
||||
})??;
|
||||
|
||||
for line in diff_str.lines() {
|
||||
let diff_line = DiffLine::parse(line);
|
||||
|
||||
if let DiffLine::OldPath { path } = diff_line {
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc(),
|
||||
};
|
||||
let project_path =
|
||||
project
|
||||
.find_project_path(path.as_ref(), cx)
|
||||
.with_context(|| {
|
||||
format!("Failed to find worktree for new path: {}", path)
|
||||
})?;
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
@@ -138,7 +127,7 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
|
||||
DiffEvent::Hunk { hunk, .. } => {
|
||||
let hunk_offset = text
|
||||
.find(&hunk.context)
|
||||
.ok_or_else(|| anyhow!("couldn't resolve hunk {:?}", hunk.context))?;
|
||||
.ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
|
||||
for edit in hunk.edits.iter().rev() {
|
||||
let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
|
||||
text.replace_range(range, &edit.text);
|
||||
@@ -737,38 +726,38 @@ mod tests {
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/file1
|
||||
+++ b/file1
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one
|
||||
two
|
||||
-three
|
||||
+3
|
||||
four
|
||||
five
|
||||
--- a/file1
|
||||
+++ b/file1
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
3
|
||||
-four
|
||||
-five
|
||||
+4
|
||||
+5
|
||||
--- a/file1
|
||||
+++ b/file1
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
-one
|
||||
-two
|
||||
3
|
||||
4
|
||||
--- a/file2
|
||||
+++ b/file2
|
||||
--- a/root/file2
|
||||
+++ b/root/file2
|
||||
+5
|
||||
six
|
||||
--- a/file2
|
||||
+++ b/file2
|
||||
--- a/root/file2
|
||||
+++ b/root/file2
|
||||
seven
|
||||
+7.5
|
||||
eight
|
||||
--- a/file2
|
||||
+++ b/file2
|
||||
--- a/root/file2
|
||||
+++ b/root/file2
|
||||
ten
|
||||
+11
|
||||
"#};
|
||||
@@ -837,8 +826,8 @@ mod tests {
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/file1
|
||||
+++ b/file1
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one
|
||||
two
|
||||
-three
|
||||
|
||||
@@ -100,7 +100,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
|
||||
) -> bool {
|
||||
let store = self.store.read(cx);
|
||||
if store.edit_prediction_model == EditPredictionModel::Sweep {
|
||||
store.has_sweep_api_token(cx)
|
||||
store.has_sweep_api_token()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
@@ -125,15 +125,14 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
|
||||
return;
|
||||
}
|
||||
|
||||
self.store.update(cx, |store, cx| {
|
||||
if let Some(current) =
|
||||
store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
|
||||
&& let BufferEditPrediction::Local { prediction } = current
|
||||
&& prediction.interpolate(buffer.read(cx)).is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
|
||||
&& let BufferEditPrediction::Local { prediction } = current
|
||||
&& prediction.interpolate(buffer.read(cx)).is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
self.store.update(cx, |store, cx| {
|
||||
store.refresh_context(&self.project, &buffer, cursor_position, cx);
|
||||
store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
|
||||
});
|
||||
@@ -172,68 +171,69 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
|
||||
cursor_position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<edit_prediction_types::EditPrediction> {
|
||||
self.store.update(cx, |store, cx| {
|
||||
let prediction =
|
||||
store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
|
||||
let prediction =
|
||||
self.store
|
||||
.read(cx)
|
||||
.current_prediction_for_buffer(buffer, &self.project, cx)?;
|
||||
|
||||
let prediction = match prediction {
|
||||
BufferEditPrediction::Local { prediction } => prediction,
|
||||
BufferEditPrediction::Jump { prediction } => {
|
||||
return Some(edit_prediction_types::EditPrediction::Jump {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
snapshot: prediction.snapshot.clone(),
|
||||
target: prediction.edits.first().unwrap().0.start,
|
||||
});
|
||||
}
|
||||
};
|
||||
let prediction = match prediction {
|
||||
BufferEditPrediction::Local { prediction } => prediction,
|
||||
BufferEditPrediction::Jump { prediction } => {
|
||||
return Some(edit_prediction_types::EditPrediction::Jump {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
snapshot: prediction.snapshot.clone(),
|
||||
target: prediction.edits.first().unwrap().0.start,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let buffer = buffer.read(cx);
|
||||
let snapshot = buffer.snapshot();
|
||||
let buffer = buffer.read(cx);
|
||||
let snapshot = buffer.snapshot();
|
||||
|
||||
let Some(edits) = prediction.interpolate(&snapshot) else {
|
||||
let Some(edits) = prediction.interpolate(&snapshot) else {
|
||||
self.store.update(cx, |store, _cx| {
|
||||
store.reject_current_prediction(
|
||||
EditPredictionRejectReason::InterpolatedEmpty,
|
||||
&self.project,
|
||||
);
|
||||
return None;
|
||||
};
|
||||
});
|
||||
return None;
|
||||
};
|
||||
|
||||
let cursor_row = cursor_position.to_point(&snapshot).row;
|
||||
let (closest_edit_ix, (closest_edit_range, _)) =
|
||||
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
|
||||
let distance_from_start =
|
||||
cursor_row.abs_diff(range.start.to_point(&snapshot).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
|
||||
cmp::min(distance_from_start, distance_from_end)
|
||||
})?;
|
||||
let cursor_row = cursor_position.to_point(&snapshot).row;
|
||||
let (closest_edit_ix, (closest_edit_range, _)) =
|
||||
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
|
||||
let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
|
||||
cmp::min(distance_from_start, distance_from_end)
|
||||
})?;
|
||||
|
||||
let mut edit_start_ix = closest_edit_ix;
|
||||
for (range, _) in edits[..edit_start_ix].iter().rev() {
|
||||
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
|
||||
- range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_start_ix -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
let mut edit_start_ix = closest_edit_ix;
|
||||
for (range, _) in edits[..edit_start_ix].iter().rev() {
|
||||
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
|
||||
- range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_start_ix -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut edit_end_ix = closest_edit_ix + 1;
|
||||
for (range, _) in &edits[edit_end_ix..] {
|
||||
let distance_from_closest_edit = range.start.to_point(buffer).row
|
||||
- closest_edit_range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_end_ix += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
let mut edit_end_ix = closest_edit_ix + 1;
|
||||
for (range, _) in &edits[edit_end_ix..] {
|
||||
let distance_from_closest_edit =
|
||||
range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_end_ix += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Some(edit_prediction_types::EditPrediction::Local {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(prediction.edit_preview.clone()),
|
||||
})
|
||||
Some(edit_prediction_types::EditPrediction::Local {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(prediction.edit_preview.clone()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
use crate::EvalCacheEntryKind;
|
||||
use crate::open_ai_response::text_from_response;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
@@ -44,7 +44,7 @@ pub fn request_prediction_with_zeta2(
|
||||
let llm_token = store.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
let eval_cache = store.eval_cache.clone();
|
||||
|
||||
let request_task = cx.background_spawn({
|
||||
@@ -95,9 +95,9 @@ pub fn request_prediction_with_zeta2(
|
||||
client,
|
||||
llm_token,
|
||||
app_version,
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache,
|
||||
#[cfg(feature = "cli-support")]
|
||||
#[cfg(feature = "eval-support")]
|
||||
EvalCacheEntryKind::Prediction,
|
||||
)
|
||||
.await;
|
||||
@@ -226,18 +226,3 @@ pub fn zeta2_prompt_input(
|
||||
};
|
||||
(editable_offset_range, prompt_input)
|
||||
}
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
|
||||
let text = &input.cursor_excerpt;
|
||||
let editable_region = input.editable_range_in_excerpt.clone();
|
||||
let old_prefix = &text[..editable_region.start];
|
||||
let old_suffix = &text[editable_region.end..];
|
||||
|
||||
let new = crate::udiff::apply_diff_to_string(patch, text)?;
|
||||
if !new.starts_with(old_prefix) || !new.ends_with(old_suffix) {
|
||||
anyhow::bail!("Patch shouldn't affect text outside of editable region");
|
||||
}
|
||||
|
||||
Ok(new[editable_region.start..new.len() - old_suffix.len()].to_string())
|
||||
}
|
||||
|
||||
@@ -34,7 +34,6 @@ language_extension.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
languages = { workspace = true, features = ["load-grammars"] }
|
||||
libc.workspace = true
|
||||
log.workspace = true
|
||||
node_runtime.workspace = true
|
||||
paths.workspace = true
|
||||
@@ -53,9 +52,10 @@ sqlez_macros.workspace = true
|
||||
terminal_view.workspace = true
|
||||
util.workspace = true
|
||||
watch.workspace = true
|
||||
edit_prediction = { workspace = true, features = ["cli-support"] }
|
||||
edit_prediction = { workspace = true, features = ["eval-support"] }
|
||||
wasmtime.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
zlog.workspace = true
|
||||
|
||||
# Wasmtime is included as a dependency in order to enable the same
|
||||
# features that are enabled in Zed.
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use std::mem;
|
||||
|
||||
use crate::example::Example;
|
||||
|
||||
pub async fn run_distill(example: &mut Example) -> Result<()> {
|
||||
let [prediction]: [_; 1] =
|
||||
mem::take(&mut example.predictions)
|
||||
.try_into()
|
||||
.map_err(|preds: Vec<_>| {
|
||||
anyhow!(
|
||||
"Example has {} predictions, but it should have exactly one",
|
||||
preds.len()
|
||||
)
|
||||
})?;
|
||||
|
||||
example.expected_patch = prediction.actual_patch;
|
||||
example.prompt = None;
|
||||
example.predictions = Vec::new();
|
||||
example.score = Vec::new();
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
|
||||
use crate::{
|
||||
PredictionProvider, PromptFormat,
|
||||
metrics::ClassificationMetrics,
|
||||
paths::{REPOS_DIR, WORKTREES_DIR},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use gpui::Entity;
|
||||
use http_client::Url;
|
||||
@@ -22,7 +25,6 @@ pub struct Example {
|
||||
pub name: String,
|
||||
pub repository_url: String,
|
||||
pub revision: String,
|
||||
#[serde(default)]
|
||||
pub uncommitted_diff: String,
|
||||
pub cursor_path: Arc<Path>,
|
||||
pub cursor_position: String,
|
||||
@@ -99,7 +101,7 @@ pub struct ExampleScore {
|
||||
}
|
||||
|
||||
impl Example {
|
||||
pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
|
||||
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
|
||||
// git@github.com:owner/repo.git
|
||||
if self.repository_url.contains('@') {
|
||||
let (owner, repo) = self
|
||||
@@ -131,6 +133,17 @@ impl Example {
|
||||
Ok((owner.into(), repo.into()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn worktree_path(&self) -> PathBuf {
|
||||
WORKTREES_DIR
|
||||
.join(&self.name)
|
||||
.join(self.repo_name().unwrap().1.as_ref())
|
||||
}
|
||||
|
||||
pub fn repo_path(&self) -> PathBuf {
|
||||
let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
|
||||
REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
|
||||
@@ -182,9 +195,9 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
|
||||
.enumerate()
|
||||
.map(|(line_ix, line)| {
|
||||
let mut example =
|
||||
serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
|
||||
serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"Failed to parse example on {}:{}\n{error}",
|
||||
"Failed to parse example on {}:{}",
|
||||
path.display(),
|
||||
line_ix + 1
|
||||
)
|
||||
@@ -204,8 +217,6 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort_examples_by_repo_and_rev(&mut examples);
|
||||
examples
|
||||
}
|
||||
|
||||
@@ -223,25 +234,6 @@ pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
|
||||
examples.sort_by(|a, b| {
|
||||
a.repository_url
|
||||
.cmp(&b.repository_url)
|
||||
.then(b.revision.cmp(&a.revision))
|
||||
});
|
||||
}
|
||||
|
||||
pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
|
||||
let mut examples_by_repo = HashMap::default();
|
||||
for example in examples.iter_mut() {
|
||||
examples_by_repo
|
||||
.entry(example.repository_url.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(example);
|
||||
}
|
||||
examples_by_repo.into_values().collect()
|
||||
}
|
||||
|
||||
fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
|
||||
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
|
||||
|
||||
@@ -272,12 +264,12 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
|
||||
state: None,
|
||||
};
|
||||
|
||||
let mut name = String::new();
|
||||
let mut text = String::new();
|
||||
let mut block_info: CowStr = "".into();
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum Section {
|
||||
Start,
|
||||
UncommittedDiff,
|
||||
EditHistory,
|
||||
CursorPosition,
|
||||
@@ -286,16 +278,14 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
|
||||
Other,
|
||||
}
|
||||
|
||||
let mut current_section = Section::Start;
|
||||
let mut current_section = Section::Other;
|
||||
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Text(line) => {
|
||||
text.push_str(&line);
|
||||
|
||||
if let Section::Start = current_section
|
||||
&& let Some((field, value)) = line.split_once('=')
|
||||
{
|
||||
if let Some((field, value)) = line.split_once('=') {
|
||||
match field.trim() {
|
||||
REPOSITORY_URL_FIELD => {
|
||||
example.repository_url = value.trim().to_string();
|
||||
@@ -307,6 +297,14 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
|
||||
}
|
||||
}
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
|
||||
if !name.is_empty() {
|
||||
anyhow::bail!(
|
||||
"Found multiple H1 headings. There should only be one with the name of the example."
|
||||
);
|
||||
}
|
||||
name = mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
|
||||
let title = mem::take(&mut text);
|
||||
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
|
||||
@@ -365,7 +363,7 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
|
||||
Section::ExpectedPatch => {
|
||||
example.expected_patch = mem::take(&mut text);
|
||||
}
|
||||
Section::Start | Section::Other => {}
|
||||
Section::Other => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
|
||||
@@ -2,15 +2,9 @@ use crate::{
|
||||
PromptFormat,
|
||||
example::{Example, ExamplePrompt},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
progress::{Progress, Step},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use anyhow::{Context as _, Result, ensure};
|
||||
use edit_prediction::{
|
||||
EditPredictionStore,
|
||||
zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
|
||||
};
|
||||
use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
|
||||
use gpui::AsyncApp;
|
||||
use std::sync::Arc;
|
||||
use zeta_prompt::format_zeta_prompt;
|
||||
@@ -20,71 +14,57 @@ pub async fn run_format_prompt(
|
||||
prompt_format: PromptFormat,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
) {
|
||||
run_context_retrieval(example, app_state, cx.clone()).await;
|
||||
|
||||
let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
|
||||
|
||||
match prompt_format {
|
||||
PromptFormat::Teacher => {
|
||||
let prompt = TeacherPrompt::format_prompt(example);
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output: example.expected_patch.clone(), // TODO
|
||||
format: prompt_format,
|
||||
});
|
||||
}
|
||||
let prompt = match prompt_format {
|
||||
PromptFormat::Teacher => TeacherPrompt::format(example),
|
||||
PromptFormat::Zeta2 => {
|
||||
run_load_project(example, app_state, cx.clone()).await?;
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
|
||||
let state = example.state.as_ref().context("state must be set")?;
|
||||
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let snapshot = state
|
||||
.buffer
|
||||
.read_with(&cx, |buffer, _| buffer.snapshot())
|
||||
.unwrap();
|
||||
let project = state.project.clone();
|
||||
let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
|
||||
anyhow::Ok(zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example
|
||||
.context
|
||||
.as_ref()
|
||||
.context("context must be set")?
|
||||
.files
|
||||
.clone(),
|
||||
ep_store.edit_history_for_project(&project, cx),
|
||||
example.cursor_path.clone(),
|
||||
example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.context("buffer must be set")?
|
||||
.cursor_offset,
|
||||
))
|
||||
})??;
|
||||
let prompt = format_zeta_prompt(&input);
|
||||
let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone())?;
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output,
|
||||
format: prompt_format,
|
||||
});
|
||||
let (_, input) = ep_store
|
||||
.update(&mut cx, |ep_store, _cx| {
|
||||
zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example.context.as_ref().unwrap().files.clone(),
|
||||
ep_store.edit_history_for_project(&project),
|
||||
example.cursor_path.clone(),
|
||||
example.buffer.as_ref().unwrap().cursor_offset,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
format_zeta_prompt(&input)
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output: example.expected_patch.clone(), // TODO
|
||||
format: prompt_format,
|
||||
});
|
||||
}
|
||||
|
||||
pub trait PromptFormatter {
|
||||
fn format(example: &Example) -> String;
|
||||
}
|
||||
|
||||
pub trait PromptParser {
|
||||
/// Return unified diff patch of prediction given raw LLM response
|
||||
fn parse(example: &Example, response: &str) -> String;
|
||||
}
|
||||
|
||||
pub struct TeacherPrompt;
|
||||
|
||||
impl TeacherPrompt {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
|
||||
pub fn format_prompt(example: &Example) -> String {
|
||||
impl PromptFormatter for TeacherPrompt {
|
||||
fn format(example: &Example) -> String {
|
||||
let edit_history = Self::format_edit_history(&example.edit_history);
|
||||
let context = Self::format_context(example);
|
||||
let editable_region = Self::format_editable_region(example);
|
||||
@@ -96,47 +76,15 @@ impl TeacherPrompt {
|
||||
|
||||
prompt
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse(example: &Example, response: &str) -> Result<String> {
|
||||
// Ideally, we should always be able to find cursor position in the retrieved context.
|
||||
// In reality, sometimes we don't find it for these reasons:
|
||||
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
|
||||
// (can be fixed by getting cursor coordinates at the load_example stage)
|
||||
// 2. Context retriever just didn't include cursor line.
|
||||
//
|
||||
// In that case, fallback to using `cursor_position` as excerpt.
|
||||
let cursor_file = &example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.context("`buffer` should be filled in in the context collection step")?
|
||||
.content;
|
||||
impl TeacherPrompt {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
|
||||
|
||||
// Extract updated (new) editable region from the model response
|
||||
let new_editable_region = extract_last_codeblock(response);
|
||||
|
||||
// Reconstruct old editable region we sent to the model
|
||||
let old_editable_region = Self::format_editable_region(example);
|
||||
let old_editable_region = Self::extract_editable_region(&old_editable_region);
|
||||
ensure!(
|
||||
cursor_file.contains(&old_editable_region),
|
||||
"Something's wrong: editable_region is not found in the cursor file"
|
||||
);
|
||||
|
||||
// Apply editable region to a larger context and compute diff.
|
||||
// This is needed to get a better context lines around the editable region
|
||||
let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
|
||||
let diff = language::unified_diff(&cursor_file, &edited_file);
|
||||
|
||||
let diff = indoc::formatdoc! {"
|
||||
--- a/{path}
|
||||
+++ b/{path}
|
||||
{diff}",
|
||||
path = example.cursor_path.to_string_lossy(),
|
||||
diff = diff,
|
||||
};
|
||||
|
||||
Ok(diff)
|
||||
}
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
// Strip comments ("garbage lines") from edit history
|
||||
@@ -159,7 +107,9 @@ impl TeacherPrompt {
|
||||
}
|
||||
|
||||
fn format_context(example: &Example) -> String {
|
||||
assert!(example.context.is_some(), "Missing context retriever step");
|
||||
if example.context.is_none() {
|
||||
panic!("Missing context retriever step");
|
||||
}
|
||||
|
||||
let mut prompt = String::new();
|
||||
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
|
||||
@@ -207,6 +157,49 @@ impl TeacherPrompt {
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptParser for TeacherPrompt {
|
||||
fn parse(example: &Example, response: &str) -> String {
|
||||
// Ideally, we should always be able to find cursor position in the retrieved context.
|
||||
// In reality, sometimes we don't find it for these reasons:
|
||||
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
|
||||
// (can be fixed by getting cursor coordinates at the load_example stage)
|
||||
// 2. Context retriever just didn't include cursor line.
|
||||
//
|
||||
// In that case, fallback to using `cursor_position` as excerpt.
|
||||
let cursor_file = &example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.expect("`buffer` should be filled in in the context collection step")
|
||||
.content;
|
||||
|
||||
// Extract updated (new) editable region from the model response
|
||||
let new_editable_region = extract_last_codeblock(response);
|
||||
|
||||
// Reconstruct old editable region we sent to the model
|
||||
let old_editable_region = Self::format_editable_region(example);
|
||||
let old_editable_region = Self::extract_editable_region(&old_editable_region);
|
||||
if !cursor_file.contains(&old_editable_region) {
|
||||
panic!("Something's wrong: editable_region is not found in the cursor file")
|
||||
}
|
||||
|
||||
// Apply editable region to a larger context and compute diff.
|
||||
// This is needed to get a better context lines around the editable region
|
||||
let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
|
||||
let diff = language::unified_diff(&cursor_file, &edited_file);
|
||||
|
||||
let diff = indoc::formatdoc! {"
|
||||
--- a/{path}
|
||||
+++ b/{path}
|
||||
{diff}
|
||||
",
|
||||
path = example.cursor_path.to_string_lossy(),
|
||||
diff = diff,
|
||||
};
|
||||
|
||||
diff
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_last_codeblock(text: &str) -> String {
|
||||
let mut last_block = None;
|
||||
let mut search_start = 0;
|
||||
@@ -228,7 +221,7 @@ fn extract_last_codeblock(text: &str) -> String {
|
||||
}
|
||||
|
||||
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
|
||||
let code_block = &text[backtick_end + 1..backtick_end + end_pos];
|
||||
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
|
||||
last_block = Some(code_block.to_string());
|
||||
search_start = backtick_end + end_pos + backtick_count;
|
||||
} else {
|
||||
@@ -257,7 +250,7 @@ mod tests {
|
||||
`````
|
||||
"};
|
||||
let last_block = extract_last_codeblock(text);
|
||||
assert_eq!(last_block, "last block\n");
|
||||
assert_eq!(last_block, "last block");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use client::{Client, ProxySettings, UserStore};
|
||||
use collections::HashMap;
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::RealFs;
|
||||
use gpui::http_client::read_proxy_from_env;
|
||||
@@ -8,13 +7,12 @@ use gpui_tokio::Tokio;
|
||||
use language::LanguageRegistry;
|
||||
use language_extension::LspAccess;
|
||||
use node_runtime::{NodeBinaryOptions, NodeRuntime};
|
||||
use project::Project;
|
||||
use project::project_settings::ProjectSettings;
|
||||
use release_channel::{AppCommitSha, AppVersion};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt as _;
|
||||
|
||||
/// Headless subset of `workspace::AppState`.
|
||||
@@ -24,22 +22,9 @@ pub struct EpAppState {
|
||||
pub user_store: Entity<UserStore>,
|
||||
pub fs: Arc<dyn fs::Fs>,
|
||||
pub node_runtime: NodeRuntime,
|
||||
pub project_cache: ProjectCache,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ProjectCache(Mutex<HashMap<String, Entity<Project>>>);
|
||||
|
||||
impl ProjectCache {
|
||||
pub fn insert(&self, repository_url: String, project: Entity<Project>) {
|
||||
self.0.lock().unwrap().insert(repository_url, project);
|
||||
}
|
||||
|
||||
pub fn get(&self, repository_url: &String) -> Option<Entity<Project>> {
|
||||
self.0.lock().unwrap().get(repository_url).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: dedupe with crates/eval/src/eval.rs
|
||||
pub fn init(cx: &mut App) -> EpAppState {
|
||||
let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
|
||||
|
||||
@@ -127,14 +112,11 @@ pub fn init(cx: &mut App) -> EpAppState {
|
||||
prompt_store::init(cx);
|
||||
terminal_view::init(cx);
|
||||
|
||||
let project_cache = ProjectCache::default();
|
||||
|
||||
EpAppState {
|
||||
languages,
|
||||
client,
|
||||
user_store,
|
||||
fs,
|
||||
node_runtime,
|
||||
project_cache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use crate::{
|
||||
example::{Example, ExampleBuffer, ExampleState},
|
||||
headless::EpAppState,
|
||||
paths::{REPOS_DIR, WORKTREES_DIR},
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
@@ -13,7 +11,7 @@ use futures::{
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
|
||||
use language::{Anchor, Buffer, ToOffset, ToPoint};
|
||||
use project::buffer_store::BufferStoreEvent;
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
@@ -25,76 +23,68 @@ use std::{
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
|
||||
pub async fn run_load_project(
|
||||
example: &mut Example,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
|
||||
if example.state.is_some() {
|
||||
return Ok(());
|
||||
return;
|
||||
}
|
||||
|
||||
let progress = Progress::global().start(Step::LoadProject, &example.name);
|
||||
let project = setup_project(example, &app_state, &mut cx).await;
|
||||
let buffer_store = project
|
||||
.read_with(&cx, |project, _| project.buffer_store().clone())
|
||||
.unwrap();
|
||||
|
||||
let project = setup_project(example, &app_state, &progress, &mut cx).await?;
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
move |_, event, cx| match event {
|
||||
BufferStoreEvent::BufferAdded(buffer) => {
|
||||
ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
.detach();
|
||||
|
||||
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
|
||||
let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
|
||||
let cursor_point = cursor_position.to_point(&buffer);
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.map(|l| l.name().to_string())
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
(
|
||||
ExampleBuffer {
|
||||
let _open_buffers = apply_edit_history(example, &project, &mut cx)
|
||||
.await
|
||||
.unwrap();
|
||||
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
|
||||
example.buffer = buffer
|
||||
.read_with(&cx, |buffer, _cx| {
|
||||
let cursor_point = cursor_position.to_point(&buffer);
|
||||
Some(ExampleBuffer {
|
||||
content: buffer.text(),
|
||||
cursor_row: cursor_point.row,
|
||||
cursor_column: cursor_point.column,
|
||||
cursor_offset: cursor_position.to_offset(&buffer),
|
||||
},
|
||||
language_name,
|
||||
)
|
||||
})?;
|
||||
|
||||
progress.set_info(language_name, InfoStyle::Normal);
|
||||
|
||||
example.buffer = Some(example_buffer);
|
||||
})
|
||||
})
|
||||
.unwrap();
|
||||
example.state = Some(ExampleState {
|
||||
buffer,
|
||||
project,
|
||||
cursor_position,
|
||||
_open_buffers,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cursor_position(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(Entity<Buffer>, Anchor)> {
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
|
||||
let result = language_registry
|
||||
.load_language_for_file_path(&example.cursor_path)
|
||||
.await;
|
||||
|
||||
if let Err(error) = result
|
||||
&& !error.is::<LanguageNotFound>()
|
||||
{
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
let worktree = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.context("No visible worktrees")
|
||||
})??;
|
||||
) -> (Entity<Buffer>, Anchor) {
|
||||
let worktree = project
|
||||
.read_with(cx, |project, cx| {
|
||||
project.visible_worktrees(cx).next().unwrap()
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
|
||||
.context("Failed to create RelPath")?
|
||||
.unwrap()
|
||||
.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
@@ -105,12 +95,15 @@ async fn cursor_position(
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
let cursor_offset_within_excerpt = example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.context("missing cursor marker")?;
|
||||
.ok_or_else(|| anyhow!("missing cursor marker"))
|
||||
.unwrap();
|
||||
let mut cursor_excerpt = example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
@@ -120,107 +113,73 @@ async fn cursor_position(
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let (excerpt_offset, _) = matches.next().with_context(|| {
|
||||
format!(
|
||||
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
|
||||
example.name
|
||||
)
|
||||
})?;
|
||||
anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
|
||||
Ok(excerpt_offset)
|
||||
})??;
|
||||
let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
|
||||
panic!(
|
||||
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
|
||||
);
|
||||
});
|
||||
assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
|
||||
excerpt_offset
|
||||
}).unwrap();
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor =
|
||||
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
|
||||
let cursor_anchor = cursor_buffer
|
||||
.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
|
||||
.unwrap();
|
||||
|
||||
Ok((cursor_buffer, cursor_anchor))
|
||||
(cursor_buffer, cursor_anchor)
|
||||
}
|
||||
|
||||
async fn setup_project(
|
||||
example: &mut Example,
|
||||
app_state: &Arc<EpAppState>,
|
||||
step_progress: &StepProgress,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<Project>> {
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx))?
|
||||
.context("Store should be initialized at init")?;
|
||||
) -> Entity<Project> {
|
||||
setup_worktree(example).await;
|
||||
|
||||
let worktree_path = setup_worktree(example, step_progress).await?;
|
||||
let project = cx
|
||||
.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
if let Some(project) = app_state.project_cache.get(&example.repository_url) {
|
||||
ep_store.update(cx, |ep_store, _| {
|
||||
ep_store.clear_history_for_project(&project);
|
||||
})?;
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
let buffers = buffer_store.read_with(cx, |buffer_store, _| {
|
||||
buffer_store.buffers().collect::<Vec<_>>()
|
||||
})?;
|
||||
for buffer in buffers {
|
||||
buffer
|
||||
.update(cx, |buffer, cx| buffer.reload(cx))?
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
return Ok(project);
|
||||
}
|
||||
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
project
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.disable_worktree_scanner(cx);
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
app_state
|
||||
.project_cache
|
||||
.insert(example.repository_url.clone(), project.clone());
|
||||
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
move |_, event, cx| match event {
|
||||
BufferStoreEvent::BufferAdded(buffer) => {
|
||||
ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})?
|
||||
.detach();
|
||||
|
||||
Ok(project)
|
||||
project.create_worktree(&example.worktree_path(), true, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
worktree
|
||||
.read_with(cx, |worktree, _cx| {
|
||||
worktree.as_local().unwrap().scan_complete()
|
||||
})
|
||||
.unwrap()
|
||||
.await;
|
||||
project
|
||||
}
|
||||
|
||||
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let worktree_path = WORKTREES_DIR
|
||||
.join(repo_owner.as_ref())
|
||||
.join(repo_name.as_ref());
|
||||
pub async fn setup_worktree(example: &Example) {
|
||||
let repo_dir = example.repo_path();
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
step_progress.set_substatus(format!("cloning {}", repo_name));
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
fs::create_dir_all(&repo_dir).unwrap();
|
||||
run_git(&repo_dir, &["init"]).await.unwrap();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &example.repository_url],
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
@@ -232,7 +191,6 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
} else {
|
||||
step_progress.set_substatus("fetching");
|
||||
if run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &example.revision],
|
||||
@@ -240,25 +198,39 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await?;
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
|
||||
.await
|
||||
.unwrap();
|
||||
if revision != example.revision {
|
||||
run_git(&repo_dir, &["tag", &example.revision, &revision])
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
revision
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
step_progress.set_substatus("preparing worktree");
|
||||
let worktree_path = example.worktree_path();
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()])
|
||||
.await
|
||||
.unwrap();
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["branch", "-f", &example.name, revision.as_str()],
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
@@ -269,36 +241,38 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
|
||||
&example.name,
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
// Apply the uncommitted diff for this example.
|
||||
if !example.uncommitted_diff.is_empty() {
|
||||
step_progress.set_substatus("applying diff");
|
||||
let mut apply_process = smol::process::Command::new("git")
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()?;
|
||||
.spawn()
|
||||
.unwrap();
|
||||
|
||||
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
|
||||
stdin.write_all(example.uncommitted_diff.as_bytes()).await?;
|
||||
stdin.close().await?;
|
||||
let mut stdin = apply_process.stdin.take().unwrap();
|
||||
stdin
|
||||
.write_all(example.uncommitted_diff.as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
stdin.close().await.unwrap();
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await?;
|
||||
anyhow::ensure!(
|
||||
apply_result.status.success(),
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
let apply_result = apply_process.output().await.unwrap();
|
||||
if !apply_result.status.success() {
|
||||
panic!(
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
step_progress.clear_substatus();
|
||||
Ok(worktree_path)
|
||||
}
|
||||
|
||||
async fn apply_edit_history(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
mod anthropic_client;
|
||||
mod distill;
|
||||
mod example;
|
||||
mod format_prompt;
|
||||
mod headless;
|
||||
@@ -7,7 +6,6 @@ mod load_project;
|
||||
mod metrics;
|
||||
mod paths;
|
||||
mod predict;
|
||||
mod progress;
|
||||
mod retrieve_context;
|
||||
mod score;
|
||||
|
||||
@@ -16,16 +14,12 @@ use edit_prediction::EditPredictionStore;
|
||||
use gpui::Application;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Display;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use crate::distill::run_distill;
|
||||
use crate::example::{group_examples_by_repo, read_examples, write_examples};
|
||||
use crate::example::{read_examples, write_examples};
|
||||
use crate::format_prompt::run_format_prompt;
|
||||
use crate::load_project::run_load_project;
|
||||
use crate::paths::FAILED_EXAMPLES_DIR;
|
||||
use crate::predict::run_prediction;
|
||||
use crate::progress::Progress;
|
||||
use crate::retrieve_context::run_context_retrieval;
|
||||
use crate::score::run_scoring;
|
||||
|
||||
@@ -34,7 +28,7 @@ use crate::score::run_scoring;
|
||||
struct EpArgs {
|
||||
#[arg(long, default_value_t = false)]
|
||||
printenv: bool,
|
||||
#[clap(long, default_value_t = 10, global = true)]
|
||||
#[clap(long, default_value_t = 10)]
|
||||
max_parallelism: usize,
|
||||
#[command(subcommand)]
|
||||
command: Option<Command>,
|
||||
@@ -44,8 +38,6 @@ struct EpArgs {
|
||||
output: Option<PathBuf>,
|
||||
#[arg(long, short, global = true)]
|
||||
in_place: bool,
|
||||
#[arg(long, short, global = true)]
|
||||
failfast: bool,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
@@ -53,7 +45,7 @@ enum Command {
|
||||
/// Parse markdown examples and output a combined .jsonl file
|
||||
ParseExample,
|
||||
/// Create git worktrees for each example and load file contents
|
||||
LoadProject,
|
||||
LoadBuffer,
|
||||
/// Retrieve context for input examples.
|
||||
Context,
|
||||
/// Generate a prompt string for a specific model
|
||||
@@ -62,67 +54,12 @@ enum Command {
|
||||
Predict(PredictArgs),
|
||||
/// Computes a score based on actual and expected patches
|
||||
Score(PredictArgs),
|
||||
/// Prepares a distillation dataset by copying expected outputs to
|
||||
/// predicted outputs and removing actual outputs and prompts.
|
||||
Distill,
|
||||
/// Print aggregated scores
|
||||
Eval(PredictArgs),
|
||||
/// Remove git repositories and worktrees
|
||||
Clean,
|
||||
}
|
||||
|
||||
impl Display for Command {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Command::ParseExample => write!(f, "parse-example"),
|
||||
Command::LoadProject => write!(f, "load-project"),
|
||||
Command::Context => write!(f, "context"),
|
||||
Command::FormatPrompt(format_prompt_args) => write!(
|
||||
f,
|
||||
"format-prompt --prompt-format={}",
|
||||
format_prompt_args
|
||||
.prompt_format
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Predict(predict_args) => {
|
||||
write!(
|
||||
f,
|
||||
"predict --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
)
|
||||
}
|
||||
Command::Score(predict_args) => {
|
||||
write!(
|
||||
f,
|
||||
"score --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
)
|
||||
}
|
||||
Command::Distill => write!(f, "distill"),
|
||||
Command::Eval(predict_args) => write!(
|
||||
f,
|
||||
"eval --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Clean => write!(f, "clean"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
struct FormatPromptArgs {
|
||||
#[clap(long)]
|
||||
@@ -150,7 +87,6 @@ enum PredictionProvider {
|
||||
Zeta1,
|
||||
Zeta2,
|
||||
Teacher,
|
||||
TeacherNonBatching,
|
||||
}
|
||||
|
||||
impl EpArgs {
|
||||
@@ -168,6 +104,8 @@ impl EpArgs {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
zlog::init();
|
||||
zlog::init_output_stderr();
|
||||
let args = EpArgs::parse();
|
||||
|
||||
if args.printenv {
|
||||
@@ -201,141 +139,60 @@ fn main() {
|
||||
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let result = async {
|
||||
if let Command::Predict(args) = &command {
|
||||
predict::sync_batches(&args.provider).await?;
|
||||
}
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await,
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let total_examples = examples.len();
|
||||
Progress::global().set_total_examples(total_examples);
|
||||
|
||||
let mut grouped_examples = group_examples_by_repo(&mut examples);
|
||||
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
|
||||
|
||||
for example_batch in example_batches {
|
||||
let futures = example_batch.into_iter().map(|repo_examples| async {
|
||||
for example in repo_examples.iter_mut() {
|
||||
let result = async {
|
||||
match &command {
|
||||
Command::ParseExample => {}
|
||||
Command::LoadProject => {
|
||||
run_load_project(example, app_state.clone(), cx.clone())
|
||||
.await?;
|
||||
}
|
||||
Command::Context => {
|
||||
run_context_retrieval(
|
||||
example,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::FormatPrompt(args) => {
|
||||
run_format_prompt(
|
||||
example,
|
||||
args.prompt_format,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::Predict(args) => {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::Distill => {
|
||||
run_distill(example).await?;
|
||||
}
|
||||
Command::Score(args) | Command::Eval(args) => {
|
||||
run_scoring(example, &args, app_state.clone(), cx.clone())
|
||||
.await?;
|
||||
}
|
||||
Command::Clean => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
for data in examples.chunks_mut(args.max_parallelism) {
|
||||
let mut futures = Vec::new();
|
||||
for example in data.iter_mut() {
|
||||
let cx = cx.clone();
|
||||
let app_state = app_state.clone();
|
||||
futures.push(async {
|
||||
match &command {
|
||||
Command::ParseExample => {}
|
||||
Command::LoadBuffer => {
|
||||
run_load_project(example, app_state.clone(), cx).await;
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
Progress::global().increment_failed();
|
||||
let failed_example_path =
|
||||
FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name));
|
||||
app_state
|
||||
.fs
|
||||
.write(
|
||||
&failed_example_path,
|
||||
&serde_json::to_vec_pretty(&example).unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let err_path =
|
||||
FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name));
|
||||
app_state
|
||||
.fs
|
||||
.write(&err_path, e.to_string().as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let msg = format!(
|
||||
indoc::indoc! {"
|
||||
While processing {}:
|
||||
|
||||
{:?}
|
||||
|
||||
Written to: \x1b[36m{}\x1b[0m
|
||||
|
||||
Explore this example data with:
|
||||
fx \x1b[36m{}\x1b[0m
|
||||
|
||||
Re-run this example with:
|
||||
cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
|
||||
"},
|
||||
example.name,
|
||||
e,
|
||||
err_path.display(),
|
||||
failed_example_path.display(),
|
||||
command,
|
||||
failed_example_path.display(),
|
||||
);
|
||||
if args.failfast || total_examples == 1 {
|
||||
Progress::global().finalize();
|
||||
panic!("{}", msg);
|
||||
} else {
|
||||
log::error!("{}", msg);
|
||||
}
|
||||
Command::Context => {
|
||||
run_context_retrieval(example, app_state, cx).await;
|
||||
}
|
||||
Command::FormatPrompt(args) => {
|
||||
run_format_prompt(example, args.prompt_format, app_state, cx).await;
|
||||
}
|
||||
Command::Predict(args) => {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state.clone(),
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Command::Score(args) | Command::Eval(args) => {
|
||||
run_scoring(example, &args, app_state, cx).await;
|
||||
}
|
||||
Command::Clean => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
});
|
||||
futures::future::join_all(futures).await;
|
||||
}
|
||||
Progress::global().finalize();
|
||||
|
||||
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
|
||||
write_examples(&examples, output.as_ref());
|
||||
}
|
||||
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await?,
|
||||
Command::Eval(_) => score::print_report(&examples),
|
||||
_ => (),
|
||||
};
|
||||
|
||||
anyhow::Ok(())
|
||||
futures::future::join_all(futures).await;
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
panic!("Fatal error: {:?}", e);
|
||||
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
|
||||
write_examples(&examples, output.as_ref());
|
||||
}
|
||||
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await,
|
||||
Command::Eval(_) => score::print_report(&examples),
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = cx.update(|cx| cx.quit());
|
||||
})
|
||||
.detach();
|
||||
|
||||
@@ -18,8 +18,6 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
});
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
|
||||
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
|
||||
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
|
||||
|
||||
fn ensure_dir(path: &Path) -> PathBuf {
|
||||
std::fs::create_dir_all(path).expect("Failed to create directory");
|
||||
|
||||
@@ -2,14 +2,12 @@ use crate::{
|
||||
PredictionProvider, PromptFormat,
|
||||
anthropic_client::AnthropicClient,
|
||||
example::{Example, ExamplePrediction},
|
||||
format_prompt::{TeacherPrompt, run_format_prompt},
|
||||
format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
|
||||
progress::{InfoStyle, Progress, Step},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, future::Shared};
|
||||
use gpui::{AppContext as _, AsyncApp, Task};
|
||||
@@ -27,33 +25,25 @@ pub async fn run_prediction(
|
||||
repetition_count: usize,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
) {
|
||||
if !example.predictions.is_empty() {
|
||||
return Ok(());
|
||||
return;
|
||||
}
|
||||
|
||||
let provider = provider.context("provider is required")?;
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await;
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
|
||||
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
if matches!(
|
||||
provider,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
|
||||
) {
|
||||
let _step_progress = Progress::global().start(Step::Predict, &example.name);
|
||||
let provider = provider.unwrap();
|
||||
|
||||
if matches!(provider, PredictionProvider::Teacher) {
|
||||
if example.prompt.is_none() {
|
||||
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
|
||||
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
|
||||
}
|
||||
|
||||
let batched = matches!(provider, PredictionProvider::Teacher);
|
||||
let batched = true;
|
||||
return predict_anthropic(example, repetition_count, batched).await;
|
||||
}
|
||||
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let _step_progress = Progress::global().start(Step::Predict, &example.name);
|
||||
|
||||
if matches!(
|
||||
provider,
|
||||
PredictionProvider::Zeta1 | PredictionProvider::Zeta2
|
||||
@@ -63,9 +53,10 @@ pub async fn run_prediction(
|
||||
.get_or_init(|| {
|
||||
let client = app_state.client.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
|
||||
eprintln!("Authentication failed: {}", e);
|
||||
}
|
||||
client
|
||||
.sign_in_with_optional_connect(true, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
.shared()
|
||||
})
|
||||
@@ -73,30 +64,31 @@ pub async fn run_prediction(
|
||||
.await;
|
||||
}
|
||||
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
|
||||
ep_store.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
|
||||
unreachable!()
|
||||
}
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})?;
|
||||
let state = example.state.as_ref().context("state must be set")?;
|
||||
ep_store
|
||||
.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher => unreachable!(),
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})
|
||||
.unwrap();
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let run_dir = RUN_DIR.join(&example.name);
|
||||
|
||||
let updated_example = Arc::new(Mutex::new(example.clone()));
|
||||
let current_run_ix = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let mut debug_rx =
|
||||
ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
|
||||
let mut debug_rx = ep_store
|
||||
.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
|
||||
.unwrap();
|
||||
let debug_task = cx.background_spawn({
|
||||
let updated_example = updated_example.clone();
|
||||
let current_run_ix = current_run_ix.clone();
|
||||
@@ -150,14 +142,14 @@ pub async fn run_prediction(
|
||||
run_dir.clone()
|
||||
};
|
||||
|
||||
fs::create_dir_all(&run_dir)?;
|
||||
fs::create_dir_all(&run_dir).unwrap();
|
||||
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
}
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
|
||||
updated_example
|
||||
.lock()
|
||||
@@ -178,17 +170,10 @@ pub async fn run_prediction(
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let actual_patch = prediction
|
||||
.and_then(|prediction| {
|
||||
let prediction = prediction.prediction.ok()?;
|
||||
prediction.edit_preview.as_unified_diff(&prediction.edits)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let has_prediction = !actual_patch.is_empty();
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
updated_example
|
||||
.lock()
|
||||
@@ -196,35 +181,28 @@ pub async fn run_prediction(
|
||||
.predictions
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.actual_patch = actual_patch;
|
||||
|
||||
if ix == repetition_count - 1 {
|
||||
let (info, style) = if has_prediction {
|
||||
("predicted", InfoStyle::Normal)
|
||||
} else {
|
||||
("no prediction", InfoStyle::Warning)
|
||||
};
|
||||
_step_progress.set_info(info, style);
|
||||
}
|
||||
.actual_patch = prediction
|
||||
.and_then(|prediction| {
|
||||
let prediction = prediction.prediction.ok()?;
|
||||
prediction.edit_preview.as_unified_diff(&prediction.edits)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
}
|
||||
|
||||
ep_store.update(&mut cx, |store, _| {
|
||||
store.remove_project(&state.project);
|
||||
})?;
|
||||
debug_task.await?;
|
||||
ep_store
|
||||
.update(&mut cx, |store, _| {
|
||||
store.remove_project(&state.project);
|
||||
})
|
||||
.unwrap();
|
||||
debug_task.await.unwrap();
|
||||
|
||||
*example = Arc::into_inner(updated_example)
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
|
||||
.unwrap()
|
||||
.into_inner()
|
||||
.map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
|
||||
Ok(())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn predict_anthropic(
|
||||
example: &mut Example,
|
||||
_repetition_count: usize,
|
||||
batched: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
|
||||
let llm_model_name = "claude-sonnet-4-5";
|
||||
let max_tokens = 16384;
|
||||
let llm_client = if batched {
|
||||
@@ -232,9 +210,12 @@ async fn predict_anthropic(
|
||||
} else {
|
||||
AnthropicClient::plain()
|
||||
};
|
||||
let llm_client = llm_client.context("Failed to create LLM client")?;
|
||||
let llm_client = llm_client.expect("Failed to create LLM client");
|
||||
|
||||
let prompt = example.prompt.as_ref().context("Prompt is required")?;
|
||||
let prompt = example
|
||||
.prompt
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
|
||||
|
||||
let messages = vec![anthropic::Message {
|
||||
role: anthropic::Role::User,
|
||||
@@ -246,10 +227,11 @@ async fn predict_anthropic(
|
||||
|
||||
let Some(response) = llm_client
|
||||
.generate(llm_model_name, max_tokens, messages)
|
||||
.await?
|
||||
.await
|
||||
.unwrap()
|
||||
else {
|
||||
// Request stashed for batched processing
|
||||
return Ok(());
|
||||
return;
|
||||
};
|
||||
|
||||
let actual_output = response
|
||||
@@ -262,7 +244,7 @@ async fn predict_anthropic(
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
|
||||
let actual_patch = TeacherPrompt::parse(example, &actual_output);
|
||||
|
||||
let prediction = ExamplePrediction {
|
||||
actual_patch,
|
||||
@@ -271,21 +253,19 @@ async fn predict_anthropic(
|
||||
};
|
||||
|
||||
example.predictions.push(prediction);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
|
||||
pub async fn sync_batches(provider: &PredictionProvider) {
|
||||
match provider {
|
||||
PredictionProvider::Teacher => {
|
||||
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
|
||||
let llm_client =
|
||||
AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
|
||||
AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
|
||||
llm_client
|
||||
.sync_batches()
|
||||
.await
|
||||
.context("Failed to sync batches")?;
|
||||
.expect("Failed to sync batches");
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,508 +0,0 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
io::{IsTerminal, Write},
|
||||
sync::{Arc, Mutex, OnceLock},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use log::{Level, Log, Metadata, Record};
|
||||
|
||||
pub struct Progress {
|
||||
inner: Mutex<ProgressInner>,
|
||||
}
|
||||
|
||||
struct ProgressInner {
|
||||
completed: Vec<CompletedTask>,
|
||||
in_progress: HashMap<String, InProgressTask>,
|
||||
is_tty: bool,
|
||||
terminal_width: usize,
|
||||
max_example_name_len: usize,
|
||||
status_lines_displayed: usize,
|
||||
total_examples: usize,
|
||||
failed_examples: usize,
|
||||
last_line_is_logging: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct InProgressTask {
|
||||
step: Step,
|
||||
started_at: Instant,
|
||||
substatus: Option<String>,
|
||||
info: Option<(String, InfoStyle)>,
|
||||
}
|
||||
|
||||
struct CompletedTask {
|
||||
step: Step,
|
||||
example_name: String,
|
||||
duration: Duration,
|
||||
info: Option<(String, InfoStyle)>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Step {
|
||||
LoadProject,
|
||||
Context,
|
||||
FormatPrompt,
|
||||
Predict,
|
||||
Score,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum InfoStyle {
|
||||
Normal,
|
||||
Warning,
|
||||
}
|
||||
|
||||
impl Step {
|
||||
pub fn label(&self) -> &'static str {
|
||||
match self {
|
||||
Step::LoadProject => "Load",
|
||||
Step::Context => "Context",
|
||||
Step::FormatPrompt => "Format",
|
||||
Step::Predict => "Predict",
|
||||
Step::Score => "Score",
|
||||
}
|
||||
}
|
||||
|
||||
fn color_code(&self) -> &'static str {
|
||||
match self {
|
||||
Step::LoadProject => "\x1b[33m",
|
||||
Step::Context => "\x1b[35m",
|
||||
Step::FormatPrompt => "\x1b[34m",
|
||||
Step::Predict => "\x1b[32m",
|
||||
Step::Score => "\x1b[31m",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
|
||||
static LOGGER: ProgressLogger = ProgressLogger;
|
||||
|
||||
const MARGIN: usize = 4;
|
||||
const MAX_STATUS_LINES: usize = 10;
|
||||
|
||||
impl Progress {
|
||||
/// Returns the global Progress instance, initializing it if necessary.
|
||||
pub fn global() -> Arc<Progress> {
|
||||
GLOBAL
|
||||
.get_or_init(|| {
|
||||
let progress = Arc::new(Self {
|
||||
inner: Mutex::new(ProgressInner {
|
||||
completed: Vec::new(),
|
||||
in_progress: HashMap::new(),
|
||||
is_tty: std::io::stderr().is_terminal(),
|
||||
terminal_width: get_terminal_width(),
|
||||
max_example_name_len: 0,
|
||||
status_lines_displayed: 0,
|
||||
total_examples: 0,
|
||||
failed_examples: 0,
|
||||
last_line_is_logging: false,
|
||||
}),
|
||||
});
|
||||
let _ = log::set_logger(&LOGGER);
|
||||
log::set_max_level(log::LevelFilter::Error);
|
||||
progress
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn set_total_examples(&self, total: usize) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.total_examples = total;
|
||||
}
|
||||
|
||||
pub fn increment_failed(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.failed_examples += 1;
|
||||
}
|
||||
|
||||
/// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
|
||||
/// This should be used for any output that needs to appear above the status lines.
|
||||
fn log(&self, message: &str) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
if !inner.last_line_is_logging {
|
||||
let reset = "\x1b[0m";
|
||||
let dim = "\x1b[2m";
|
||||
let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
|
||||
eprintln!("{dim}{divider}{reset}");
|
||||
inner.last_line_is_logging = true;
|
||||
}
|
||||
|
||||
eprintln!("{}", message);
|
||||
}
|
||||
|
||||
pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
|
||||
inner.in_progress.insert(
|
||||
example_name.to_string(),
|
||||
InProgressTask {
|
||||
step,
|
||||
started_at: Instant::now(),
|
||||
substatus: None,
|
||||
info: None,
|
||||
},
|
||||
);
|
||||
|
||||
Self::print_status_lines(&mut inner);
|
||||
|
||||
StepProgress {
|
||||
progress: self.clone(),
|
||||
step,
|
||||
example_name: example_name.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(&self, step: Step, example_name: &str) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
|
||||
let Some(task) = inner.in_progress.remove(example_name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if task.step == step {
|
||||
inner.completed.push(CompletedTask {
|
||||
step: task.step,
|
||||
example_name: example_name.to_string(),
|
||||
duration: task.started_at.elapsed(),
|
||||
info: task.info,
|
||||
});
|
||||
|
||||
Self::clear_status_lines(&mut inner);
|
||||
Self::print_logging_closing_divider(&mut inner);
|
||||
Self::print_completed(&inner, inner.completed.last().unwrap());
|
||||
Self::print_status_lines(&mut inner);
|
||||
} else {
|
||||
inner.in_progress.insert(example_name.to_string(), task);
|
||||
}
|
||||
}
|
||||
|
||||
fn print_logging_closing_divider(inner: &mut ProgressInner) {
|
||||
if inner.last_line_is_logging {
|
||||
let reset = "\x1b[0m";
|
||||
let dim = "\x1b[2m";
|
||||
let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
|
||||
eprintln!("{dim}{divider}{reset}");
|
||||
inner.last_line_is_logging = false;
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_status_lines(inner: &mut ProgressInner) {
|
||||
if inner.is_tty && inner.status_lines_displayed > 0 {
|
||||
// Move up and clear each line we previously displayed
|
||||
for _ in 0..inner.status_lines_displayed {
|
||||
eprint!("\x1b[A\x1b[K");
|
||||
}
|
||||
let _ = std::io::stderr().flush();
|
||||
inner.status_lines_displayed = 0;
|
||||
}
|
||||
}
|
||||
|
||||
fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
|
||||
let duration = format_duration(task.duration);
|
||||
let name_width = inner.max_example_name_len;
|
||||
|
||||
if inner.is_tty {
|
||||
let reset = "\x1b[0m";
|
||||
let bold = "\x1b[1m";
|
||||
let dim = "\x1b[2m";
|
||||
|
||||
let yellow = "\x1b[33m";
|
||||
let info_part = task
|
||||
.info
|
||||
.as_ref()
|
||||
.map(|(s, style)| {
|
||||
if *style == InfoStyle::Warning {
|
||||
format!("{yellow}{s}{reset}")
|
||||
} else {
|
||||
s.to_string()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let prefix = format!(
|
||||
"{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
|
||||
color = task.step.color_code(),
|
||||
label = task.step.label(),
|
||||
name = task.example_name,
|
||||
);
|
||||
|
||||
let duration_with_margin = format!("{duration} ");
|
||||
let padding_needed = inner
|
||||
.terminal_width
|
||||
.saturating_sub(MARGIN)
|
||||
.saturating_sub(duration_with_margin.len())
|
||||
.saturating_sub(strip_ansi_len(&prefix));
|
||||
let padding = " ".repeat(padding_needed);
|
||||
|
||||
eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
|
||||
} else {
|
||||
let info_part = task
|
||||
.info
|
||||
.as_ref()
|
||||
.map(|(s, _)| format!(" | {}", s))
|
||||
.unwrap_or_default();
|
||||
|
||||
eprintln!(
|
||||
"{label:>12} {name:<name_width$}{info_part} {duration}",
|
||||
label = task.step.label(),
|
||||
name = task.example_name,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn print_status_lines(inner: &mut ProgressInner) {
|
||||
if !inner.is_tty || inner.in_progress.is_empty() {
|
||||
inner.status_lines_displayed = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
let reset = "\x1b[0m";
|
||||
let bold = "\x1b[1m";
|
||||
let dim = "\x1b[2m";
|
||||
|
||||
// Build the done/in-progress/total label
|
||||
let done_count = inner.completed.len();
|
||||
let in_progress_count = inner.in_progress.len();
|
||||
let failed_count = inner.failed_examples;
|
||||
|
||||
let failed_label = if failed_count > 0 {
|
||||
format!(" {} failed ", failed_count)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let range_label = format!(
|
||||
" {}/{}/{} ",
|
||||
done_count, in_progress_count, inner.total_examples
|
||||
);
|
||||
|
||||
// Print a divider line with failed count on left, range label on right
|
||||
let failed_visible_len = strip_ansi_len(&failed_label);
|
||||
let range_visible_len = range_label.len();
|
||||
let middle_divider_len = inner
|
||||
.terminal_width
|
||||
.saturating_sub(MARGIN * 2)
|
||||
.saturating_sub(failed_visible_len)
|
||||
.saturating_sub(range_visible_len);
|
||||
let left_divider = "─".repeat(MARGIN);
|
||||
let middle_divider = "─".repeat(middle_divider_len);
|
||||
let right_divider = "─".repeat(MARGIN);
|
||||
eprintln!(
|
||||
"{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
|
||||
);
|
||||
|
||||
let mut tasks: Vec<_> = inner.in_progress.iter().collect();
|
||||
tasks.sort_by_key(|(name, _)| *name);
|
||||
|
||||
let total_tasks = tasks.len();
|
||||
let mut lines_printed = 0;
|
||||
|
||||
for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
|
||||
let elapsed = format_duration(task.started_at.elapsed());
|
||||
let substatus_part = task
|
||||
.substatus
|
||||
.as_ref()
|
||||
.map(|s| truncate_with_ellipsis(s, 30))
|
||||
.unwrap_or_default();
|
||||
|
||||
let step_label = task.step.label();
|
||||
let step_color = task.step.color_code();
|
||||
let name_width = inner.max_example_name_len;
|
||||
|
||||
let prefix = format!(
|
||||
"{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
|
||||
name = name,
|
||||
);
|
||||
|
||||
let duration_with_margin = format!("{elapsed} ");
|
||||
let padding_needed = inner
|
||||
.terminal_width
|
||||
.saturating_sub(MARGIN)
|
||||
.saturating_sub(duration_with_margin.len())
|
||||
.saturating_sub(strip_ansi_len(&prefix));
|
||||
let padding = " ".repeat(padding_needed);
|
||||
|
||||
eprintln!("{prefix}{padding}{dim}{duration_with_margin}{reset}");
|
||||
lines_printed += 1;
|
||||
}
|
||||
|
||||
// Show "+N more" on its own line if there are more tasks
|
||||
if total_tasks > MAX_STATUS_LINES {
|
||||
let remaining = total_tasks - MAX_STATUS_LINES;
|
||||
eprintln!("{:>12} +{remaining} more", "");
|
||||
lines_printed += 1;
|
||||
}
|
||||
|
||||
inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
|
||||
let _ = std::io::stderr().flush();
|
||||
}
|
||||
|
||||
pub fn finalize(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
// Print summary if there were failures
|
||||
if inner.failed_examples > 0 {
|
||||
let total_processed = inner.completed.len() + inner.failed_examples;
|
||||
let percentage = if total_processed > 0 {
|
||||
inner.failed_examples as f64 / total_processed as f64 * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
eprintln!(
|
||||
"\n{} of {} examples failed ({:.1}%)",
|
||||
inner.failed_examples, total_processed, percentage
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StepProgress {
|
||||
progress: Arc<Progress>,
|
||||
step: Step,
|
||||
example_name: String,
|
||||
}
|
||||
|
||||
impl StepProgress {
|
||||
pub fn set_substatus(&self, substatus: impl Into<Cow<'static, str>>) {
|
||||
let mut inner = self.progress.inner.lock().unwrap();
|
||||
if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
|
||||
task.substatus = Some(substatus.into().into_owned());
|
||||
Progress::clear_status_lines(&mut inner);
|
||||
Progress::print_status_lines(&mut inner);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_substatus(&self) {
|
||||
let mut inner = self.progress.inner.lock().unwrap();
|
||||
if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
|
||||
task.substatus = None;
|
||||
Progress::clear_status_lines(&mut inner);
|
||||
Progress::print_status_lines(&mut inner);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_info(&self, info: impl Into<String>, style: InfoStyle) {
|
||||
let mut inner = self.progress.inner.lock().unwrap();
|
||||
if let Some(task) = inner.in_progress.get_mut(&self.example_name) {
|
||||
task.info = Some((info.into(), style));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StepProgress {
|
||||
fn drop(&mut self) {
|
||||
self.progress.finish(self.step, &self.example_name);
|
||||
}
|
||||
}
|
||||
|
||||
struct ProgressLogger;
|
||||
|
||||
impl Log for ProgressLogger {
|
||||
fn enabled(&self, metadata: &Metadata) -> bool {
|
||||
metadata.level() <= Level::Info
|
||||
}
|
||||
|
||||
fn log(&self, record: &Record) {
|
||||
if !self.enabled(record.metadata()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let level_color = match record.level() {
|
||||
Level::Error => "\x1b[31m",
|
||||
Level::Warn => "\x1b[33m",
|
||||
Level::Info => "\x1b[32m",
|
||||
Level::Debug => "\x1b[34m",
|
||||
Level::Trace => "\x1b[35m",
|
||||
};
|
||||
let reset = "\x1b[0m";
|
||||
let bold = "\x1b[1m";
|
||||
|
||||
let level_label = match record.level() {
|
||||
Level::Error => "Error",
|
||||
Level::Warn => "Warn",
|
||||
Level::Info => "Info",
|
||||
Level::Debug => "Debug",
|
||||
Level::Trace => "Trace",
|
||||
};
|
||||
|
||||
let message = format!(
|
||||
"{bold}{level_color}{level_label:>12}{reset} {}",
|
||||
record.args()
|
||||
);
|
||||
|
||||
if let Some(progress) = GLOBAL.get() {
|
||||
progress.log(&message);
|
||||
} else {
|
||||
eprintln!("{}", message);
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) {
|
||||
let _ = std::io::stderr().flush();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn get_terminal_width() -> usize {
|
||||
unsafe {
|
||||
let mut winsize: libc::winsize = std::mem::zeroed();
|
||||
if libc::ioctl(libc::STDERR_FILENO, libc::TIOCGWINSZ, &mut winsize) == 0
|
||||
&& winsize.ws_col > 0
|
||||
{
|
||||
winsize.ws_col as usize
|
||||
} else {
|
||||
80
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn get_terminal_width() -> usize {
|
||||
80
|
||||
}
|
||||
|
||||
fn strip_ansi_len(s: &str) -> usize {
|
||||
let mut len = 0;
|
||||
let mut in_escape = false;
|
||||
for c in s.chars() {
|
||||
if c == '\x1b' {
|
||||
in_escape = true;
|
||||
} else if in_escape {
|
||||
if c == 'm' {
|
||||
in_escape = false;
|
||||
}
|
||||
} else {
|
||||
len += 1;
|
||||
}
|
||||
}
|
||||
len
|
||||
}
|
||||
|
||||
fn truncate_with_ellipsis(s: &str, max_len: usize) -> String {
|
||||
if s.len() <= max_len {
|
||||
s.to_string()
|
||||
} else {
|
||||
format!("{}…", &s[..max_len.saturating_sub(1)])
|
||||
}
|
||||
}
|
||||
|
||||
fn format_duration(duration: Duration) -> String {
|
||||
const MINUTE_IN_MILLIS: f32 = 60. * 1000.;
|
||||
|
||||
let millis = duration.as_millis() as f32;
|
||||
if millis < 1000.0 {
|
||||
format!("{}ms", millis)
|
||||
} else if millis < MINUTE_IN_MILLIS {
|
||||
format!("{:.1}s", millis / 1_000.0)
|
||||
} else {
|
||||
format!("{:.1}m", millis / MINUTE_IN_MILLIS)
|
||||
}
|
||||
}
|
||||
@@ -2,51 +2,50 @@ use crate::{
|
||||
example::{Example, ExampleContext},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::Buffer;
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use language::{Buffer, LanguageNotFound};
|
||||
use project::Project;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
pub async fn run_context_retrieval(
|
||||
example: &mut Example,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
) {
|
||||
if example.context.is_some() {
|
||||
return Ok(());
|
||||
return;
|
||||
}
|
||||
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let step_progress: Arc<StepProgress> = Progress::global()
|
||||
.start(Step::Context, &example.name)
|
||||
.into();
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await;
|
||||
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let project = state.project.clone();
|
||||
|
||||
let _lsp_handle = project.update(&mut cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&state.buffer, cx)
|
||||
})?;
|
||||
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
|
||||
let _lsp_handle = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&state.buffer, cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
|
||||
|
||||
let mut events = ep_store.update(&mut cx, |store, cx| {
|
||||
store.register_buffer(&state.buffer, &project, cx);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
|
||||
store.debug_info(&project, cx)
|
||||
})?;
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let mut events = ep_store
|
||||
.update(&mut cx, |store, cx| {
|
||||
store.register_buffer(&state.buffer, &project, cx);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
|
||||
store.debug_info(&project, cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
@@ -57,84 +56,117 @@ pub async fn run_context_retrieval(
|
||||
}
|
||||
}
|
||||
|
||||
let context_files =
|
||||
ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx))?;
|
||||
|
||||
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
|
||||
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
|
||||
let context_files = ep_store
|
||||
.update(&mut cx, |store, cx| store.context_for_project(&project, cx))
|
||||
.unwrap();
|
||||
|
||||
example.context = Some(ExampleContext {
|
||||
files: context_files,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_language_servers_to_start(
|
||||
async fn wait_for_language_server_to_start(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
step_progress: &Arc<StepProgress>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
let lsp_store = project.read_with(cx, |project, _| project.lsp_store())?;
|
||||
) {
|
||||
let language_registry = project
|
||||
.read_with(cx, |project, _| project.languages().clone())
|
||||
.unwrap();
|
||||
let result = language_registry
|
||||
.load_language_for_file_path(&example.cursor_path)
|
||||
.await;
|
||||
|
||||
let (language_server_ids, mut starting_language_server_ids) = buffer
|
||||
if let Err(error) = result
|
||||
&& !error.is::<LanguageNotFound>()
|
||||
{
|
||||
panic!("Failed to load language for file path: {}", error);
|
||||
}
|
||||
|
||||
let Some(language_id) = buffer
|
||||
.read_with(cx, |buffer, _cx| {
|
||||
buffer.language().map(|language| language.id())
|
||||
})
|
||||
.unwrap()
|
||||
else {
|
||||
panic!("No language for {:?}", example.cursor_path);
|
||||
};
|
||||
|
||||
let mut ready_languages = HashSet::default();
|
||||
let log_prefix = format!("{} | ", example.name);
|
||||
if !ready_languages.contains(&language_id) {
|
||||
wait_for_lang_server(&project, &buffer, log_prefix, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
ready_languages.insert(language_id);
|
||||
}
|
||||
|
||||
let lsp_store = project
|
||||
.read_with(cx, |project, _cx| project.lsp_store())
|
||||
.unwrap();
|
||||
|
||||
// hacky wait for buffer to be registered with the language server
|
||||
for _ in 0..100 {
|
||||
if lsp_store
|
||||
.update(cx, |lsp_store, cx| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(&buffer, cx)
|
||||
.next()
|
||||
.map(|(_, language_server)| language_server.server_id())
|
||||
})
|
||||
})
|
||||
.unwrap()
|
||||
.is_some()
|
||||
{
|
||||
return;
|
||||
} else {
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(10))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
panic!("No language server found for buffer");
|
||||
}
|
||||
|
||||
pub fn wait_for_lang_server(
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
log_prefix: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>> {
|
||||
eprintln!("{}⏵ Waiting for language server", log_prefix);
|
||||
|
||||
let (mut tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
let lsp_store = project
|
||||
.read_with(cx, |project, _| project.lsp_store())
|
||||
.unwrap();
|
||||
|
||||
let has_lang_server = buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
lsp_store.update(cx, |lsp_store, cx| {
|
||||
let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
|
||||
let starting_ids = ids
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
|
||||
.collect::<HashSet<_>>();
|
||||
(ids, starting_ids)
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(buffer, cx)
|
||||
.next()
|
||||
.is_some()
|
||||
})
|
||||
})
|
||||
.unwrap_or_default();
|
||||
.unwrap_or(false);
|
||||
|
||||
step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
|
||||
|
||||
let timeout = cx
|
||||
.background_executor()
|
||||
.timer(Duration::from_secs(60 * 5))
|
||||
.shared();
|
||||
|
||||
let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
|
||||
let added_subscription = cx.subscribe(project, {
|
||||
let step_progress = step_progress.clone();
|
||||
move |_, event, _| match event {
|
||||
project::Event::LanguageServerAdded(language_server_id, name, _) => {
|
||||
step_progress.set_substatus(format!("LSP started: {}", name));
|
||||
tx.try_send(*language_server_id).ok();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
});
|
||||
|
||||
while !starting_language_server_ids.is_empty() {
|
||||
futures::select! {
|
||||
language_server_id = rx.next() => {
|
||||
if let Some(id) = language_server_id {
|
||||
starting_language_server_ids.remove(&id);
|
||||
}
|
||||
},
|
||||
_ = timeout.clone().fuse() => {
|
||||
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(added_subscription);
|
||||
|
||||
if !language_server_ids.is_empty() {
|
||||
if has_lang_server {
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
.unwrap()
|
||||
.detach();
|
||||
}
|
||||
let (mut added_tx, mut added_rx) = mpsc::channel(1);
|
||||
|
||||
let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
|
||||
let subscriptions = [
|
||||
cx.subscribe(&lsp_store, {
|
||||
let step_progress = step_progress.clone();
|
||||
let log_prefix = log_prefix.clone();
|
||||
move |_, event, _| {
|
||||
if let project::LspStoreEvent::LanguageServerUpdate {
|
||||
message:
|
||||
@@ -147,46 +179,50 @@ async fn wait_for_language_servers_to_start(
|
||||
..
|
||||
} = event
|
||||
{
|
||||
step_progress.set_substatus(message.clone());
|
||||
eprintln!("{}⟲ {message}", log_prefix)
|
||||
}
|
||||
}
|
||||
}),
|
||||
cx.subscribe(project, {
|
||||
let step_progress = step_progress.clone();
|
||||
move |_, event, cx| match event {
|
||||
project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
|
||||
let lsp_store = lsp_store.read(cx);
|
||||
let name = lsp_store
|
||||
.language_server_adapter_for_id(*language_server_id)
|
||||
.unwrap()
|
||||
.name();
|
||||
step_progress.set_substatus(format!("LSP idle: {}", name));
|
||||
tx.try_send(*language_server_id).ok();
|
||||
let buffer = buffer.clone();
|
||||
move |project, event, cx| match event {
|
||||
project::Event::LanguageServerAdded(_, _, _) => {
|
||||
let buffer = buffer.clone();
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer, cx))
|
||||
.detach();
|
||||
added_tx.try_send(()).ok();
|
||||
}
|
||||
project::Event::DiskBasedDiagnosticsFinished { .. } => {
|
||||
tx.try_send(()).ok();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}),
|
||||
];
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.await?;
|
||||
|
||||
let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
|
||||
while !pending_language_server_ids.is_empty() {
|
||||
futures::select! {
|
||||
language_server_id = rx.next() => {
|
||||
if let Some(id) = language_server_id {
|
||||
pending_language_server_ids.remove(&id);
|
||||
cx.spawn(async move |cx| {
|
||||
if !has_lang_server {
|
||||
// some buffers never have a language server, so this aborts quickly in that case.
|
||||
let timeout = cx.background_executor().timer(Duration::from_secs(500));
|
||||
futures::select! {
|
||||
_ = added_rx.next() => {},
|
||||
_ = timeout.fuse() => {
|
||||
anyhow::bail!("Waiting for language server add timed out after 5 seconds");
|
||||
}
|
||||
},
|
||||
_ = timeout.clone().fuse() => {
|
||||
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
drop(subscriptions);
|
||||
step_progress.clear_substatus();
|
||||
Ok(())
|
||||
let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
|
||||
let result = futures::select! {
|
||||
_ = rx.next() => {
|
||||
eprintln!("{}⚑ Language server idle", log_prefix);
|
||||
anyhow::Ok(())
|
||||
},
|
||||
_ = timeout.fuse() => {
|
||||
anyhow::bail!("LSP wait timed out after 5 minutes");
|
||||
}
|
||||
};
|
||||
drop(subscriptions);
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ use crate::{
|
||||
headless::EpAppState,
|
||||
metrics::{self, ClassificationMetrics},
|
||||
predict::run_prediction,
|
||||
progress::{Progress, Step},
|
||||
};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use gpui::AsyncApp;
|
||||
@@ -15,7 +14,7 @@ pub async fn run_scoring(
|
||||
args: &PredictArgs,
|
||||
app_state: Arc<EpAppState>,
|
||||
cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
) {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
@@ -23,9 +22,7 @@ pub async fn run_scoring(
|
||||
app_state,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let _progress = Progress::global().start(Step::Score, &example.name);
|
||||
.await;
|
||||
|
||||
let expected_patch = parse_patch(&example.expected_patch);
|
||||
|
||||
@@ -43,7 +40,6 @@ pub async fn run_scoring(
|
||||
}
|
||||
|
||||
example.score = scores;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
|
||||
|
||||
@@ -18,7 +18,6 @@ Focus on:
|
||||
Rules:
|
||||
- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
|
||||
- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
|
||||
- Keep existing formatting unless it's absolutely necessary
|
||||
|
||||
Input format:
|
||||
- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.
|
||||
|
||||
@@ -20,8 +20,8 @@ cloud_llm_client.workspace = true
|
||||
codestral.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
copilot.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
@@ -41,6 +41,7 @@ telemetry.workspace = true
|
||||
text.workspace = true
|
||||
theme.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
|
||||
@@ -3,9 +3,7 @@ use client::{Client, UserStore, zed_urls};
|
||||
use cloud_llm_client::UsageLimit;
|
||||
use codestral::CodestralEditPredictionDelegate;
|
||||
use copilot::{Copilot, Status};
|
||||
use edit_prediction::{
|
||||
EditPredictionStore, MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag,
|
||||
};
|
||||
use edit_prediction::{MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag};
|
||||
use edit_prediction_types::EditPredictionDelegateHandle;
|
||||
use editor::{
|
||||
Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll,
|
||||
@@ -44,9 +42,12 @@ use workspace::{
|
||||
StatusItemView, Toast, Workspace, create_and_open_local_file, item::ItemHandle,
|
||||
notifications::NotificationId,
|
||||
};
|
||||
use zed_actions::{OpenBrowser, OpenSettingsAt};
|
||||
use zed_actions::OpenBrowser;
|
||||
|
||||
use crate::{RatePredictions, rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag};
|
||||
use crate::{
|
||||
ExternalProviderApiKeyModal, RatePredictions,
|
||||
rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag,
|
||||
};
|
||||
|
||||
actions!(
|
||||
edit_prediction,
|
||||
@@ -247,21 +248,45 @@ impl Render for EditPredictionButton {
|
||||
EditPredictionProvider::Codestral => {
|
||||
let enabled = self.editor_enabled.unwrap_or(true);
|
||||
let has_api_key = CodestralEditPredictionDelegate::has_api_key(cx);
|
||||
let fs = self.fs.clone();
|
||||
let this = cx.weak_entity();
|
||||
|
||||
let tooltip_meta = if has_api_key {
|
||||
"Powered by Codestral"
|
||||
} else {
|
||||
"Missing API key for Codestral"
|
||||
};
|
||||
|
||||
div().child(
|
||||
PopoverMenu::new("codestral")
|
||||
.menu(move |window, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.build_codestral_context_menu(window, cx)
|
||||
})
|
||||
.ok()
|
||||
if has_api_key {
|
||||
this.update(cx, |this, cx| {
|
||||
this.build_codestral_context_menu(window, cx)
|
||||
})
|
||||
.ok()
|
||||
} else {
|
||||
Some(ContextMenu::build(window, cx, |menu, _, _| {
|
||||
let fs = fs.clone();
|
||||
|
||||
menu.entry(
|
||||
"Configure Codestral API Key",
|
||||
None,
|
||||
move |window, cx| {
|
||||
window.dispatch_action(
|
||||
zed_actions::agent::OpenSettings.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
},
|
||||
)
|
||||
.separator()
|
||||
.entry(
|
||||
"Use Zed AI instead",
|
||||
None,
|
||||
move |_, cx| {
|
||||
set_completion_provider(
|
||||
fs.clone(),
|
||||
cx,
|
||||
EditPredictionProvider::Zed,
|
||||
)
|
||||
},
|
||||
)
|
||||
}))
|
||||
}
|
||||
})
|
||||
.anchor(Corner::BottomRight)
|
||||
.trigger_with_tooltip(
|
||||
@@ -279,14 +304,7 @@ impl Render for EditPredictionButton {
|
||||
cx.theme().colors().status_bar_background,
|
||||
))
|
||||
}),
|
||||
move |_window, cx| {
|
||||
Tooltip::with_meta(
|
||||
"Edit Prediction",
|
||||
Some(&ToggleMenu),
|
||||
tooltip_meta,
|
||||
cx,
|
||||
)
|
||||
},
|
||||
move |_window, cx| Tooltip::for_action("Codestral", &ToggleMenu, cx),
|
||||
)
|
||||
.with_handle(self.popover_menu_handle.clone()),
|
||||
)
|
||||
@@ -295,7 +313,6 @@ impl Render for EditPredictionButton {
|
||||
let enabled = self.editor_enabled.unwrap_or(true);
|
||||
|
||||
let ep_icon;
|
||||
let tooltip_meta;
|
||||
let mut missing_token = false;
|
||||
|
||||
match provider {
|
||||
@@ -303,25 +320,15 @@ impl Render for EditPredictionButton {
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => {
|
||||
ep_icon = IconName::SweepAi;
|
||||
tooltip_meta = if missing_token {
|
||||
"Missing API key for Sweep"
|
||||
} else {
|
||||
"Powered by Sweep"
|
||||
};
|
||||
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token(cx));
|
||||
.is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token());
|
||||
}
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => {
|
||||
ep_icon = IconName::Inception;
|
||||
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token(cx));
|
||||
tooltip_meta = if missing_token {
|
||||
"Missing API key for Mercury"
|
||||
} else {
|
||||
"Powered by Mercury"
|
||||
};
|
||||
.is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token());
|
||||
}
|
||||
_ => {
|
||||
ep_icon = if enabled {
|
||||
@@ -329,7 +336,6 @@ impl Render for EditPredictionButton {
|
||||
} else {
|
||||
IconName::ZedPredictDisabled
|
||||
};
|
||||
tooltip_meta = "Powered by Zeta"
|
||||
}
|
||||
};
|
||||
|
||||
@@ -394,26 +400,33 @@ impl Render for EditPredictionButton {
|
||||
})
|
||||
.when(!self.popover_menu_handle.is_deployed(), |element| {
|
||||
let user = user.clone();
|
||||
|
||||
element.tooltip(move |_window, cx| {
|
||||
let description = if enabled {
|
||||
if enabled {
|
||||
if show_editor_predictions {
|
||||
tooltip_meta
|
||||
Tooltip::for_action("Edit Prediction", &ToggleMenu, cx)
|
||||
} else if user.is_none() {
|
||||
"Sign In To Use"
|
||||
Tooltip::with_meta(
|
||||
"Edit Prediction",
|
||||
Some(&ToggleMenu),
|
||||
"Sign In To Use",
|
||||
cx,
|
||||
)
|
||||
} else {
|
||||
"Hidden For This File"
|
||||
Tooltip::with_meta(
|
||||
"Edit Prediction",
|
||||
Some(&ToggleMenu),
|
||||
"Hidden For This File",
|
||||
cx,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
"Disabled For This File"
|
||||
};
|
||||
|
||||
Tooltip::with_meta(
|
||||
"Edit Prediction",
|
||||
Some(&ToggleMenu),
|
||||
description,
|
||||
cx,
|
||||
)
|
||||
Tooltip::with_meta(
|
||||
"Edit Prediction",
|
||||
Some(&ToggleMenu),
|
||||
"Disabled For This File",
|
||||
cx,
|
||||
)
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
@@ -506,12 +519,6 @@ impl EditPredictionButton {
|
||||
|
||||
providers.push(EditPredictionProvider::Zed);
|
||||
|
||||
if cx.has_flag::<Zeta2FeatureFlag>() {
|
||||
providers.push(EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(copilot) = Copilot::global(cx) {
|
||||
if matches!(copilot.read(cx).status(), Status::Authorized) {
|
||||
providers.push(EditPredictionProvider::Copilot);
|
||||
@@ -530,28 +537,24 @@ impl EditPredictionButton {
|
||||
providers.push(EditPredictionProvider::Codestral);
|
||||
}
|
||||
|
||||
let ep_store = EditPredictionStore::try_global(cx);
|
||||
|
||||
if cx.has_flag::<SweepFeatureFlag>()
|
||||
&& ep_store
|
||||
.as_ref()
|
||||
.is_some_and(|ep_store| ep_store.read(cx).has_sweep_api_token(cx))
|
||||
{
|
||||
if cx.has_flag::<SweepFeatureFlag>() {
|
||||
providers.push(EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
));
|
||||
}
|
||||
|
||||
if cx.has_flag::<MercuryFeatureFlag>()
|
||||
&& ep_store
|
||||
.as_ref()
|
||||
.is_some_and(|ep_store| ep_store.read(cx).has_mercury_api_token(cx))
|
||||
{
|
||||
if cx.has_flag::<MercuryFeatureFlag>() {
|
||||
providers.push(EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
));
|
||||
}
|
||||
|
||||
if cx.has_flag::<Zeta2FeatureFlag>() {
|
||||
providers.push(EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
));
|
||||
}
|
||||
|
||||
providers
|
||||
}
|
||||
|
||||
@@ -559,10 +562,13 @@ impl EditPredictionButton {
|
||||
&self,
|
||||
mut menu: ContextMenu,
|
||||
current_provider: EditPredictionProvider,
|
||||
cx: &mut App,
|
||||
cx: &App,
|
||||
) -> ContextMenu {
|
||||
let available_providers = self.get_available_providers(cx);
|
||||
|
||||
const ZED_AI_CALLOUT: &str =
|
||||
"Zed's edit prediction is powered by Zeta, an open-source, dataset mode.";
|
||||
|
||||
let providers: Vec<_> = available_providers
|
||||
.into_iter()
|
||||
.filter(|p| *p != EditPredictionProvider::None)
|
||||
@@ -575,32 +581,153 @@ impl EditPredictionButton {
|
||||
let is_current = provider == current_provider;
|
||||
let fs = self.fs.clone();
|
||||
|
||||
let name = match provider {
|
||||
EditPredictionProvider::Zed => "Zed AI",
|
||||
EditPredictionProvider::Copilot => "GitHub Copilot",
|
||||
EditPredictionProvider::Supermaven => "Supermaven",
|
||||
EditPredictionProvider::Codestral => "Codestral",
|
||||
menu = match provider {
|
||||
EditPredictionProvider::Zed => menu.item(
|
||||
ContextMenuEntry::new("Zed AI")
|
||||
.toggleable(IconPosition::Start, is_current)
|
||||
.documentation_aside(
|
||||
DocumentationSide::Left,
|
||||
DocumentationEdge::Bottom,
|
||||
|_| Label::new(ZED_AI_CALLOUT).into_any_element(),
|
||||
)
|
||||
.handler(move |_, cx| {
|
||||
set_completion_provider(fs.clone(), cx, provider);
|
||||
}),
|
||||
),
|
||||
EditPredictionProvider::Copilot => menu.item(
|
||||
ContextMenuEntry::new("GitHub Copilot")
|
||||
.toggleable(IconPosition::Start, is_current)
|
||||
.handler(move |_, cx| {
|
||||
set_completion_provider(fs.clone(), cx, provider);
|
||||
}),
|
||||
),
|
||||
EditPredictionProvider::Supermaven => menu.item(
|
||||
ContextMenuEntry::new("Supermaven")
|
||||
.toggleable(IconPosition::Start, is_current)
|
||||
.handler(move |_, cx| {
|
||||
set_completion_provider(fs.clone(), cx, provider);
|
||||
}),
|
||||
),
|
||||
EditPredictionProvider::Codestral => menu.item(
|
||||
ContextMenuEntry::new("Codestral")
|
||||
.toggleable(IconPosition::Start, is_current)
|
||||
.handler(move |_, cx| {
|
||||
set_completion_provider(fs.clone(), cx, provider);
|
||||
}),
|
||||
),
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => "Sweep",
|
||||
) => {
|
||||
let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
|
||||
|
||||
let should_open_modal = !has_api_token || is_current;
|
||||
|
||||
let entry = if has_api_token {
|
||||
ContextMenuEntry::new("Sweep")
|
||||
.toggleable(IconPosition::Start, is_current)
|
||||
} else {
|
||||
ContextMenuEntry::new("Sweep")
|
||||
.icon(IconName::XCircle)
|
||||
.icon_color(Color::Error)
|
||||
.documentation_aside(
|
||||
DocumentationSide::Left,
|
||||
DocumentationEdge::Bottom,
|
||||
|_| {
|
||||
Label::new("Click to configure your Sweep API token")
|
||||
.into_any_element()
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
let entry = entry.handler(move |window, cx| {
|
||||
if should_open_modal {
|
||||
if let Some(workspace) = window.root::<Workspace>().flatten() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
ExternalProviderApiKeyModal::new(
|
||||
window,
|
||||
cx,
|
||||
|api_key, store, cx| {
|
||||
store
|
||||
.sweep_ai
|
||||
.set_api_token(api_key, cx)
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
)
|
||||
});
|
||||
});
|
||||
};
|
||||
} else {
|
||||
set_completion_provider(fs.clone(), cx, provider);
|
||||
}
|
||||
});
|
||||
|
||||
menu.item(entry)
|
||||
}
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => "Mercury",
|
||||
) => {
|
||||
let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.map_or(false, |ep_store| ep_store.read(cx).has_mercury_api_token());
|
||||
|
||||
let should_open_modal = !has_api_token || is_current;
|
||||
|
||||
let entry = if has_api_token {
|
||||
ContextMenuEntry::new("Mercury")
|
||||
.toggleable(IconPosition::Start, is_current)
|
||||
} else {
|
||||
ContextMenuEntry::new("Mercury")
|
||||
.icon(IconName::XCircle)
|
||||
.icon_color(Color::Error)
|
||||
.documentation_aside(
|
||||
DocumentationSide::Left,
|
||||
DocumentationEdge::Bottom,
|
||||
|_| {
|
||||
Label::new("Click to configure your Mercury API token")
|
||||
.into_any_element()
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
let entry = entry.handler(move |window, cx| {
|
||||
if should_open_modal {
|
||||
if let Some(workspace) = window.root::<Workspace>().flatten() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
ExternalProviderApiKeyModal::new(
|
||||
window,
|
||||
cx,
|
||||
|api_key, store, cx| {
|
||||
store
|
||||
.mercury
|
||||
.set_api_token(api_key, cx)
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
)
|
||||
});
|
||||
});
|
||||
};
|
||||
} else {
|
||||
set_completion_provider(fs.clone(), cx, provider);
|
||||
}
|
||||
});
|
||||
|
||||
menu.item(entry)
|
||||
}
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => "Zeta2",
|
||||
) => menu.item(
|
||||
ContextMenuEntry::new("Zeta2")
|
||||
.toggleable(IconPosition::Start, is_current)
|
||||
.handler(move |_, cx| {
|
||||
set_completion_provider(fs.clone(), cx, provider);
|
||||
}),
|
||||
),
|
||||
EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
menu = menu.item(
|
||||
ContextMenuEntry::new(name)
|
||||
.toggleable(IconPosition::Start, is_current)
|
||||
.handler(move |_, cx| {
|
||||
set_completion_provider(fs.clone(), cx, provider);
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -705,7 +832,14 @@ impl EditPredictionButton {
|
||||
let subtle_mode = matches!(current_mode, EditPredictionsMode::Subtle);
|
||||
let eager_mode = matches!(current_mode, EditPredictionsMode::Eager);
|
||||
|
||||
menu = menu
|
||||
if matches!(
|
||||
provider,
|
||||
EditPredictionProvider::Zed
|
||||
| EditPredictionProvider::Copilot
|
||||
| EditPredictionProvider::Supermaven
|
||||
| EditPredictionProvider::Codestral
|
||||
) {
|
||||
menu = menu
|
||||
.separator()
|
||||
.header("Display Modes")
|
||||
.item(
|
||||
@@ -734,111 +868,104 @@ impl EditPredictionButton {
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
menu = menu.separator().header("Privacy");
|
||||
|
||||
if matches!(
|
||||
provider,
|
||||
EditPredictionProvider::Zed
|
||||
| EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
)
|
||||
) {
|
||||
if let Some(provider) = &self.edit_prediction_provider {
|
||||
let data_collection = provider.data_collection_state(cx);
|
||||
if let Some(provider) = &self.edit_prediction_provider {
|
||||
let data_collection = provider.data_collection_state(cx);
|
||||
|
||||
if data_collection.is_supported() {
|
||||
let provider = provider.clone();
|
||||
let enabled = data_collection.is_enabled();
|
||||
let is_open_source = data_collection.is_project_open_source();
|
||||
let is_collecting = data_collection.is_enabled();
|
||||
let (icon_name, icon_color) = if is_open_source && is_collecting {
|
||||
(IconName::Check, Color::Success)
|
||||
} else {
|
||||
(IconName::Check, Color::Accent)
|
||||
};
|
||||
if data_collection.is_supported() {
|
||||
let provider = provider.clone();
|
||||
let enabled = data_collection.is_enabled();
|
||||
let is_open_source = data_collection.is_project_open_source();
|
||||
let is_collecting = data_collection.is_enabled();
|
||||
let (icon_name, icon_color) = if is_open_source && is_collecting {
|
||||
(IconName::Check, Color::Success)
|
||||
} else {
|
||||
(IconName::Check, Color::Accent)
|
||||
};
|
||||
|
||||
menu = menu.item(
|
||||
ContextMenuEntry::new("Training Data Collection")
|
||||
.toggleable(IconPosition::Start, data_collection.is_enabled())
|
||||
.icon(icon_name)
|
||||
.icon_color(icon_color)
|
||||
.documentation_aside(DocumentationSide::Left, DocumentationEdge::Top, move |cx| {
|
||||
let (msg, label_color, icon_name, icon_color) = match (is_open_source, is_collecting) {
|
||||
(true, true) => (
|
||||
"Project identified as open source, and you're sharing data.",
|
||||
Color::Default,
|
||||
IconName::Check,
|
||||
Color::Success,
|
||||
),
|
||||
(true, false) => (
|
||||
"Project identified as open source, but you're not sharing data.",
|
||||
Color::Muted,
|
||||
IconName::Close,
|
||||
Color::Muted,
|
||||
),
|
||||
(false, true) => (
|
||||
"Project not identified as open source. No data captured.",
|
||||
Color::Muted,
|
||||
IconName::Close,
|
||||
Color::Muted,
|
||||
),
|
||||
(false, false) => (
|
||||
"Project not identified as open source, and setting turned off.",
|
||||
Color::Muted,
|
||||
IconName::Close,
|
||||
Color::Muted,
|
||||
),
|
||||
};
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new(indoc!{
|
||||
"Help us improve our open dataset model by sharing data from open source repositories. \
|
||||
Zed must detect a license file in your repo for this setting to take effect. \
|
||||
Files with sensitive data and secrets are excluded by default."
|
||||
})
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.items_start()
|
||||
.pt_2()
|
||||
.pr_1()
|
||||
.flex_1()
|
||||
.gap_1p5()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.child(h_flex().flex_shrink_0().h(line_height).child(Icon::new(icon_name).size(IconSize::XSmall).color(icon_color)))
|
||||
.child(div().child(msg).w_full().text_sm().text_color(label_color.color(cx)))
|
||||
)
|
||||
.into_any_element()
|
||||
})
|
||||
.handler(move |_, cx| {
|
||||
provider.toggle_data_collection(cx);
|
||||
|
||||
if !enabled {
|
||||
telemetry::event!(
|
||||
"Data Collection Enabled",
|
||||
source = "Edit Prediction Status Menu"
|
||||
);
|
||||
} else {
|
||||
telemetry::event!(
|
||||
"Data Collection Disabled",
|
||||
source = "Edit Prediction Status Menu"
|
||||
);
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
if is_collecting && !is_open_source {
|
||||
menu = menu.item(
|
||||
ContextMenuEntry::new("Training Data Collection")
|
||||
.toggleable(IconPosition::Start, data_collection.is_enabled())
|
||||
.icon(icon_name)
|
||||
.icon_color(icon_color)
|
||||
.documentation_aside(DocumentationSide::Left, DocumentationEdge::Top, move |cx| {
|
||||
let (msg, label_color, icon_name, icon_color) = match (is_open_source, is_collecting) {
|
||||
(true, true) => (
|
||||
"Project identified as open source, and you're sharing data.",
|
||||
Color::Default,
|
||||
IconName::Check,
|
||||
Color::Success,
|
||||
),
|
||||
(true, false) => (
|
||||
"Project identified as open source, but you're not sharing data.",
|
||||
Color::Muted,
|
||||
IconName::Close,
|
||||
Color::Muted,
|
||||
),
|
||||
(false, true) => (
|
||||
"Project not identified as open source. No data captured.",
|
||||
Color::Muted,
|
||||
IconName::Close,
|
||||
Color::Muted,
|
||||
),
|
||||
(false, false) => (
|
||||
"Project not identified as open source, and setting turned off.",
|
||||
Color::Muted,
|
||||
IconName::Close,
|
||||
Color::Muted,
|
||||
),
|
||||
};
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new(indoc!{
|
||||
"Help us improve our open dataset model by sharing data from open source repositories. \
|
||||
Zed must detect a license file in your repo for this setting to take effect. \
|
||||
Files with sensitive data and secrets are excluded by default."
|
||||
})
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.items_start()
|
||||
.pt_2()
|
||||
.pr_1()
|
||||
.flex_1()
|
||||
.gap_1p5()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.child(h_flex().flex_shrink_0().h(line_height).child(Icon::new(icon_name).size(IconSize::XSmall).color(icon_color)))
|
||||
.child(div().child(msg).w_full().text_sm().text_color(label_color.color(cx)))
|
||||
)
|
||||
.into_any_element()
|
||||
})
|
||||
.handler(move |_, cx| {
|
||||
provider.toggle_data_collection(cx);
|
||||
|
||||
if !enabled {
|
||||
telemetry::event!(
|
||||
"Data Collection Enabled",
|
||||
source = "Edit Prediction Status Menu"
|
||||
);
|
||||
} else {
|
||||
telemetry::event!(
|
||||
"Data Collection Disabled",
|
||||
source = "Edit Prediction Status Menu"
|
||||
);
|
||||
}
|
||||
})
|
||||
ContextMenuEntry::new("No data captured.")
|
||||
.disabled(true)
|
||||
.icon(IconName::Close)
|
||||
.icon_color(Color::Error)
|
||||
.icon_size(IconSize::Small),
|
||||
);
|
||||
|
||||
if is_collecting && !is_open_source {
|
||||
menu = menu.item(
|
||||
ContextMenuEntry::new("No data captured.")
|
||||
.disabled(true)
|
||||
.icon(IconName::Close)
|
||||
.icon_color(Color::Error)
|
||||
.icon_size(IconSize::Small),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -960,7 +1087,10 @@ impl EditPredictionButton {
|
||||
let menu =
|
||||
self.add_provider_switching_section(menu, EditPredictionProvider::Codestral, cx);
|
||||
|
||||
menu
|
||||
menu.separator()
|
||||
.entry("Configure Codestral API Key", None, move |window, cx| {
|
||||
window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx);
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1080,22 +1210,6 @@ impl EditPredictionButton {
|
||||
}
|
||||
|
||||
menu = self.add_provider_switching_section(menu, provider, cx);
|
||||
menu = menu.separator().item(
|
||||
ContextMenuEntry::new("Configure Providers")
|
||||
.icon(IconName::Settings)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
OpenSettingsAt {
|
||||
path: "edit_predictions.providers".to_string(),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
);
|
||||
|
||||
menu
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod edit_prediction_button;
|
||||
mod edit_prediction_context_view;
|
||||
mod external_provider_api_token_modal;
|
||||
mod rate_prediction_modal;
|
||||
|
||||
use std::any::{Any as _, TypeId};
|
||||
@@ -16,6 +17,7 @@ use ui::{App, prelude::*};
|
||||
use workspace::{SplitDirection, Workspace};
|
||||
|
||||
pub use edit_prediction_button::{EditPredictionButton, ToggleMenu};
|
||||
pub use external_provider_api_token_modal::ExternalProviderApiKeyModal;
|
||||
|
||||
use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag;
|
||||
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use gpui::{
|
||||
DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, Render,
|
||||
};
|
||||
use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*};
|
||||
use ui_input::InputField;
|
||||
use workspace::ModalView;
|
||||
|
||||
pub struct ExternalProviderApiKeyModal {
|
||||
api_key_input: Entity<InputField>,
|
||||
focus_handle: FocusHandle,
|
||||
on_confirm: Box<dyn Fn(Option<String>, &mut EditPredictionStore, &mut App)>,
|
||||
}
|
||||
|
||||
impl ExternalProviderApiKeyModal {
|
||||
pub fn new(
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
on_confirm: impl Fn(Option<String>, &mut EditPredictionStore, &mut App) + 'static,
|
||||
) -> Self {
|
||||
let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your API key"));
|
||||
|
||||
Self {
|
||||
api_key_input,
|
||||
focus_handle: cx.focus_handle(),
|
||||
on_confirm: Box::new(on_confirm),
|
||||
}
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self.api_key_input.read(cx).text(cx);
|
||||
let api_key = (!api_key.trim().is_empty()).then_some(api_key);
|
||||
|
||||
if let Some(ep_store) = EditPredictionStore::try_global(cx) {
|
||||
ep_store.update(cx, |ep_store, cx| (self.on_confirm)(api_key, ep_store, cx))
|
||||
}
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for ExternalProviderApiKeyModal {}
|
||||
|
||||
impl ModalView for ExternalProviderApiKeyModal {}
|
||||
|
||||
impl Focusable for ExternalProviderApiKeyModal {
|
||||
fn focus_handle(&self, _cx: &App) -> FocusHandle {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ExternalProviderApiKeyModal {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.key_context("ExternalApiKeyModal")
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.on_action(cx.listener(Self::confirm))
|
||||
.elevation_2(cx)
|
||||
.w(px(400.))
|
||||
.p_4()
|
||||
.gap_3()
|
||||
.child(Headline::new("API Token").size(HeadlineSize::Small))
|
||||
.child(self.api_key_input.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.gap_2()
|
||||
.child(Button::new("cancel", "Cancel").on_click(cx.listener(
|
||||
|_, _, _window, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
},
|
||||
)))
|
||||
.child(
|
||||
Button::new("save", "Save")
|
||||
.style(ButtonStyle::Filled)
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.confirm(&menu::Confirm, window, cx);
|
||||
})),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -45,7 +45,7 @@ impl Editor {
|
||||
|
||||
let bracket_matches_by_accent = self.visible_excerpts(false, cx).into_iter().fold(
|
||||
HashMap::default(),
|
||||
|mut acc, (excerpt_id, (buffer, _, buffer_range))| {
|
||||
|mut acc, (excerpt_id, (buffer, buffer_version, buffer_range))| {
|
||||
let buffer_snapshot = buffer.read(cx).snapshot();
|
||||
if language_settings::language_settings(
|
||||
buffer_snapshot.language().map(|language| language.name()),
|
||||
@@ -62,7 +62,7 @@ impl Editor {
|
||||
let brackets_by_accent = buffer_snapshot
|
||||
.fetch_bracket_ranges(
|
||||
buffer_range.start..buffer_range.end,
|
||||
Some(fetched_chunks),
|
||||
Some((&buffer_version, fetched_chunks)),
|
||||
)
|
||||
.into_iter()
|
||||
.flat_map(|(chunk_range, pairs)| {
|
||||
|
||||
@@ -56,7 +56,6 @@ use sum_tree::{Bias, TreeMap};
|
||||
use text::{BufferId, LineIndent};
|
||||
use ui::{SharedString, px};
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
use ztracing::instrument;
|
||||
|
||||
use std::{
|
||||
any::TypeId,
|
||||
@@ -169,7 +168,6 @@ impl DisplayMap {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn snapshot(&mut self, cx: &mut Context<Self>) -> DisplaySnapshot {
|
||||
let tab_size = Self::tab_size(&self.buffer, cx);
|
||||
|
||||
@@ -197,7 +195,6 @@ impl DisplayMap {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn set_state(&mut self, other: &DisplaySnapshot, cx: &mut Context<Self>) {
|
||||
self.fold(
|
||||
other
|
||||
@@ -214,7 +211,6 @@ impl DisplayMap {
|
||||
}
|
||||
|
||||
/// Creates folds for the given creases.
|
||||
#[instrument(skip_all)]
|
||||
pub fn fold<T: Clone + ToOffset>(&mut self, creases: Vec<Crease<T>>, cx: &mut Context<Self>) {
|
||||
let buffer_snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let edits = self.buffer_subscription.consume().into_inner();
|
||||
@@ -283,7 +279,6 @@ impl DisplayMap {
|
||||
}
|
||||
|
||||
/// Removes any folds with the given ranges.
|
||||
#[instrument(skip_all)]
|
||||
pub fn remove_folds_with_type<T: ToOffset>(
|
||||
&mut self,
|
||||
ranges: impl IntoIterator<Item = Range<T>>,
|
||||
@@ -309,7 +304,6 @@ impl DisplayMap {
|
||||
}
|
||||
|
||||
/// Removes any folds whose ranges intersect any of the given ranges.
|
||||
#[instrument(skip_all)]
|
||||
pub fn unfold_intersecting<T: ToOffset>(
|
||||
&mut self,
|
||||
ranges: impl IntoIterator<Item = Range<T>>,
|
||||
@@ -341,7 +335,6 @@ impl DisplayMap {
|
||||
block_map.remove_intersecting_replace_blocks(offset_ranges, inclusive);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn disable_header_for_buffer(&mut self, buffer_id: BufferId, cx: &mut Context<Self>) {
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let edits = self.buffer_subscription.consume().into_inner();
|
||||
@@ -356,7 +349,6 @@ impl DisplayMap {
|
||||
block_map.disable_header_for_buffer(buffer_id)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn fold_buffers(
|
||||
&mut self,
|
||||
buffer_ids: impl IntoIterator<Item = language::BufferId>,
|
||||
@@ -375,7 +367,6 @@ impl DisplayMap {
|
||||
block_map.fold_buffers(buffer_ids, self.buffer.read(cx), cx)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn unfold_buffers(
|
||||
&mut self,
|
||||
buffer_ids: impl IntoIterator<Item = language::BufferId>,
|
||||
@@ -394,17 +385,14 @@ impl DisplayMap {
|
||||
block_map.unfold_buffers(buffer_ids, self.buffer.read(cx), cx)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn is_buffer_folded(&self, buffer_id: language::BufferId) -> bool {
|
||||
self.block_map.folded_buffers.contains(&buffer_id)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn folded_buffers(&self) -> &HashSet<BufferId> {
|
||||
&self.block_map.folded_buffers
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn insert_creases(
|
||||
&mut self,
|
||||
creases: impl IntoIterator<Item = Crease<Anchor>>,
|
||||
@@ -414,7 +402,6 @@ impl DisplayMap {
|
||||
self.crease_map.insert(creases, &snapshot)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn remove_creases(
|
||||
&mut self,
|
||||
crease_ids: impl IntoIterator<Item = CreaseId>,
|
||||
@@ -424,7 +411,6 @@ impl DisplayMap {
|
||||
self.crease_map.remove(crease_ids, &snapshot)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn insert_blocks(
|
||||
&mut self,
|
||||
blocks: impl IntoIterator<Item = BlockProperties<Anchor>>,
|
||||
@@ -443,7 +429,6 @@ impl DisplayMap {
|
||||
block_map.insert(blocks)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn resize_blocks(&mut self, heights: HashMap<CustomBlockId, u32>, cx: &mut Context<Self>) {
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let edits = self.buffer_subscription.consume().into_inner();
|
||||
@@ -458,12 +443,10 @@ impl DisplayMap {
|
||||
block_map.resize(heights);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn replace_blocks(&mut self, renderers: HashMap<CustomBlockId, RenderBlock>) {
|
||||
self.block_map.replace_blocks(renderers);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn remove_blocks(&mut self, ids: HashSet<CustomBlockId>, cx: &mut Context<Self>) {
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let edits = self.buffer_subscription.consume().into_inner();
|
||||
@@ -478,7 +461,6 @@ impl DisplayMap {
|
||||
block_map.remove(ids);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn row_for_block(
|
||||
&mut self,
|
||||
block_id: CustomBlockId,
|
||||
@@ -498,7 +480,6 @@ impl DisplayMap {
|
||||
Some(DisplayRow(block_row.0))
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn highlight_text(
|
||||
&mut self,
|
||||
key: HighlightKey,
|
||||
@@ -526,7 +507,6 @@ impl DisplayMap {
|
||||
self.text_highlights.insert(key, to_insert);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn highlight_inlays(
|
||||
&mut self,
|
||||
type_id: TypeId,
|
||||
@@ -546,7 +526,6 @@ impl DisplayMap {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn text_highlights(&self, type_id: TypeId) -> Option<(HighlightStyle, &[Range<Anchor>])> {
|
||||
let highlights = self.text_highlights.get(&HighlightKey::Type(type_id))?;
|
||||
Some((highlights.0, &highlights.1))
|
||||
@@ -559,7 +538,6 @@ impl DisplayMap {
|
||||
self.text_highlights.values()
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn clear_highlights(&mut self, type_id: TypeId) -> bool {
|
||||
let mut cleared = self
|
||||
.text_highlights
|
||||
@@ -588,7 +566,6 @@ impl DisplayMap {
|
||||
.update(cx, |map, cx| map.set_wrap_width(width, cx))
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn update_fold_widths(
|
||||
&mut self,
|
||||
widths: impl IntoIterator<Item = (ChunkRendererId, Pixels)>,
|
||||
@@ -620,7 +597,6 @@ impl DisplayMap {
|
||||
self.inlay_map.current_inlays()
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn splice_inlays(
|
||||
&mut self,
|
||||
to_remove: &[InlayId],
|
||||
@@ -650,7 +626,6 @@ impl DisplayMap {
|
||||
self.block_map.read(snapshot, edits);
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
fn tab_size(buffer: &Entity<MultiBuffer>, cx: &App) -> NonZeroU32 {
|
||||
let buffer = buffer.read(cx).as_singleton().map(|buffer| buffer.read(cx));
|
||||
let language = buffer
|
||||
@@ -700,7 +675,6 @@ pub struct HighlightedChunk<'a> {
|
||||
}
|
||||
|
||||
impl<'a> HighlightedChunk<'a> {
|
||||
#[instrument(skip_all)]
|
||||
fn highlight_invisibles(
|
||||
self,
|
||||
editor_style: &'a EditorStyle,
|
||||
@@ -858,7 +832,6 @@ impl DisplaySnapshot {
|
||||
self.buffer_snapshot().widest_line_number()
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn prev_line_boundary(&self, mut point: MultiBufferPoint) -> (Point, DisplayPoint) {
|
||||
loop {
|
||||
let mut inlay_point = self.inlay_snapshot().to_inlay_point(point);
|
||||
@@ -877,7 +850,6 @@ impl DisplaySnapshot {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn next_line_boundary(
|
||||
&self,
|
||||
mut point: MultiBufferPoint,
|
||||
@@ -916,7 +888,6 @@ impl DisplaySnapshot {
|
||||
new_start..new_end
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn point_to_display_point(&self, point: MultiBufferPoint, bias: Bias) -> DisplayPoint {
|
||||
let inlay_point = self.inlay_snapshot().to_inlay_point(point);
|
||||
let fold_point = self.fold_snapshot().to_fold_point(inlay_point, bias);
|
||||
@@ -946,7 +917,6 @@ impl DisplaySnapshot {
|
||||
.anchor_at(point.to_offset(self, bias), bias)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
fn display_point_to_inlay_point(&self, point: DisplayPoint, bias: Bias) -> InlayPoint {
|
||||
let block_point = point.0;
|
||||
let wrap_point = self.block_snapshot.to_wrap_point(block_point, bias);
|
||||
@@ -958,7 +928,6 @@ impl DisplaySnapshot {
|
||||
fold_point.to_inlay_point(self.fold_snapshot())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn display_point_to_fold_point(&self, point: DisplayPoint, bias: Bias) -> FoldPoint {
|
||||
let block_point = point.0;
|
||||
let wrap_point = self.block_snapshot.to_wrap_point(block_point, bias);
|
||||
@@ -968,7 +937,6 @@ impl DisplaySnapshot {
|
||||
.0
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn fold_point_to_display_point(&self, fold_point: FoldPoint) -> DisplayPoint {
|
||||
let tab_point = self.tab_snapshot().fold_point_to_tab_point(fold_point);
|
||||
let wrap_point = self.wrap_snapshot().tab_point_to_wrap_point(tab_point);
|
||||
@@ -981,7 +949,6 @@ impl DisplaySnapshot {
|
||||
}
|
||||
|
||||
/// Returns text chunks starting at the given display row until the end of the file
|
||||
#[instrument(skip_all)]
|
||||
pub fn text_chunks(&self, display_row: DisplayRow) -> impl Iterator<Item = &str> {
|
||||
self.block_snapshot
|
||||
.chunks(
|
||||
@@ -994,7 +961,6 @@ impl DisplaySnapshot {
|
||||
}
|
||||
|
||||
/// Returns text chunks starting at the end of the given display row in reverse until the start of the file
|
||||
#[instrument(skip_all)]
|
||||
pub fn reverse_text_chunks(&self, display_row: DisplayRow) -> impl Iterator<Item = &str> {
|
||||
(0..=display_row.0).rev().flat_map(move |row| {
|
||||
self.block_snapshot
|
||||
@@ -1011,7 +977,6 @@ impl DisplaySnapshot {
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn chunks(
|
||||
&self,
|
||||
display_rows: Range<DisplayRow>,
|
||||
@@ -1030,7 +995,6 @@ impl DisplaySnapshot {
|
||||
)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn highlighted_chunks<'a>(
|
||||
&'a self,
|
||||
display_rows: Range<DisplayRow>,
|
||||
@@ -1107,7 +1071,6 @@ impl DisplaySnapshot {
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn layout_row(
|
||||
&self,
|
||||
display_row: DisplayRow,
|
||||
@@ -1169,7 +1132,6 @@ impl DisplaySnapshot {
|
||||
layout_line.closest_index_for_x(x) as u32
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn grapheme_at(&self, mut point: DisplayPoint) -> Option<SharedString> {
|
||||
point = DisplayPoint(self.block_snapshot.clip_point(point.0, Bias::Left));
|
||||
let chars = self
|
||||
@@ -1359,7 +1321,6 @@ impl DisplaySnapshot {
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn crease_for_buffer_row(&self, buffer_row: MultiBufferRow) -> Option<Crease<Point>> {
|
||||
let start =
|
||||
MultiBufferPoint::new(buffer_row.0, self.buffer_snapshot().line_len(buffer_row));
|
||||
@@ -1446,7 +1407,6 @@ impl DisplaySnapshot {
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
#[instrument(skip_all)]
|
||||
pub fn text_highlight_ranges<Tag: ?Sized + 'static>(
|
||||
&self,
|
||||
) -> Option<Arc<(HighlightStyle, Vec<Range<Anchor>>)>> {
|
||||
@@ -1457,7 +1417,6 @@ impl DisplaySnapshot {
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
#[instrument(skip_all)]
|
||||
pub fn all_text_highlight_ranges<Tag: ?Sized + 'static>(
|
||||
&self,
|
||||
) -> Vec<(gpui::Hsla, Range<Point>)> {
|
||||
@@ -1507,7 +1466,6 @@ impl DisplaySnapshot {
|
||||
///
|
||||
/// This moves by buffer rows instead of display rows, a distinction that is
|
||||
/// important when soft wrapping is enabled.
|
||||
#[instrument(skip_all)]
|
||||
pub fn start_of_relative_buffer_row(&self, point: DisplayPoint, times: isize) -> DisplayPoint {
|
||||
let start = self.display_point_to_fold_point(point, Bias::Left);
|
||||
let target = start.row() as isize + times;
|
||||
|
||||
@@ -529,7 +529,7 @@ impl BlockMap {
|
||||
BlockMapWriter(self)
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all, fields(edits = ?edits))]
|
||||
#[ztracing::instrument(skip_all, fields(edits))]
|
||||
fn sync(&self, wrap_snapshot: &WrapSnapshot, mut edits: WrapPatch) {
|
||||
let _timer = zlog::time!("BlockMap::sync").warn_if_gt(std::time::Duration::from_millis(50));
|
||||
|
||||
@@ -570,9 +570,6 @@ impl BlockMap {
|
||||
let mut wrap_point_cursor = wrap_snapshot.wrap_point_cursor();
|
||||
|
||||
while let Some(edit) = edits.next() {
|
||||
let span = ztracing::debug_span!("while edits", edit = ?edit);
|
||||
let _enter = span.enter();
|
||||
|
||||
let mut old_start = edit.old.start;
|
||||
let mut new_start = edit.new.start;
|
||||
|
||||
@@ -631,8 +628,6 @@ impl BlockMap {
|
||||
let mut old_end = edit.old.end;
|
||||
let mut new_end = edit.new.end;
|
||||
loop {
|
||||
let span = ztracing::debug_span!("decide where edit ends loop");
|
||||
let _enter = span.enter();
|
||||
// Seek to the transform starting at or after the end of the edit
|
||||
cursor.seek(&old_end, Bias::Left);
|
||||
cursor.next();
|
||||
@@ -741,10 +736,6 @@ impl BlockMap {
|
||||
// and then insert the block itself.
|
||||
let mut just_processed_folded_buffer = false;
|
||||
for (block_placement, block) in blocks_in_edit.drain(..) {
|
||||
let span =
|
||||
ztracing::debug_span!("for block in edits", block_height = block.height());
|
||||
let _enter = span.enter();
|
||||
|
||||
let mut summary = TransformSummary {
|
||||
input_rows: WrapRow(0),
|
||||
output_rows: BlockRow(block.height()),
|
||||
@@ -966,7 +957,6 @@ impl BlockMap {
|
||||
}
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip(tree, wrap_snapshot))]
|
||||
fn push_isomorphic(tree: &mut SumTree<Transform>, rows: RowDelta, wrap_snapshot: &WrapSnapshot) {
|
||||
if rows == RowDelta(0) {
|
||||
return;
|
||||
|
||||
@@ -840,7 +840,7 @@ impl WrapSnapshot {
|
||||
self.tab_point_to_wrap_point(self.tab_snapshot.clip_point(self.to_tab_point(point), bias))
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all, fields(point=?point, ret))]
|
||||
#[ztracing::instrument(skip_all, fields(point, ret))]
|
||||
pub fn prev_row_boundary(&self, mut point: WrapPoint) -> WrapRow {
|
||||
if self.transforms.is_empty() {
|
||||
return WrapRow(0);
|
||||
@@ -851,14 +851,11 @@ impl WrapSnapshot {
|
||||
let mut cursor = self
|
||||
.transforms
|
||||
.cursor::<Dimensions<WrapPoint, TabPoint>>(());
|
||||
// start
|
||||
cursor.seek(&point, Bias::Right);
|
||||
// end
|
||||
if cursor.item().is_none() {
|
||||
cursor.prev();
|
||||
}
|
||||
|
||||
// start
|
||||
while let Some(transform) = cursor.item() {
|
||||
if transform.is_isomorphic() && cursor.start().1.column() == 0 {
|
||||
return cmp::min(cursor.end().0.row(), point.row());
|
||||
@@ -866,7 +863,6 @@ impl WrapSnapshot {
|
||||
cursor.prev();
|
||||
}
|
||||
}
|
||||
// end
|
||||
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
@@ -7135,7 +7135,6 @@ impl Editor {
|
||||
Some((query, selection_anchor_range))
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all)]
|
||||
fn update_selection_occurrence_highlights(
|
||||
&mut self,
|
||||
query_text: String,
|
||||
@@ -7280,7 +7279,6 @@ impl Editor {
|
||||
});
|
||||
}
|
||||
|
||||
#[ztracing::instrument(skip_all)]
|
||||
fn refresh_selected_text_highlights(
|
||||
&mut self,
|
||||
on_buffer_edit: bool,
|
||||
@@ -20975,22 +20973,9 @@ impl Editor {
|
||||
buffer_ranges.last()
|
||||
}?;
|
||||
|
||||
let start_row_in_buffer = text::ToPoint::to_point(&range.start, buffer).row;
|
||||
let end_row_in_buffer = text::ToPoint::to_point(&range.end, buffer).row;
|
||||
|
||||
let Some(buffer_diff) = multi_buffer.diff_for(buffer.remote_id()) else {
|
||||
let selection = start_row_in_buffer..end_row_in_buffer;
|
||||
|
||||
return Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection));
|
||||
};
|
||||
|
||||
let buffer_diff_snapshot = buffer_diff.read(cx).snapshot(cx);
|
||||
|
||||
Some((
|
||||
multi_buffer.buffer(buffer.remote_id()).unwrap(),
|
||||
buffer_diff_snapshot.row_to_base_text_row(start_row_in_buffer, buffer)
|
||||
..buffer_diff_snapshot.row_to_base_text_row(end_row_in_buffer, buffer),
|
||||
))
|
||||
let selection = text::ToPoint::to_point(&range.start, buffer).row
|
||||
..text::ToPoint::to_point(&range.end, buffer).row;
|
||||
Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection))
|
||||
});
|
||||
|
||||
let Some((buffer, selection)) = buffer_and_selection else {
|
||||
|
||||
@@ -27701,7 +27701,6 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("x", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
@@ -27717,7 +27716,8 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.bˇ"
|
||||
- [x] Item 2.bˇ
|
||||
"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.newline(&Newline, window, cx);
|
||||
@@ -27728,41 +27728,34 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
ˇ"
|
||||
ˇ
|
||||
"
|
||||
});
|
||||
|
||||
// Case 3: Test adding a new nested list item preserves indent
|
||||
cx.set_state(&indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("-", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
-ˇ"
|
||||
-ˇ
|
||||
"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input(" [x] Item 2.c", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
- [x] Item 2.cˇ"
|
||||
- [x] Item 2.cˇ
|
||||
"
|
||||
});
|
||||
|
||||
// Case 4: Test adding new line after nested ordered list preserves indent of previous line
|
||||
@@ -27771,7 +27764,8 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.bˇ"
|
||||
2. Item 2.bˇ
|
||||
"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.newline(&Newline, window, cx);
|
||||
@@ -27782,81 +27776,60 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
ˇ"
|
||||
ˇ
|
||||
"
|
||||
});
|
||||
|
||||
// Case 5: Adding new ordered list item preserves indent
|
||||
cx.set_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("3", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
3ˇ"
|
||||
3ˇ
|
||||
"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input(".", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
3.ˇ"
|
||||
3.ˇ
|
||||
"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input(" Item 2.c", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
3. Item 2.cˇ"
|
||||
3. Item 2.cˇ
|
||||
"
|
||||
});
|
||||
|
||||
// Case 6: Test adding new line after nested ordered list preserves indent of previous line
|
||||
cx.set_state(indoc! {"
|
||||
- Item 1
|
||||
- Item 1.a
|
||||
- Item 1.a
|
||||
ˇ"});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("-", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- Item 1
|
||||
- Item 1.a
|
||||
- Item 1.a
|
||||
-ˇ"});
|
||||
|
||||
// Case 7: Test blockquote newline preserves something
|
||||
cx.set_state(indoc! {"
|
||||
> Item 1ˇ"
|
||||
> Item 1ˇ
|
||||
"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.newline(&Newline, window, cx);
|
||||
});
|
||||
cx.assert_editor_state(indoc! {"
|
||||
> Item 1
|
||||
ˇ"
|
||||
ˇ
|
||||
"
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ use theme::ActiveTheme;
|
||||
enum MatchingBracketHighlight {}
|
||||
|
||||
impl Editor {
|
||||
#[ztracing::instrument(skip_all)]
|
||||
pub fn refresh_matching_bracket_highlights(
|
||||
&mut self,
|
||||
window: &Window,
|
||||
|
||||
@@ -623,10 +623,7 @@ pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
});
|
||||
MarkdownStyle {
|
||||
base_text_style,
|
||||
code_block: StyleRefinement::default()
|
||||
.my(rems(1.))
|
||||
.font_buffer(cx)
|
||||
.font_features(buffer_font_features.clone()),
|
||||
code_block: StyleRefinement::default().my(rems(1.)).font_buffer(cx),
|
||||
inline_code: TextStyleRefinement {
|
||||
background_color: Some(cx.theme().colors().background),
|
||||
font_family: Some(buffer_font_family),
|
||||
|
||||
@@ -892,7 +892,7 @@ pub fn wait_for_lang_server(
|
||||
.update(cx, |buffer, cx| {
|
||||
lsp_store.update(cx, |lsp_store, cx| {
|
||||
lsp_store
|
||||
.running_language_servers_for_local_buffer(buffer, cx)
|
||||
.language_servers_for_local_buffer(buffer, cx)
|
||||
.next()
|
||||
.is_some()
|
||||
})
|
||||
|
||||
@@ -29,6 +29,7 @@ pub struct ExtensionHostProxy {
|
||||
slash_command_proxy: RwLock<Option<Arc<dyn ExtensionSlashCommandProxy>>>,
|
||||
context_server_proxy: RwLock<Option<Arc<dyn ExtensionContextServerProxy>>>,
|
||||
debug_adapter_provider_proxy: RwLock<Option<Arc<dyn ExtensionDebugAdapterProviderProxy>>>,
|
||||
language_model_provider_proxy: RwLock<Option<Arc<dyn ExtensionLanguageModelProviderProxy>>>,
|
||||
}
|
||||
|
||||
impl ExtensionHostProxy {
|
||||
@@ -54,6 +55,7 @@ impl ExtensionHostProxy {
|
||||
slash_command_proxy: RwLock::default(),
|
||||
context_server_proxy: RwLock::default(),
|
||||
debug_adapter_provider_proxy: RwLock::default(),
|
||||
language_model_provider_proxy: RwLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +92,15 @@ impl ExtensionHostProxy {
|
||||
.write()
|
||||
.replace(Arc::new(proxy));
|
||||
}
|
||||
|
||||
pub fn register_language_model_provider_proxy(
|
||||
&self,
|
||||
proxy: impl ExtensionLanguageModelProviderProxy,
|
||||
) {
|
||||
self.language_model_provider_proxy
|
||||
.write()
|
||||
.replace(Arc::new(proxy));
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ExtensionThemeProxy: Send + Sync + 'static {
|
||||
@@ -375,6 +386,49 @@ pub trait ExtensionContextServerProxy: Send + Sync + 'static {
|
||||
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App);
|
||||
}
|
||||
|
||||
/// A function that registers a language model provider with the registry.
|
||||
/// This allows extension_host to create the provider (which requires WasmExtension)
|
||||
/// and pass a registration closure to the language_models crate.
|
||||
pub type LanguageModelProviderRegistration = Box<dyn FnOnce(&mut App) + Send + Sync + 'static>;
|
||||
|
||||
pub trait ExtensionLanguageModelProviderProxy: Send + Sync + 'static {
|
||||
/// Register an LLM provider from an extension.
|
||||
/// The `register_fn` closure will be called with the App context and should
|
||||
/// register the provider with the LanguageModelRegistry.
|
||||
fn register_language_model_provider(
|
||||
&self,
|
||||
provider_id: Arc<str>,
|
||||
register_fn: LanguageModelProviderRegistration,
|
||||
cx: &mut App,
|
||||
);
|
||||
|
||||
/// Unregister an LLM provider when an extension is unloaded.
|
||||
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App);
|
||||
}
|
||||
|
||||
impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy {
|
||||
fn register_language_model_provider(
|
||||
&self,
|
||||
provider_id: Arc<str>,
|
||||
register_fn: LanguageModelProviderRegistration,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
|
||||
return;
|
||||
};
|
||||
|
||||
proxy.register_language_model_provider(provider_id, register_fn, cx)
|
||||
}
|
||||
|
||||
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
|
||||
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
|
||||
return;
|
||||
};
|
||||
|
||||
proxy.unregister_language_model_provider(provider_id, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtensionContextServerProxy for ExtensionHostProxy {
|
||||
fn register_context_server(
|
||||
&self,
|
||||
|
||||
@@ -93,6 +93,8 @@ pub struct ExtensionManifest {
|
||||
pub debug_adapters: BTreeMap<Arc<str>, DebugAdapterManifestEntry>,
|
||||
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
|
||||
pub debug_locators: BTreeMap<Arc<str>, DebugLocatorManifestEntry>,
|
||||
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
|
||||
pub language_model_providers: BTreeMap<Arc<str>, LanguageModelProviderManifestEntry>,
|
||||
}
|
||||
|
||||
impl ExtensionManifest {
|
||||
@@ -288,6 +290,71 @@ pub struct DebugAdapterManifestEntry {
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct DebugLocatorManifestEntry {}
|
||||
|
||||
/// Manifest entry for a language model provider.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct LanguageModelProviderManifestEntry {
|
||||
/// Display name for the provider.
|
||||
pub name: String,
|
||||
/// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
|
||||
#[serde(default)]
|
||||
pub icon: Option<String>,
|
||||
/// Default models to show even before API connection.
|
||||
#[serde(default)]
|
||||
pub models: Vec<LanguageModelManifestEntry>,
|
||||
/// Authentication configuration.
|
||||
#[serde(default)]
|
||||
pub auth: Option<LanguageModelAuthConfig>,
|
||||
}
|
||||
|
||||
/// Manifest entry for a language model.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct LanguageModelManifestEntry {
|
||||
/// Unique identifier for the model.
|
||||
pub id: String,
|
||||
/// Display name for the model.
|
||||
pub name: String,
|
||||
/// Maximum input token count.
|
||||
#[serde(default)]
|
||||
pub max_token_count: u64,
|
||||
/// Maximum output tokens (optional).
|
||||
#[serde(default)]
|
||||
pub max_output_tokens: Option<u64>,
|
||||
/// Whether the model supports image inputs.
|
||||
#[serde(default)]
|
||||
pub supports_images: bool,
|
||||
/// Whether the model supports tool/function calling.
|
||||
#[serde(default)]
|
||||
pub supports_tools: bool,
|
||||
/// Whether the model supports extended thinking/reasoning.
|
||||
#[serde(default)]
|
||||
pub supports_thinking: bool,
|
||||
}
|
||||
|
||||
/// Authentication configuration for a language model provider.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct LanguageModelAuthConfig {
|
||||
/// Environment variable name for the API key.
|
||||
#[serde(default)]
|
||||
pub env_var: Option<String>,
|
||||
/// Human-readable name for the credential shown in the UI input field (e.g., "API Key", "Access Token").
|
||||
#[serde(default)]
|
||||
pub credential_label: Option<String>,
|
||||
/// OAuth configuration for web-based authentication flows.
|
||||
#[serde(default)]
|
||||
pub oauth: Option<OAuthConfig>,
|
||||
}
|
||||
|
||||
/// OAuth configuration for web-based authentication.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct OAuthConfig {
|
||||
/// The text to display on the sign-in button (e.g., "Sign in with GitHub").
|
||||
#[serde(default)]
|
||||
pub sign_in_button_label: Option<String>,
|
||||
/// The icon to display on the sign-in button (e.g., "github").
|
||||
#[serde(default)]
|
||||
pub sign_in_button_icon: Option<String>,
|
||||
}
|
||||
|
||||
impl ExtensionManifest {
|
||||
pub async fn load(fs: Arc<dyn Fs>, extension_dir: &Path) -> Result<Self> {
|
||||
let extension_name = extension_dir
|
||||
@@ -358,6 +425,7 @@ fn manifest_from_old_manifest(
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,6 +459,7 @@ mod tests {
|
||||
capabilities: vec![],
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,27 @@ pub use wit::{
|
||||
GithubRelease, GithubReleaseAsset, GithubReleaseOptions, github_release_by_tag_name,
|
||||
latest_github_release,
|
||||
},
|
||||
zed::extension::llm_provider::{
|
||||
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
|
||||
CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
|
||||
ImageData as LlmImageData, MessageContent as LlmMessageContent,
|
||||
MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
|
||||
ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest,
|
||||
OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig,
|
||||
OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo,
|
||||
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
|
||||
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
|
||||
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
|
||||
ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
|
||||
ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
|
||||
ToolUseJsonParseError as LlmToolUseJsonParseError,
|
||||
delete_credential as llm_delete_credential, get_credential as llm_get_credential,
|
||||
get_env_var as llm_get_env_var, oauth_open_browser as llm_oauth_open_browser,
|
||||
oauth_start_web_auth as llm_oauth_start_web_auth,
|
||||
request_credential as llm_request_credential,
|
||||
send_oauth_http_request as llm_oauth_http_request,
|
||||
store_credential as llm_store_credential,
|
||||
},
|
||||
zed::extension::nodejs::{
|
||||
node_binary_path, npm_install_package, npm_package_installed_version,
|
||||
npm_package_latest_version,
|
||||
@@ -259,6 +280,94 @@ pub trait Extension: Send + Sync {
|
||||
) -> Result<DebugRequest, String> {
|
||||
Err("`run_dap_locator` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Returns information about language model providers offered by this extension.
|
||||
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// Returns the models available for a provider.
|
||||
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Returns markdown content to display in the provider's settings UI.
|
||||
/// This can include setup instructions, links to documentation, etc.
|
||||
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if the provider is authenticated.
|
||||
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Start an OAuth device flow sign-in.
|
||||
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
|
||||
/// Opens the browser to the verification URL and returns the user code that should
|
||||
/// be displayed to the user.
|
||||
fn llm_provider_start_device_flow_sign_in(
|
||||
&mut self,
|
||||
_provider_id: &str,
|
||||
) -> Result<String, String> {
|
||||
Err("`llm_provider_start_device_flow_sign_in` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Poll for device flow sign-in completion.
|
||||
/// This is called after llm_provider_start_device_flow_sign_in returns the user code.
|
||||
/// The extension should poll the OAuth provider until the user authorizes or the flow times out.
|
||||
fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
|
||||
Err("`llm_provider_poll_device_flow_sign_in` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Reset credentials for the provider.
|
||||
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
|
||||
Err("`llm_provider_reset_credentials` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Count tokens for a request.
|
||||
fn llm_count_tokens(
|
||||
&self,
|
||||
_provider_id: &str,
|
||||
_model_id: &str,
|
||||
_request: &LlmCompletionRequest,
|
||||
) -> Result<u64, String> {
|
||||
Err("`llm_count_tokens` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Start streaming a completion from the model.
|
||||
/// Returns a stream ID that can be used with `llm_stream_completion_next` and `llm_stream_completion_close`.
|
||||
fn llm_stream_completion_start(
|
||||
&mut self,
|
||||
_provider_id: &str,
|
||||
_model_id: &str,
|
||||
_request: &LlmCompletionRequest,
|
||||
) -> Result<String, String> {
|
||||
Err("`llm_stream_completion_start` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Get the next event from a completion stream.
|
||||
/// Returns `Ok(None)` when the stream is complete.
|
||||
fn llm_stream_completion_next(
|
||||
&mut self,
|
||||
_stream_id: &str,
|
||||
) -> Result<Option<LlmCompletionEvent>, String> {
|
||||
Err("`llm_stream_completion_next` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Close a completion stream and release its resources.
|
||||
fn llm_stream_completion_close(&mut self, _stream_id: &str) {
|
||||
// Default implementation does nothing
|
||||
}
|
||||
|
||||
/// Get cache configuration for a model (if prompt caching is supported).
|
||||
fn llm_cache_configuration(
|
||||
&self,
|
||||
_provider_id: &str,
|
||||
_model_id: &str,
|
||||
) -> Option<LlmCacheConfiguration> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers the provided type as a Zed extension.
|
||||
@@ -518,6 +627,65 @@ impl wit::Guest for Component {
|
||||
) -> Result<DebugRequest, String> {
|
||||
extension().run_dap_locator(locator_name, build_task)
|
||||
}
|
||||
|
||||
fn llm_providers() -> Vec<LlmProviderInfo> {
|
||||
extension().llm_providers()
|
||||
}
|
||||
|
||||
fn llm_provider_models(provider_id: String) -> Result<Vec<LlmModelInfo>, String> {
|
||||
extension().llm_provider_models(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_settings_markdown(provider_id: String) -> Option<String> {
|
||||
extension().llm_provider_settings_markdown(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_is_authenticated(provider_id: String) -> bool {
|
||||
extension().llm_provider_is_authenticated(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_start_device_flow_sign_in(provider_id: String) -> Result<String, String> {
|
||||
extension().llm_provider_start_device_flow_sign_in(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_poll_device_flow_sign_in(provider_id: String) -> Result<(), String> {
|
||||
extension().llm_provider_poll_device_flow_sign_in(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_reset_credentials(provider_id: String) -> Result<(), String> {
|
||||
extension().llm_provider_reset_credentials(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_count_tokens(
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
request: LlmCompletionRequest,
|
||||
) -> Result<u64, String> {
|
||||
extension().llm_count_tokens(&provider_id, &model_id, &request)
|
||||
}
|
||||
|
||||
fn llm_stream_completion_start(
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
request: LlmCompletionRequest,
|
||||
) -> Result<String, String> {
|
||||
extension().llm_stream_completion_start(&provider_id, &model_id, &request)
|
||||
}
|
||||
|
||||
fn llm_stream_completion_next(stream_id: String) -> Result<Option<LlmCompletionEvent>, String> {
|
||||
extension().llm_stream_completion_next(&stream_id)
|
||||
}
|
||||
|
||||
fn llm_stream_completion_close(stream_id: String) {
|
||||
extension().llm_stream_completion_close(&stream_id)
|
||||
}
|
||||
|
||||
fn llm_cache_configuration(
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
) -> Option<LlmCacheConfiguration> {
|
||||
extension().llm_cache_configuration(&provider_id, &model_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// The ID of a language server.
|
||||
|
||||
@@ -8,6 +8,7 @@ world extension {
|
||||
import platform;
|
||||
import process;
|
||||
import nodejs;
|
||||
import llm-provider;
|
||||
|
||||
use common.{env-vars, range};
|
||||
use context-server.{context-server-configuration};
|
||||
@@ -15,6 +16,10 @@ world extension {
|
||||
use lsp.{completion, symbol};
|
||||
use process.{command};
|
||||
use slash-command.{slash-command, slash-command-argument-completion, slash-command-output};
|
||||
use llm-provider.{
|
||||
provider-info, model-info, completion-request,
|
||||
credential-type, cache-configuration, completion-event, token-usage
|
||||
};
|
||||
|
||||
/// Initializes the extension.
|
||||
export init-extension: func();
|
||||
@@ -164,4 +169,74 @@ world extension {
|
||||
export dap-config-to-scenario: func(config: debug-config) -> result<debug-scenario, string>;
|
||||
export dap-locator-create-scenario: func(locator-name: string, build-config-template: build-task-template, resolved-label: string, debug-adapter-name: string) -> option<debug-scenario>;
|
||||
export run-dap-locator: func(locator-name: string, config: resolved-task) -> result<debug-request, string>;
|
||||
|
||||
/// Returns information about language model providers offered by this extension.
|
||||
export llm-providers: func() -> list<provider-info>;
|
||||
|
||||
/// Returns the models available for a provider.
|
||||
export llm-provider-models: func(provider-id: string) -> result<list<model-info>, string>;
|
||||
|
||||
/// Returns markdown content to display in the provider's settings UI.
|
||||
/// This can include setup instructions, links to documentation, etc.
|
||||
export llm-provider-settings-markdown: func(provider-id: string) -> option<string>;
|
||||
|
||||
/// Check if the provider is authenticated.
|
||||
export llm-provider-is-authenticated: func(provider-id: string) -> bool;
|
||||
|
||||
/// Start an OAuth device flow sign-in.
|
||||
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
|
||||
///
|
||||
/// The device flow works as follows:
|
||||
/// 1. Extension requests a device code from the OAuth provider
|
||||
/// 2. Extension opens the verification URL in the browser
|
||||
/// 3. Extension returns the user code to display to the user
|
||||
/// 4. Host displays the user code and calls llm-provider-poll-device-flow-sign-in
|
||||
/// 5. Extension polls for the access token while user authorizes in browser
|
||||
/// 6. Once authorized, extension stores the credential and returns success
|
||||
///
|
||||
/// Returns the user code that should be displayed to the user while they
|
||||
/// complete authorization in the browser.
|
||||
export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result<string, string>;
|
||||
|
||||
/// Poll for device flow sign-in completion.
|
||||
/// This is called after llm-provider-start-device-flow-sign-in returns the user code.
|
||||
/// The extension should poll the OAuth provider until the user authorizes or the flow times out.
|
||||
/// Returns Ok(()) on successful authentication, or an error message on failure.
|
||||
export llm-provider-poll-device-flow-sign-in: func(provider-id: string) -> result<_, string>;
|
||||
|
||||
/// Reset credentials for the provider.
|
||||
export llm-provider-reset-credentials: func(provider-id: string) -> result<_, string>;
|
||||
|
||||
/// Count tokens for a request.
|
||||
export llm-count-tokens: func(
|
||||
provider-id: string,
|
||||
model-id: string,
|
||||
request: completion-request
|
||||
) -> result<u64, string>;
|
||||
|
||||
/// Start streaming a completion from the model.
|
||||
/// Returns a stream ID that can be used with llm-stream-next and llm-stream-close.
|
||||
export llm-stream-completion-start: func(
|
||||
provider-id: string,
|
||||
model-id: string,
|
||||
request: completion-request
|
||||
) -> result<string, string>;
|
||||
|
||||
/// Get the next event from a completion stream.
|
||||
/// Returns None when the stream is complete.
|
||||
export llm-stream-completion-next: func(
|
||||
stream-id: string
|
||||
) -> result<option<completion-event>, string>;
|
||||
|
||||
/// Close a completion stream and release its resources.
|
||||
export llm-stream-completion-close: func(
|
||||
stream-id: string
|
||||
);
|
||||
|
||||
/// Get cache configuration for a model (if prompt caching is supported).
|
||||
export llm-cache-configuration: func(
|
||||
provider-id: string,
|
||||
model-id: string
|
||||
) -> option<cache-configuration>;
|
||||
|
||||
}
|
||||
|
||||
348
crates/extension_api/wit/since_v0.8.0/llm-provider.wit
Normal file
348
crates/extension_api/wit/since_v0.8.0/llm-provider.wit
Normal file
@@ -0,0 +1,348 @@
|
||||
interface llm-provider {
|
||||
/// Information about a language model provider.
|
||||
record provider-info {
|
||||
/// Unique identifier for the provider (e.g., "my-extension.my-provider").
|
||||
id: string,
|
||||
/// Display name for the provider.
|
||||
name: string,
|
||||
/// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
|
||||
icon: option<string>,
|
||||
}
|
||||
|
||||
/// Capabilities of a language model.
|
||||
record model-capabilities {
|
||||
/// Whether the model supports image inputs.
|
||||
supports-images: bool,
|
||||
/// Whether the model supports tool/function calling.
|
||||
supports-tools: bool,
|
||||
/// Whether the model supports the "auto" tool choice.
|
||||
supports-tool-choice-auto: bool,
|
||||
/// Whether the model supports the "any" tool choice.
|
||||
supports-tool-choice-any: bool,
|
||||
/// Whether the model supports the "none" tool choice.
|
||||
supports-tool-choice-none: bool,
|
||||
/// Whether the model supports extended thinking/reasoning.
|
||||
supports-thinking: bool,
|
||||
/// The format for tool input schemas.
|
||||
tool-input-format: tool-input-format,
|
||||
}
|
||||
|
||||
/// Format for tool input schemas.
|
||||
enum tool-input-format {
|
||||
/// Standard JSON Schema format.
|
||||
json-schema,
|
||||
/// Simplified schema format for certain providers.
|
||||
simplified,
|
||||
}
|
||||
|
||||
/// Information about a specific model.
|
||||
record model-info {
|
||||
/// Unique identifier for the model.
|
||||
id: string,
|
||||
/// Display name for the model.
|
||||
name: string,
|
||||
/// Maximum input token count.
|
||||
max-token-count: u64,
|
||||
/// Maximum output tokens (optional).
|
||||
max-output-tokens: option<u64>,
|
||||
/// Model capabilities.
|
||||
capabilities: model-capabilities,
|
||||
/// Whether this is the default model for the provider.
|
||||
is-default: bool,
|
||||
/// Whether this is the default fast model.
|
||||
is-default-fast: bool,
|
||||
}
|
||||
|
||||
/// The role of a message participant.
|
||||
enum message-role {
|
||||
/// User message.
|
||||
user,
|
||||
/// Assistant message.
|
||||
assistant,
|
||||
/// System message.
|
||||
system,
|
||||
}
|
||||
|
||||
/// A message in a completion request.
|
||||
record request-message {
|
||||
/// The role of the message sender.
|
||||
role: message-role,
|
||||
/// The content of the message.
|
||||
content: list<message-content>,
|
||||
/// Whether to cache this message for prompt caching.
|
||||
cache: bool,
|
||||
}
|
||||
|
||||
/// Content within a message.
|
||||
variant message-content {
|
||||
/// Plain text content.
|
||||
text(string),
|
||||
/// Image content.
|
||||
image(image-data),
|
||||
/// A tool use request from the assistant.
|
||||
tool-use(tool-use),
|
||||
/// A tool result from the user.
|
||||
tool-result(tool-result),
|
||||
/// Thinking/reasoning content.
|
||||
thinking(thinking-content),
|
||||
/// Redacted/encrypted thinking content.
|
||||
redacted-thinking(string),
|
||||
}
|
||||
|
||||
/// Image data for vision models.
|
||||
record image-data {
|
||||
/// Base64-encoded image data.
|
||||
source: string,
|
||||
/// Image width in pixels (optional).
|
||||
width: option<u32>,
|
||||
/// Image height in pixels (optional).
|
||||
height: option<u32>,
|
||||
}
|
||||
|
||||
/// A tool use request from the model.
|
||||
record tool-use {
|
||||
/// Unique identifier for this tool use.
|
||||
id: string,
|
||||
/// The name of the tool being used.
|
||||
name: string,
|
||||
/// JSON string of the tool input arguments.
|
||||
input: string,
|
||||
/// Thought signature for providers that support it (e.g., Anthropic).
|
||||
thought-signature: option<string>,
|
||||
}
|
||||
|
||||
/// A tool result to send back to the model.
|
||||
record tool-result {
|
||||
/// The ID of the tool use this is a result for.
|
||||
tool-use-id: string,
|
||||
/// The name of the tool.
|
||||
tool-name: string,
|
||||
/// Whether this result represents an error.
|
||||
is-error: bool,
|
||||
/// The content of the result.
|
||||
content: tool-result-content,
|
||||
}
|
||||
|
||||
/// Content of a tool result.
|
||||
variant tool-result-content {
|
||||
/// Text result.
|
||||
text(string),
|
||||
/// Image result.
|
||||
image(image-data),
|
||||
}
|
||||
|
||||
/// Thinking/reasoning content from models that support extended thinking.
|
||||
record thinking-content {
|
||||
/// The thinking text.
|
||||
text: string,
|
||||
/// Signature for the thinking block (provider-specific).
|
||||
signature: option<string>,
|
||||
}
|
||||
|
||||
/// A tool definition for function calling.
|
||||
record tool-definition {
|
||||
/// The name of the tool.
|
||||
name: string,
|
||||
/// Description of what the tool does.
|
||||
description: string,
|
||||
/// JSON Schema for input parameters.
|
||||
input-schema: string,
|
||||
}
|
||||
|
||||
/// Tool choice preference for the model.
|
||||
enum tool-choice {
|
||||
/// Let the model decide whether to use tools.
|
||||
auto,
|
||||
/// Force the model to use at least one tool.
|
||||
any,
|
||||
/// Prevent the model from using tools.
|
||||
none,
|
||||
}
|
||||
|
||||
/// A completion request to send to the model.
|
||||
record completion-request {
|
||||
/// The messages in the conversation.
|
||||
messages: list<request-message>,
|
||||
/// Available tools for the model to use.
|
||||
tools: list<tool-definition>,
|
||||
/// Tool choice preference.
|
||||
tool-choice: option<tool-choice>,
|
||||
/// Stop sequences to end generation.
|
||||
stop-sequences: list<string>,
|
||||
/// Temperature for sampling (0.0-1.0).
|
||||
temperature: option<f32>,
|
||||
/// Whether thinking/reasoning is allowed.
|
||||
thinking-allowed: bool,
|
||||
/// Maximum tokens to generate.
|
||||
max-tokens: option<u64>,
|
||||
}
|
||||
|
||||
/// Events emitted during completion streaming.
|
||||
variant completion-event {
|
||||
/// Completion has started.
|
||||
started,
|
||||
/// Text content chunk.
|
||||
text(string),
|
||||
/// Thinking/reasoning content chunk.
|
||||
thinking(thinking-content),
|
||||
/// Redacted thinking (encrypted) chunk.
|
||||
redacted-thinking(string),
|
||||
/// Tool use request from the model.
|
||||
tool-use(tool-use),
|
||||
/// JSON parse error when parsing tool input.
|
||||
tool-use-json-parse-error(tool-use-json-parse-error),
|
||||
/// Completion stopped.
|
||||
stop(stop-reason),
|
||||
/// Token usage update.
|
||||
usage(token-usage),
|
||||
/// Reasoning details (provider-specific JSON).
|
||||
reasoning-details(string),
|
||||
}
|
||||
|
||||
/// Error information when tool use JSON parsing fails.
|
||||
record tool-use-json-parse-error {
|
||||
/// The tool use ID.
|
||||
id: string,
|
||||
/// The tool name.
|
||||
tool-name: string,
|
||||
/// The raw input that failed to parse.
|
||||
raw-input: string,
|
||||
/// The parse error message.
|
||||
error: string,
|
||||
}
|
||||
|
||||
/// Reason the completion stopped.
|
||||
enum stop-reason {
|
||||
/// The model finished generating.
|
||||
end-turn,
|
||||
/// Maximum tokens reached.
|
||||
max-tokens,
|
||||
/// The model wants to use a tool.
|
||||
tool-use,
|
||||
/// The model refused to respond.
|
||||
refusal,
|
||||
}
|
||||
|
||||
/// Token usage statistics.
|
||||
record token-usage {
|
||||
/// Number of input tokens used.
|
||||
input-tokens: u64,
|
||||
/// Number of output tokens generated.
|
||||
output-tokens: u64,
|
||||
/// Tokens used for cache creation (if supported).
|
||||
cache-creation-input-tokens: option<u64>,
|
||||
/// Tokens read from cache (if supported).
|
||||
cache-read-input-tokens: option<u64>,
|
||||
}
|
||||
|
||||
/// Credential types that can be requested.
|
||||
enum credential-type {
|
||||
/// An API key.
|
||||
api-key,
|
||||
/// An OAuth token.
|
||||
oauth-token,
|
||||
}
|
||||
|
||||
/// Cache configuration for prompt caching.
|
||||
record cache-configuration {
|
||||
/// Maximum number of cache anchors.
|
||||
max-cache-anchors: u32,
|
||||
/// Whether caching should be applied to tool definitions.
|
||||
should-cache-tool-definitions: bool,
|
||||
/// Minimum token count for a message to be cached.
|
||||
min-total-token-count: u64,
|
||||
}
|
||||
|
||||
/// Configuration for starting an OAuth web authentication flow.
|
||||
record oauth-web-auth-config {
|
||||
/// The URL to open in the user's browser to start authentication.
|
||||
/// This should include client_id, redirect_uri, scope, state, etc.
|
||||
/// Use `{port}` as a placeholder in the URL - it will be replaced with
|
||||
/// the actual localhost port before opening the browser.
|
||||
/// Example: "https://example.com/oauth?redirect_uri=http://127.0.0.1:{port}/callback"
|
||||
auth-url: string,
|
||||
/// The path to listen on for the OAuth callback (e.g., "/callback").
|
||||
/// A localhost server will be started to receive the redirect.
|
||||
callback-path: string,
|
||||
/// Timeout in seconds to wait for the callback (default: 300 = 5 minutes).
|
||||
timeout-secs: option<u32>,
|
||||
}
|
||||
|
||||
/// Result of an OAuth web authentication flow.
|
||||
record oauth-web-auth-result {
|
||||
/// The full callback URL that was received, including query parameters.
|
||||
/// The extension is responsible for parsing the code, state, etc.
|
||||
callback-url: string,
|
||||
/// The port that was used for the localhost callback server.
|
||||
port: u32,
|
||||
}
|
||||
|
||||
/// A generic HTTP request for OAuth token exchange.
|
||||
record oauth-http-request {
|
||||
/// The URL to request.
|
||||
url: string,
|
||||
/// HTTP method (e.g., "POST", "GET").
|
||||
method: string,
|
||||
/// Request headers as key-value pairs.
|
||||
headers: list<tuple<string, string>>,
|
||||
/// Request body as a string (for form-encoded or JSON bodies).
|
||||
body: string,
|
||||
}
|
||||
|
||||
/// Response from an OAuth HTTP request.
|
||||
record oauth-http-response {
|
||||
/// HTTP status code.
|
||||
status: u16,
|
||||
/// Response headers as key-value pairs.
|
||||
headers: list<tuple<string, string>>,
|
||||
/// Response body as a string.
|
||||
body: string,
|
||||
}
|
||||
|
||||
/// Request a credential from the user.
|
||||
/// Returns true if the credential was provided, false if the user cancelled.
|
||||
request-credential: func(
|
||||
provider-id: string,
|
||||
credential-type: credential-type,
|
||||
label: string,
|
||||
placeholder: string
|
||||
) -> result<bool, string>;
|
||||
|
||||
/// Get a stored credential for this provider.
|
||||
get-credential: func(provider-id: string) -> option<string>;
|
||||
|
||||
/// Store a credential for this provider.
|
||||
store-credential: func(provider-id: string, value: string) -> result<_, string>;
|
||||
|
||||
/// Delete a stored credential for this provider.
|
||||
delete-credential: func(provider-id: string) -> result<_, string>;
|
||||
|
||||
/// Read an environment variable.
|
||||
get-env-var: func(name: string) -> option<string>;
|
||||
|
||||
/// Start an OAuth web authentication flow.
|
||||
///
|
||||
/// This will:
|
||||
/// 1. Start a localhost server to receive the OAuth callback
|
||||
/// 2. Open the auth URL in the user's default browser
|
||||
/// 3. Wait for the callback (up to the timeout)
|
||||
/// 4. Return the callback URL with query parameters
|
||||
///
|
||||
/// The extension is responsible for:
|
||||
/// - Constructing the auth URL with client_id, redirect_uri, scope, state, etc.
|
||||
/// - Parsing the callback URL to extract the authorization code
|
||||
/// - Exchanging the code for tokens using oauth-http-request
|
||||
oauth-start-web-auth: func(config: oauth-web-auth-config) -> result<oauth-web-auth-result, string>;
|
||||
|
||||
/// Make an HTTP request for OAuth token exchange.
|
||||
///
|
||||
/// This is a simple HTTP client for OAuth flows, allowing the extension
|
||||
/// to handle token exchange with full control over serialization.
|
||||
send-oauth-http-request: func(request: oauth-http-request) -> result<oauth-http-response, string>;
|
||||
|
||||
/// Open a URL in the user's default browser.
|
||||
///
|
||||
/// Useful for OAuth flows that need to open a browser but handle the
|
||||
/// callback differently (e.g., polling-based flows).
|
||||
oauth-open-browser: func(url: string) -> result<_, string>;
|
||||
}
|
||||
@@ -255,6 +255,21 @@ async fn copy_extension_resources(
|
||||
}
|
||||
}
|
||||
|
||||
for (_, provider_entry) in &manifest.language_model_providers {
|
||||
if let Some(icon_path) = &provider_entry.icon {
|
||||
let source_icon = extension_path.join(icon_path);
|
||||
let dest_icon = output_dir.join(icon_path);
|
||||
|
||||
// Create parent directory if needed
|
||||
if let Some(parent) = dest_icon.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
fs::copy(&source_icon, &dest_icon)
|
||||
.with_context(|| format!("failed to copy LLM provider icon '{}'", icon_path))?;
|
||||
}
|
||||
}
|
||||
|
||||
if !manifest.languages.is_empty() {
|
||||
let output_languages_dir = output_dir.join("languages");
|
||||
fs::create_dir_all(&output_languages_dir)?;
|
||||
|
||||
@@ -22,7 +22,10 @@ async-tar.workspace = true
|
||||
async-trait.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
dap.workspace = true
|
||||
dirs.workspace = true
|
||||
editor.workspace = true
|
||||
extension.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -30,8 +33,11 @@ gpui.workspace = true
|
||||
gpui_tokio.workspace = true
|
||||
http_client.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
markdown.workspace = true
|
||||
lsp.workspace = true
|
||||
menu.workspace = true
|
||||
moka.workspace = true
|
||||
node_runtime.workspace = true
|
||||
paths.workspace = true
|
||||
@@ -43,10 +49,13 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
task.workspace = true
|
||||
telemetry.workspace = true
|
||||
tempfile.workspace = true
|
||||
theme.workspace = true
|
||||
toml.workspace = true
|
||||
ui.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
wasmparser.workspace = true
|
||||
|
||||
@@ -148,6 +148,7 @@ fn manifest() -> ExtensionManifest {
|
||||
)],
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
124
crates/extension_host/src/anthropic_migration.rs
Normal file
124
crates/extension_host/src/anthropic_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const ANTHROPIC_EXTENSION_ID: &str = "anthropic";
|
||||
const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
|
||||
const ANTHROPIC_DEFAULT_API_URL: &str = "https://api.anthropic.com";
|
||||
|
||||
/// Migrates Anthropic API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_anthropic_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != ANTHROPIC_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
ANTHROPIC_EXTENSION_ID, ANTHROPIC_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(ANTHROPIC_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing Anthropic API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode Anthropic API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing Anthropic API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing Anthropic API key to Anthropic extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated Anthropic API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate Anthropic API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-ant-test-key-12345";
|
||||
|
||||
cx.write_credentials(ANTHROPIC_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_anthropic_credentials_if_needed(ANTHROPIC_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-anthropic:anthropic");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_anthropic_credentials_if_needed(ANTHROPIC_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-anthropic:anthropic");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-ant-test-key";
|
||||
|
||||
cx.write_credentials(ANTHROPIC_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_anthropic_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-anthropic:anthropic");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -113,6 +113,7 @@ mod tests {
|
||||
capabilities: vec![],
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
216
crates/extension_host/src/copilot_migration.rs
Normal file
216
crates/extension_host/src/copilot_migration.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
use std::path::PathBuf;
|
||||
|
||||
const COPILOT_CHAT_EXTENSION_ID: &str = "copilot-chat";
|
||||
const COPILOT_CHAT_PROVIDER_ID: &str = "copilot-chat";
|
||||
|
||||
/// Migrates Copilot OAuth credentials from the GitHub Copilot config files
|
||||
/// to the new extension-based credential location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != COPILOT_CHAT_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
// Read from copilot config files
|
||||
let oauth_token = match read_copilot_oauth_token().await {
|
||||
Some(token) if !token.is_empty() => token,
|
||||
_ => {
|
||||
log::debug!("No existing Copilot OAuth token found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &_cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate Copilot OAuth token: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
async fn read_copilot_oauth_token() -> Option<String> {
|
||||
let config_paths = copilot_config_paths();
|
||||
|
||||
for path in config_paths {
|
||||
if let Some(token) = read_oauth_token_from_file(&path).await {
|
||||
return Some(token);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn copilot_config_paths() -> Vec<PathBuf> {
|
||||
let config_dir = if cfg!(target_os = "windows") {
|
||||
dirs::data_local_dir()
|
||||
} else {
|
||||
std::env::var("XDG_CONFIG_HOME")
|
||||
.map(PathBuf::from)
|
||||
.ok()
|
||||
.or_else(|| dirs::home_dir().map(|h| h.join(".config")))
|
||||
};
|
||||
|
||||
let Some(config_dir) = config_dir else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
let copilot_dir = config_dir.join("github-copilot");
|
||||
|
||||
vec![
|
||||
copilot_dir.join("hosts.json"),
|
||||
copilot_dir.join("apps.json"),
|
||||
]
|
||||
}
|
||||
|
||||
async fn read_oauth_token_from_file(path: &PathBuf) -> Option<String> {
|
||||
let contents = match smol::fs::read_to_string(path).await {
|
||||
Ok(contents) => contents,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
extract_oauth_token(&contents, "github.com")
|
||||
}
|
||||
|
||||
fn extract_oauth_token(contents: &str, domain: &str) -> Option<String> {
|
||||
let value: serde_json::Value = serde_json::from_str(contents).ok()?;
|
||||
let obj = value.as_object()?;
|
||||
|
||||
for (key, value) in obj.iter() {
|
||||
if key.starts_with(domain) {
|
||||
if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) {
|
||||
return Some(token.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_from_hosts_json() {
|
||||
let contents = r#"{
|
||||
"github.com": {
|
||||
"oauth_token": "ghu_test_token_12345"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, Some("ghu_test_token_12345".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_with_user_suffix() {
|
||||
let contents = r#"{
|
||||
"github.com:user": {
|
||||
"oauth_token": "ghu_another_token"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, Some("ghu_another_token".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_wrong_domain() {
|
||||
let contents = r#"{
|
||||
"gitlab.com": {
|
||||
"oauth_token": "some_token"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_invalid_json() {
|
||||
let contents = "not valid json";
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_missing_oauth_token_field() {
|
||||
let contents = r#"{
|
||||
"github.com": {
|
||||
"user": "testuser"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_multiple_entries_picks_first_match() {
|
||||
let contents = r#"{
|
||||
"gitlab.com": {
|
||||
"oauth_token": "gitlab_token"
|
||||
},
|
||||
"github.com": {
|
||||
"oauth_token": "github_token"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, Some("github_token".to_string()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_copilot_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials for other extensions"
|
||||
);
|
||||
}
|
||||
|
||||
// Note: Unlike the other migrations, copilot migration reads from the filesystem
|
||||
// (copilot config files), not from the credentials provider. In tests, these files
|
||||
// don't exist, so no migration occurs.
|
||||
#[gpui::test]
|
||||
async fn test_no_credentials_when_no_copilot_config_exists(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_copilot_credentials_if_needed(COPILOT_CHAT_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"No credentials should be written when copilot config doesn't exist"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,11 @@
|
||||
mod anthropic_migration;
|
||||
mod capability_granter;
|
||||
mod copilot_migration;
|
||||
pub mod extension_settings;
|
||||
mod google_ai_migration;
|
||||
pub mod headless_host;
|
||||
mod open_router_migration;
|
||||
mod openai_migration;
|
||||
pub mod wasm_host;
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -12,13 +17,14 @@ use async_tar::Archive;
|
||||
use client::ExtensionProvides;
|
||||
use client::{Client, ExtensionMetadata, GetExtensionsResponse, proto, telemetry::Telemetry};
|
||||
use collections::{BTreeMap, BTreeSet, HashSet, btree_map};
|
||||
|
||||
pub use extension::ExtensionManifest;
|
||||
use extension::extension_builder::{CompileExtensionOptions, ExtensionBuilder};
|
||||
use extension::{
|
||||
ExtensionContextServerProxy, ExtensionDebugAdapterProviderProxy, ExtensionEvents,
|
||||
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageProxy,
|
||||
ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, ExtensionSnippetProxy,
|
||||
ExtensionThemeProxy,
|
||||
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageModelProviderProxy,
|
||||
ExtensionLanguageProxy, ExtensionLanguageServerProxy, ExtensionSlashCommandProxy,
|
||||
ExtensionSnippetProxy, ExtensionThemeProxy,
|
||||
};
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::future::join_all;
|
||||
@@ -32,8 +38,8 @@ use futures::{
|
||||
select_biased,
|
||||
};
|
||||
use gpui::{
|
||||
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task, WeakEntity,
|
||||
actions,
|
||||
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, SharedString, Task,
|
||||
WeakEntity, actions,
|
||||
};
|
||||
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
|
||||
use language::{
|
||||
@@ -53,15 +59,24 @@ use std::{
|
||||
cmp::Ordering,
|
||||
path::{self, Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
time::Duration,
|
||||
};
|
||||
use url::Url;
|
||||
use util::{ResultExt, paths::RemotePathBuf};
|
||||
use wasm_host::llm_provider::ExtensionLanguageModelProvider;
|
||||
use wasm_host::{
|
||||
WasmExtension, WasmHost,
|
||||
wit::{is_supported_wasm_api_version, wasm_api_version_range},
|
||||
wit::{LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, wasm_api_version_range},
|
||||
};
|
||||
|
||||
struct LlmProviderWithModels {
|
||||
provider_info: LlmProviderInfo,
|
||||
models: Vec<LlmModelInfo>,
|
||||
is_authenticated: bool,
|
||||
icon_path: Option<SharedString>,
|
||||
auth_config: Option<extension::LanguageModelAuthConfig>,
|
||||
}
|
||||
|
||||
pub use extension::{
|
||||
ExtensionLibraryKind, GrammarManifestEntry, OldExtensionManifest, SchemaVersion,
|
||||
};
|
||||
@@ -70,6 +85,79 @@ pub use extension_settings::ExtensionSettings;
|
||||
pub const RELOAD_DEBOUNCE_DURATION: Duration = Duration::from_millis(200);
|
||||
const FS_WATCH_LATENCY: Duration = Duration::from_millis(100);
|
||||
|
||||
/// Extension IDs that are being migrated from hardcoded LLM providers.
|
||||
/// For backwards compatibility, if the user has the corresponding env var set,
|
||||
/// we automatically enable env var reading for these extensions on first install.
|
||||
const LEGACY_LLM_EXTENSION_IDS: &[&str] = &[
|
||||
"anthropic",
|
||||
"copilot-chat",
|
||||
"google-ai",
|
||||
"openrouter",
|
||||
"openai",
|
||||
];
|
||||
|
||||
/// Migrates legacy LLM provider extensions by auto-enabling env var reading
|
||||
/// if the env var is currently present in the environment.
|
||||
///
|
||||
/// This is idempotent: if the provider is already in `allowed_env_var_providers`,
|
||||
/// we skip. This means if a user explicitly removes it, it will be re-added on
|
||||
/// next launch if the env var is still set - but that's predictable behavior.
|
||||
fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut App) {
|
||||
// Only apply migration to known legacy LLM extensions
|
||||
if !LEGACY_LLM_EXTENSION_IDS.contains(&manifest.id.as_ref()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check each provider in the manifest
|
||||
for (provider_id, provider_entry) in &manifest.language_model_providers {
|
||||
let Some(auth_config) = &provider_entry.auth else {
|
||||
continue;
|
||||
};
|
||||
let Some(env_var_name) = &auth_config.env_var else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let full_provider_id: Arc<str> = format!("{}:{}", manifest.id, provider_id).into();
|
||||
|
||||
// Check if the env var is present and non-empty
|
||||
let env_var_is_set = std::env::var(env_var_name)
|
||||
.map(|v| !v.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
// If env var isn't set, no need to do anything
|
||||
if !env_var_is_set {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if already enabled in settings
|
||||
let already_enabled = ExtensionSettings::get_global(cx)
|
||||
.allowed_env_var_providers
|
||||
.contains(full_provider_id.as_ref());
|
||||
|
||||
if already_enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Enable env var reading since the env var is set
|
||||
settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
|
||||
let full_provider_id = full_provider_id.clone();
|
||||
move |settings, _| {
|
||||
let providers = settings
|
||||
.extension
|
||||
.allowed_env_var_providers
|
||||
.get_or_insert_with(Vec::new);
|
||||
|
||||
if !providers
|
||||
.iter()
|
||||
.any(|id| id.as_ref() == full_provider_id.as_ref())
|
||||
{
|
||||
providers.push(full_provider_id);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// The current extension [`SchemaVersion`] supported by Zed.
|
||||
const CURRENT_SCHEMA_VERSION: SchemaVersion = SchemaVersion(1);
|
||||
|
||||
@@ -131,6 +219,8 @@ pub struct ExtensionStore {
|
||||
pub enum ExtensionOperation {
|
||||
Upgrade,
|
||||
Install,
|
||||
/// Auto-install from settings - triggers legacy LLM provider migrations
|
||||
AutoInstall,
|
||||
Remove,
|
||||
}
|
||||
|
||||
@@ -606,15 +696,68 @@ impl ExtensionStore {
|
||||
.extension_index
|
||||
.extensions
|
||||
.contains_key(extension_id.as_ref());
|
||||
!is_already_installed && !SUPPRESSED_EXTENSIONS.contains(&extension_id.as_ref())
|
||||
let dominated = SUPPRESSED_EXTENSIONS.contains(&extension_id.as_ref());
|
||||
!is_already_installed && !dominated
|
||||
})
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
for extension_id in extensions_to_install {
|
||||
// When enabled, this checks if an extension exists locally in the repo's extensions/
|
||||
// directory and installs it as a dev extension instead of fetching from the registry.
|
||||
// This is useful for testing auto-installed extensions before they've been published.
|
||||
// Set to `true` only during local development/testing of new auto-install extensions.
|
||||
#[cfg(debug_assertions)]
|
||||
const DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS: bool = false;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
if DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS {
|
||||
let local_extension_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.parent()
|
||||
.unwrap()
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("extensions")
|
||||
.join(extension_id.as_ref());
|
||||
|
||||
if local_extension_path.exists() {
|
||||
// Force-remove existing extension directory if it exists and isn't a symlink
|
||||
// This handles the case where the extension was previously installed from the registry
|
||||
if let Some(installed_dir) = this
|
||||
.update(cx, |this, _cx| this.installed_dir.clone())
|
||||
.ok()
|
||||
{
|
||||
let existing_path = installed_dir.join(extension_id.as_ref());
|
||||
if existing_path.exists() {
|
||||
let metadata = std::fs::symlink_metadata(&existing_path);
|
||||
let is_symlink = metadata.map(|m| m.is_symlink()).unwrap_or(false);
|
||||
if !is_symlink {
|
||||
if let Err(e) = std::fs::remove_dir_all(&existing_path) {
|
||||
log::error!(
|
||||
"Failed to remove existing extension directory {:?}: {}",
|
||||
existing_path,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(task) = this
|
||||
.update(cx, |this, cx| {
|
||||
this.install_dev_extension(local_extension_path, cx)
|
||||
})
|
||||
.ok()
|
||||
{
|
||||
task.await.log_err();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.install_latest_extension(extension_id.clone(), cx);
|
||||
this.auto_install_latest_extension(extension_id.clone(), cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
@@ -769,7 +912,10 @@ impl ExtensionStore {
|
||||
this.update(cx, |this, cx| this.reload(Some(extension_id.clone()), cx))?
|
||||
.await;
|
||||
|
||||
if let ExtensionOperation::Install = operation {
|
||||
if matches!(
|
||||
operation,
|
||||
ExtensionOperation::Install | ExtensionOperation::AutoInstall
|
||||
) {
|
||||
this.update(cx, |this, cx| {
|
||||
cx.emit(Event::ExtensionInstalled(extension_id.clone()));
|
||||
if let Some(events) = ExtensionEvents::try_global(cx)
|
||||
@@ -779,6 +925,27 @@ impl ExtensionStore {
|
||||
this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx)
|
||||
});
|
||||
}
|
||||
|
||||
// Run legacy LLM provider migrations only for auto-installed extensions
|
||||
if matches!(operation, ExtensionOperation::AutoInstall) {
|
||||
if let Some(manifest) = this.extension_manifest_for_id(&extension_id) {
|
||||
migrate_legacy_llm_provider_env_var(&manifest, cx);
|
||||
}
|
||||
copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx);
|
||||
anthropic_migration::migrate_anthropic_credentials_if_needed(
|
||||
&extension_id,
|
||||
cx,
|
||||
);
|
||||
google_ai_migration::migrate_google_ai_credentials_if_needed(
|
||||
&extension_id,
|
||||
cx,
|
||||
);
|
||||
openai_migration::migrate_openai_credentials_if_needed(&extension_id, cx);
|
||||
open_router_migration::migrate_open_router_credentials_if_needed(
|
||||
&extension_id,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
@@ -788,8 +955,24 @@ impl ExtensionStore {
|
||||
}
|
||||
|
||||
pub fn install_latest_extension(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
|
||||
log::info!("installing extension {extension_id} latest version");
|
||||
self.install_latest_extension_with_operation(extension_id, ExtensionOperation::Install, cx);
|
||||
}
|
||||
|
||||
/// Auto-install an extension, triggering legacy LLM provider migrations.
|
||||
fn auto_install_latest_extension(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
|
||||
self.install_latest_extension_with_operation(
|
||||
extension_id,
|
||||
ExtensionOperation::AutoInstall,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
fn install_latest_extension_with_operation(
|
||||
&mut self,
|
||||
extension_id: Arc<str>,
|
||||
operation: ExtensionOperation,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let schema_versions = schema_version_range();
|
||||
let wasm_api_versions = wasm_api_version_range(ReleaseChannel::global(cx));
|
||||
|
||||
@@ -812,13 +995,8 @@ impl ExtensionStore {
|
||||
return;
|
||||
};
|
||||
|
||||
self.install_or_upgrade_extension_at_endpoint(
|
||||
extension_id,
|
||||
url,
|
||||
ExtensionOperation::Install,
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
self.install_or_upgrade_extension_at_endpoint(extension_id, url, operation, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
pub fn upgrade_extension(
|
||||
@@ -837,7 +1015,6 @@ impl ExtensionStore {
|
||||
operation: ExtensionOperation,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
log::info!("installing extension {extension_id} {version}");
|
||||
let Some(url) = self
|
||||
.http_client
|
||||
.build_zed_api_url(
|
||||
@@ -1134,18 +1311,6 @@ impl ExtensionStore {
|
||||
return Task::ready(());
|
||||
}
|
||||
|
||||
let reload_count = extensions_to_unload
|
||||
.iter()
|
||||
.filter(|id| extensions_to_load.contains(id))
|
||||
.count();
|
||||
|
||||
log::info!(
|
||||
"extensions updated. loading {}, reloading {}, unloading {}",
|
||||
extensions_to_load.len() - reload_count,
|
||||
reload_count,
|
||||
extensions_to_unload.len() - reload_count
|
||||
);
|
||||
|
||||
let extension_ids = extensions_to_load
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
@@ -1220,6 +1385,11 @@ impl ExtensionStore {
|
||||
for command_name in extension.manifest.slash_commands.keys() {
|
||||
self.proxy.unregister_slash_command(command_name.clone());
|
||||
}
|
||||
for provider_id in extension.manifest.language_model_providers.keys() {
|
||||
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
|
||||
self.proxy
|
||||
.unregister_language_model_provider(full_provider_id, cx);
|
||||
}
|
||||
}
|
||||
|
||||
self.wasm_extensions
|
||||
@@ -1358,7 +1528,11 @@ impl ExtensionStore {
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut wasm_extensions = Vec::new();
|
||||
let mut wasm_extensions: Vec<(
|
||||
Arc<ExtensionManifest>,
|
||||
WasmExtension,
|
||||
Vec<LlmProviderWithModels>,
|
||||
)> = Vec::new();
|
||||
for extension in extension_entries {
|
||||
if extension.manifest.lib.kind.is_none() {
|
||||
continue;
|
||||
@@ -1376,7 +1550,122 @@ impl ExtensionStore {
|
||||
|
||||
match wasm_extension {
|
||||
Ok(wasm_extension) => {
|
||||
wasm_extensions.push((extension.manifest.clone(), wasm_extension))
|
||||
// Query for LLM providers if the manifest declares any
|
||||
let mut llm_providers_with_models = Vec::new();
|
||||
if !extension.manifest.language_model_providers.is_empty() {
|
||||
let providers_result = wasm_extension
|
||||
.call(|ext, store| {
|
||||
async move { ext.call_llm_providers(store).await }.boxed()
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Ok(Ok(providers)) = providers_result {
|
||||
for provider_info in providers {
|
||||
let models_result = wasm_extension
|
||||
.call({
|
||||
let provider_id = provider_info.id.clone();
|
||||
|ext, store| {
|
||||
async move {
|
||||
ext.call_llm_provider_models(store, &provider_id)
|
||||
.await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let models: Vec<LlmModelInfo> = match models_result {
|
||||
Ok(Ok(Ok(models))) => models,
|
||||
Ok(Ok(Err(e))) => {
|
||||
log::error!(
|
||||
"Failed to get models for LLM provider {} in extension {}: {}",
|
||||
provider_info.id,
|
||||
extension.manifest.id,
|
||||
e
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
log::error!(
|
||||
"Wasm error calling llm_provider_models for {} in extension {}: {:?}",
|
||||
provider_info.id,
|
||||
extension.manifest.id,
|
||||
e
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"Extension call failed for llm_provider_models {} in extension {}: {:?}",
|
||||
provider_info.id,
|
||||
extension.manifest.id,
|
||||
e
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Query initial authentication state
|
||||
let is_authenticated = wasm_extension
|
||||
.call({
|
||||
let provider_id = provider_info.id.clone();
|
||||
|ext, store| {
|
||||
async move {
|
||||
ext.call_llm_provider_is_authenticated(
|
||||
store,
|
||||
&provider_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or(Ok(false))
|
||||
.unwrap_or(false);
|
||||
|
||||
// Resolve icon path if provided
|
||||
let icon_path = provider_info.icon.as_ref().map(|icon| {
|
||||
let icon_file_path = extension_path.join(icon);
|
||||
// Canonicalize to resolve symlinks (dev extensions are symlinked)
|
||||
let absolute_icon_path = icon_file_path
|
||||
.canonicalize()
|
||||
.unwrap_or(icon_file_path)
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
SharedString::from(absolute_icon_path)
|
||||
});
|
||||
|
||||
let provider_id_arc: Arc<str> =
|
||||
provider_info.id.as_str().into();
|
||||
let auth_config = extension
|
||||
.manifest
|
||||
.language_model_providers
|
||||
.get(&provider_id_arc)
|
||||
.and_then(|entry| entry.auth.clone());
|
||||
|
||||
llm_providers_with_models.push(LlmProviderWithModels {
|
||||
provider_info,
|
||||
models,
|
||||
is_authenticated,
|
||||
icon_path,
|
||||
auth_config,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
log::error!(
|
||||
"Failed to get LLM providers from extension {}: {:?}",
|
||||
extension.manifest.id,
|
||||
providers_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
wasm_extensions.push((
|
||||
extension.manifest.clone(),
|
||||
wasm_extension,
|
||||
llm_providers_with_models,
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
@@ -1395,7 +1684,7 @@ impl ExtensionStore {
|
||||
this.update(cx, |this, cx| {
|
||||
this.reload_complete_senders.clear();
|
||||
|
||||
for (manifest, wasm_extension) in &wasm_extensions {
|
||||
for (manifest, wasm_extension, llm_providers_with_models) in &wasm_extensions {
|
||||
let extension = Arc::new(wasm_extension.clone());
|
||||
|
||||
for (language_server_id, language_server_config) in &manifest.language_servers {
|
||||
@@ -1449,9 +1738,41 @@ impl ExtensionStore {
|
||||
this.proxy
|
||||
.register_debug_locator(extension.clone(), debug_adapter.clone());
|
||||
}
|
||||
|
||||
// Register LLM providers
|
||||
for llm_provider in llm_providers_with_models {
|
||||
let provider_id: Arc<str> =
|
||||
format!("{}:{}", manifest.id, llm_provider.provider_info.id).into();
|
||||
let wasm_ext = extension.as_ref().clone();
|
||||
let pinfo = llm_provider.provider_info.clone();
|
||||
let mods = llm_provider.models.clone();
|
||||
let auth = llm_provider.is_authenticated;
|
||||
let icon = llm_provider.icon_path.clone();
|
||||
let auth_config = llm_provider.auth_config.clone();
|
||||
|
||||
this.proxy.register_language_model_provider(
|
||||
provider_id.clone(),
|
||||
Box::new(move |cx: &mut App| {
|
||||
let provider = Arc::new(ExtensionLanguageModelProvider::new(
|
||||
wasm_ext, pinfo, mods, auth, icon, auth_config, cx,
|
||||
));
|
||||
language_model::LanguageModelRegistry::global(cx).update(
|
||||
cx,
|
||||
|registry, cx| {
|
||||
registry.register_provider(provider, cx);
|
||||
},
|
||||
);
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
this.wasm_extensions.extend(wasm_extensions);
|
||||
let wasm_extensions_without_llm: Vec<_> = wasm_extensions
|
||||
.into_iter()
|
||||
.map(|(manifest, ext, _)| (manifest, ext))
|
||||
.collect();
|
||||
this.wasm_extensions.extend(wasm_extensions_without_llm);
|
||||
this.proxy.set_extensions_loaded();
|
||||
this.proxy.reload_current_theme(cx);
|
||||
this.proxy.reload_current_icon_theme(cx);
|
||||
@@ -1473,7 +1794,6 @@ impl ExtensionStore {
|
||||
let index_path = self.index_path.clone();
|
||||
let proxy = self.proxy.clone();
|
||||
cx.background_spawn(async move {
|
||||
let start_time = Instant::now();
|
||||
let mut index = ExtensionIndex::default();
|
||||
|
||||
fs.create_dir(&work_dir).await.log_err();
|
||||
@@ -1511,7 +1831,6 @@ impl ExtensionStore {
|
||||
.log_err();
|
||||
}
|
||||
|
||||
log::info!("rebuilt extension index in {:?}", start_time.elapsed());
|
||||
index
|
||||
})
|
||||
}
|
||||
@@ -1785,11 +2104,6 @@ impl ExtensionStore {
|
||||
})?,
|
||||
path_style,
|
||||
);
|
||||
log::info!(
|
||||
"Uploading extension {} to {:?}",
|
||||
missing_extension.clone().id,
|
||||
dest_dir
|
||||
);
|
||||
|
||||
client
|
||||
.update(cx, |client, cx| {
|
||||
@@ -1797,11 +2111,6 @@ impl ExtensionStore {
|
||||
})?
|
||||
.await?;
|
||||
|
||||
log::info!(
|
||||
"Finished uploading extension {}",
|
||||
missing_extension.clone().id
|
||||
);
|
||||
|
||||
let result = client
|
||||
.update(cx, |client, _cx| {
|
||||
client.proto_client().request(proto::InstallExtension {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use collections::HashMap;
|
||||
use collections::{HashMap, HashSet};
|
||||
use extension::{
|
||||
DownloadFileCapability, ExtensionCapability, NpmInstallPackageCapability, ProcessExecCapability,
|
||||
};
|
||||
@@ -16,6 +16,10 @@ pub struct ExtensionSettings {
|
||||
pub auto_install_extensions: HashMap<Arc<str>, bool>,
|
||||
pub auto_update_extensions: HashMap<Arc<str>, bool>,
|
||||
pub granted_capabilities: Vec<ExtensionCapability>,
|
||||
/// The extension language model providers that are allowed to read API keys
|
||||
/// from environment variables. Each entry is a provider ID in the format
|
||||
/// "extension_id:provider_id".
|
||||
pub allowed_env_var_providers: HashSet<Arc<str>>,
|
||||
}
|
||||
|
||||
impl ExtensionSettings {
|
||||
@@ -60,6 +64,13 @@ impl Settings for ExtensionSettings {
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
allowed_env_var_providers: content
|
||||
.extension
|
||||
.allowed_env_var_providers
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,6 +165,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}),
|
||||
dev: false,
|
||||
},
|
||||
@@ -196,6 +197,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}),
|
||||
dev: false,
|
||||
},
|
||||
@@ -376,6 +378,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}),
|
||||
dev: false,
|
||||
},
|
||||
|
||||
124
crates/extension_host/src/google_ai_migration.rs
Normal file
124
crates/extension_host/src/google_ai_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const GOOGLE_AI_EXTENSION_ID: &str = "google-ai";
|
||||
const GOOGLE_AI_PROVIDER_ID: &str = "google-ai";
|
||||
const GOOGLE_AI_DEFAULT_API_URL: &str = "https://generativelanguage.googleapis.com";
|
||||
|
||||
/// Migrates Google AI API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_google_ai_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != GOOGLE_AI_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
GOOGLE_AI_EXTENSION_ID, GOOGLE_AI_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(GOOGLE_AI_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing Google AI API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode Google AI API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing Google AI API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing Google AI API key to Google AI extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated Google AI API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate Google AI API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "AIzaSy-test-key-12345";
|
||||
|
||||
cx.write_credentials(GOOGLE_AI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_google_ai_credentials_if_needed(GOOGLE_AI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-google-ai:google-ai");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_google_ai_credentials_if_needed(GOOGLE_AI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-google-ai:google-ai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "AIzaSy-test-key";
|
||||
|
||||
cx.write_credentials(GOOGLE_AI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_google_ai_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-google-ai:google-ai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
124
crates/extension_host/src/open_router_migration.rs
Normal file
124
crates/extension_host/src/open_router_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const OPEN_ROUTER_EXTENSION_ID: &str = "openrouter";
|
||||
const OPEN_ROUTER_PROVIDER_ID: &str = "openrouter";
|
||||
const OPEN_ROUTER_DEFAULT_API_URL: &str = "https://openrouter.ai/api/v1";
|
||||
|
||||
/// Migrates OpenRouter API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_open_router_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != OPEN_ROUTER_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
OPEN_ROUTER_EXTENSION_ID, OPEN_ROUTER_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(OPEN_ROUTER_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing OpenRouter API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode OpenRouter API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing OpenRouter API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing OpenRouter API key to OpenRouter extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated OpenRouter API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate OpenRouter API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-or-test-key-12345";
|
||||
|
||||
cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-openrouter:openrouter");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openrouter:openrouter");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-or-test-key";
|
||||
|
||||
cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_open_router_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openrouter:openrouter");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
124
crates/extension_host/src/openai_migration.rs
Normal file
124
crates/extension_host/src/openai_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const OPENAI_EXTENSION_ID: &str = "openai";
|
||||
const OPENAI_PROVIDER_ID: &str = "openai";
|
||||
const OPENAI_DEFAULT_API_URL: &str = "https://api.openai.com/v1";
|
||||
|
||||
/// Migrates OpenAI API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_openai_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != OPENAI_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
OPENAI_EXTENSION_ID, OPENAI_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(OPENAI_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing OpenAI API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode OpenAI API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing OpenAI API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing OpenAI API key to OpenAI extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated OpenAI API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate OpenAI API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-test-key-12345";
|
||||
|
||||
cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-openai:openai");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openai:openai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-test-key";
|
||||
|
||||
cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_openai_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openai:openai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
pub mod llm_provider;
|
||||
pub mod wit;
|
||||
|
||||
use crate::capability_granter::CapabilityGranter;
|
||||
use crate::{ExtensionManifest, ExtensionSettings};
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest};
|
||||
use extension::{
|
||||
CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary,
|
||||
@@ -64,7 +66,7 @@ pub struct WasmHost {
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WasmExtension {
|
||||
tx: UnboundedSender<ExtensionCall>,
|
||||
tx: Arc<UnboundedSender<ExtensionCall>>,
|
||||
pub manifest: Arc<ExtensionManifest>,
|
||||
pub work_dir: Arc<Path>,
|
||||
#[allow(unused)]
|
||||
@@ -74,7 +76,10 @@ pub struct WasmExtension {
|
||||
|
||||
impl Drop for WasmExtension {
|
||||
fn drop(&mut self) {
|
||||
self.tx.close_channel();
|
||||
// Only close the channel when this is the last clone holding the sender
|
||||
if Arc::strong_count(&self.tx) == 1 {
|
||||
self.tx.close_channel();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -671,7 +676,7 @@ impl WasmHost {
|
||||
Ok(WasmExtension {
|
||||
manifest,
|
||||
work_dir,
|
||||
tx,
|
||||
tx: Arc::new(tx),
|
||||
zed_api_version,
|
||||
_task: task,
|
||||
})
|
||||
|
||||
1464
crates/extension_host/src/wasm_host/llm_provider.rs
Normal file
1464
crates/extension_host/src/wasm_host/llm_provider.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@ use lsp::LanguageServerName;
|
||||
use release_channel::ReleaseChannel;
|
||||
use task::{DebugScenario, SpawnInTerminal, TaskTemplate, ZedDebugConfig};
|
||||
|
||||
use crate::wasm_host::wit::since_v0_6_0::dap::StartDebuggingRequestArgumentsRequest;
|
||||
use crate::wasm_host::wit::since_v0_8_0::dap::StartDebuggingRequestArgumentsRequest;
|
||||
|
||||
use super::{WasmState, wasm_engine};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
@@ -33,6 +33,19 @@ pub use latest::CodeLabelSpanLiteral;
|
||||
pub use latest::{
|
||||
CodeLabel, CodeLabelSpan, Command, DebugAdapterBinary, ExtensionProject, Range, SlashCommand,
|
||||
zed::extension::context_server::ContextServerConfiguration,
|
||||
zed::extension::llm_provider::{
|
||||
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
|
||||
CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
|
||||
ImageData as LlmImageData, MessageContent as LlmMessageContent,
|
||||
MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
|
||||
ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo,
|
||||
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
|
||||
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
|
||||
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
|
||||
ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
|
||||
ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
|
||||
ToolUseJsonParseError as LlmToolUseJsonParseError,
|
||||
},
|
||||
zed::extension::lsp::{
|
||||
Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind,
|
||||
},
|
||||
@@ -1007,6 +1020,20 @@ impl Extension {
|
||||
resource: Resource<Arc<dyn WorktreeDelegate>>,
|
||||
) -> Result<Result<DebugAdapterBinary, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let dap_binary = ext
|
||||
.call_get_dap_binary(
|
||||
store,
|
||||
&adapter_name,
|
||||
&task.try_into()?,
|
||||
user_installed_path.as_ref().and_then(|p| p.to_str()),
|
||||
resource,
|
||||
)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(dap_binary))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let dap_binary = ext
|
||||
.call_get_dap_binary(
|
||||
@@ -1032,6 +1059,16 @@ impl Extension {
|
||||
config: serde_json::Value,
|
||||
) -> Result<Result<StartDebuggingRequestArgumentsRequest, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let config =
|
||||
serde_json::to_string(&config).context("Adapter config is not a valid JSON")?;
|
||||
let result = ext
|
||||
.call_dap_request_kind(store, &adapter_name, &config)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(result))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let config =
|
||||
serde_json::to_string(&config).context("Adapter config is not a valid JSON")?;
|
||||
@@ -1052,6 +1089,15 @@ impl Extension {
|
||||
config: ZedDebugConfig,
|
||||
) -> Result<Result<DebugScenario, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let config = config.into();
|
||||
let result = ext
|
||||
.call_dap_config_to_scenario(store, &config)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(result.try_into()?))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let config = config.into();
|
||||
let dap_binary = ext
|
||||
@@ -1074,6 +1120,20 @@ impl Extension {
|
||||
debug_adapter_name: String,
|
||||
) -> Result<Option<DebugScenario>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let build_config_template = build_config_template.into();
|
||||
let result = ext
|
||||
.call_dap_locator_create_scenario(
|
||||
store,
|
||||
&locator_name,
|
||||
&build_config_template,
|
||||
&resolved_label,
|
||||
&debug_adapter_name,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(result.map(TryInto::try_into).transpose()?)
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let build_config_template = build_config_template.into();
|
||||
let dap_binary = ext
|
||||
@@ -1099,6 +1159,15 @@ impl Extension {
|
||||
resolved_build_task: SpawnInTerminal,
|
||||
) -> Result<Result<DebugRequest, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let build_config_template = resolved_build_task.try_into()?;
|
||||
let dap_request = ext
|
||||
.call_run_dap_locator(store, &locator_name, &build_config_template)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(dap_request.into()))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let build_config_template = resolved_build_task.try_into()?;
|
||||
let dap_request = ext
|
||||
@@ -1111,6 +1180,174 @@ impl Extension {
|
||||
_ => anyhow::bail!("`dap_locator_create_scenario` not available prior to v0.6.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_providers(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
) -> Result<Vec<latest::llm_provider::ProviderInfo>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_providers(store).await,
|
||||
_ => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_models(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<Vec<latest::llm_provider::ModelInfo>, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_provider_models(store, provider_id).await,
|
||||
_ => anyhow::bail!("`llm_provider_models` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_settings_markdown(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Option<String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_settings_markdown(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_is_authenticated(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<bool> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_is_authenticated(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_start_device_flow_sign_in(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<String, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_start_device_flow_sign_in(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!(
|
||||
"`llm_provider_start_device_flow_sign_in` not available prior to v0.8.0"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_poll_device_flow_sign_in(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<(), String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_poll_device_flow_sign_in(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!(
|
||||
"`llm_provider_poll_device_flow_sign_in` not available prior to v0.8.0"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_reset_credentials(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<(), String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_reset_credentials(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("`llm_provider_reset_credentials` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_count_tokens(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
request: &latest::llm_provider::CompletionRequest,
|
||||
) -> Result<Result<u64, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_count_tokens(store, provider_id, model_id, request)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("`llm_count_tokens` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_stream_completion_start(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
request: &latest::llm_provider::CompletionRequest,
|
||||
) -> Result<Result<String, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_stream_completion_start(store, provider_id, model_id, request)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("`llm_stream_completion_start` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_stream_completion_next(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
stream_id: &str,
|
||||
) -> Result<Result<Option<latest::llm_provider::CompletionEvent>, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_stream_completion_next(store, stream_id).await,
|
||||
_ => anyhow::bail!("`llm_stream_completion_next` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_stream_completion_close(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
stream_id: &str,
|
||||
) -> Result<()> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_stream_completion_close(store, stream_id).await,
|
||||
_ => anyhow::bail!("`llm_stream_completion_close` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_cache_configuration(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
) -> Result<Option<latest::llm_provider::CacheConfiguration>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_cache_configuration(store, provider_id, model_id)
|
||||
.await
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait ToWasmtimeResult<T> {
|
||||
|
||||
@@ -32,8 +32,6 @@ wasmtime::component::bindgen!({
|
||||
},
|
||||
});
|
||||
|
||||
pub use self::zed::extension::*;
|
||||
|
||||
mod settings {
|
||||
#![allow(dead_code)]
|
||||
include!(concat!(env!("OUT_DIR"), "/since_v0.6.0/settings.rs"));
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use crate::wasm_host::wit::since_v0_6_0::{
|
||||
use crate::wasm_host::wit::since_v0_8_0::{
|
||||
dap::{
|
||||
AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest,
|
||||
StartDebuggingRequestArguments, TcpArguments, TcpArgumentsTemplate,
|
||||
},
|
||||
lsp::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind},
|
||||
slash_command::SlashCommandOutputSection,
|
||||
};
|
||||
use crate::wasm_host::wit::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind};
|
||||
use crate::wasm_host::{WasmState, wit::ToWasmtimeResult};
|
||||
use ::http_client::{AsyncBody, HttpRequestExt};
|
||||
use ::settings::{Settings, WorktreeId};
|
||||
@@ -13,6 +13,7 @@ use anyhow::{Context as _, Result, bail};
|
||||
use async_compression::futures::bufread::GzipDecoder;
|
||||
use async_tar::Archive;
|
||||
use async_trait::async_trait;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use extension::{
|
||||
ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate,
|
||||
};
|
||||
@@ -22,12 +23,14 @@ use gpui::{BackgroundExecutor, SharedString};
|
||||
use language::{BinaryStatus, LanguageName, language_settings::AllLanguageSettings};
|
||||
use project::project_settings::ProjectSettings;
|
||||
use semver::Version;
|
||||
use smol::net::TcpListener;
|
||||
use std::{
|
||||
env,
|
||||
net::Ipv4Addr,
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr,
|
||||
sync::{Arc, OnceLock},
|
||||
time::Duration,
|
||||
};
|
||||
use task::{SpawnInTerminal, ZedDebugConfig};
|
||||
use url::Url;
|
||||
@@ -1107,3 +1110,361 @@ impl ExtensionImports for WasmState {
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
}
|
||||
|
||||
impl llm_provider::Host for WasmState {
|
||||
async fn request_credential(
|
||||
&mut self,
|
||||
_provider_id: String,
|
||||
_credential_type: llm_provider::CredentialType,
|
||||
_label: String,
|
||||
_placeholder: String,
|
||||
) -> wasmtime::Result<Result<bool, String>> {
|
||||
// For now, credential requests return false (not provided)
|
||||
// Extensions should use get_env_var to check for env vars first,
|
||||
// then store_credential/get_credential for manual storage
|
||||
// Full UI credential prompting will be added in a future phase
|
||||
Ok(Ok(false))
|
||||
}
|
||||
|
||||
async fn get_credential(&mut self, provider_id: String) -> wasmtime::Result<Option<String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
|
||||
// Check if this provider has an env var configured and if the user has allowed it
|
||||
let env_var_name = self
|
||||
.manifest
|
||||
.language_model_providers
|
||||
.get(&Arc::<str>::from(provider_id.as_str()))
|
||||
.and_then(|entry| entry.auth.as_ref())
|
||||
.and_then(|auth| auth.env_var.clone());
|
||||
|
||||
if let Some(env_var_name) = env_var_name {
|
||||
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
|
||||
// Read settings dynamically to get current allowed_env_var_providers
|
||||
let is_allowed = self
|
||||
.on_main_thread({
|
||||
let full_provider_id = full_provider_id.clone();
|
||||
move |cx| {
|
||||
async move {
|
||||
cx.update(|cx| {
|
||||
crate::extension_settings::ExtensionSettings::get_global(cx)
|
||||
.allowed_env_var_providers
|
||||
.contains(&full_provider_id)
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_allowed {
|
||||
if let Ok(value) = env::var(&env_var_name) {
|
||||
if !value.is_empty() {
|
||||
return Ok(Some(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to credential store
|
||||
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
|
||||
let result = credentials_provider
|
||||
.read_credentials(&credential_key, cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
Ok(result.map(|(_, password)| String::from_utf8_lossy(&password).to_string()))
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn store_credential(
|
||||
&mut self,
|
||||
provider_id: String,
|
||||
value: String,
|
||||
) -> wasmtime::Result<Result<(), String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
|
||||
credentials_provider
|
||||
.write_credentials(&credential_key, "api_key", value.as_bytes(), cx)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn delete_credential(
|
||||
&mut self,
|
||||
provider_id: String,
|
||||
) -> wasmtime::Result<Result<(), String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
|
||||
credentials_provider
|
||||
.delete_credentials(&credential_key, cx)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn get_env_var(&mut self, name: String) -> wasmtime::Result<Option<String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
|
||||
// Find which provider (if any) declares this env var in its auth config
|
||||
let mut allowed_provider_id: Option<Arc<str>> = None;
|
||||
for (provider_id, provider_entry) in &self.manifest.language_model_providers {
|
||||
if let Some(auth_config) = &provider_entry.auth {
|
||||
if auth_config.env_var.as_deref() == Some(&name) {
|
||||
allowed_provider_id = Some(provider_id.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no provider declares this env var, deny access
|
||||
let Some(provider_id) = allowed_provider_id else {
|
||||
log::warn!(
|
||||
"Extension {} attempted to read env var {} which is not declared in any provider auth config",
|
||||
extension_id,
|
||||
name
|
||||
);
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Check if the user has allowed this provider to read env vars
|
||||
// Read settings dynamically to get current allowed_env_var_providers
|
||||
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
|
||||
let is_allowed = self
|
||||
.on_main_thread({
|
||||
let full_provider_id = full_provider_id.clone();
|
||||
move |cx| {
|
||||
async move {
|
||||
cx.update(|cx| {
|
||||
crate::extension_settings::ExtensionSettings::get_global(cx)
|
||||
.allowed_env_var_providers
|
||||
.contains(&full_provider_id)
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_allowed {
|
||||
log::debug!(
|
||||
"Extension {} provider {} is not allowed to read env var {}",
|
||||
extension_id,
|
||||
provider_id,
|
||||
name
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(env::var(&name).ok())
|
||||
}
|
||||
|
||||
async fn oauth_start_web_auth(
|
||||
&mut self,
|
||||
config: llm_provider::OauthWebAuthConfig,
|
||||
) -> wasmtime::Result<Result<llm_provider::OauthWebAuthResult, String>> {
|
||||
let auth_url = config.auth_url;
|
||||
let callback_path = config.callback_path;
|
||||
let timeout_secs = config.timeout_secs.unwrap_or(300);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to bind localhost server: {}", e))?;
|
||||
let port = listener
|
||||
.local_addr()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))?
|
||||
.port();
|
||||
|
||||
let auth_url_with_port = auth_url.replace("{port}", &port.to_string());
|
||||
cx.update(|cx| {
|
||||
cx.open_url(&auth_url_with_port);
|
||||
})?;
|
||||
|
||||
let accept_future = async {
|
||||
let (mut stream, _) = listener
|
||||
.accept()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?;
|
||||
|
||||
let mut request_line = String::new();
|
||||
{
|
||||
let mut reader = smol::io::BufReader::new(&mut stream);
|
||||
smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
|
||||
}
|
||||
|
||||
let callback_url = if let Some(path_start) = request_line.find(' ') {
|
||||
if let Some(path_end) = request_line[path_start + 1..].find(' ') {
|
||||
let path = &request_line[path_start + 1..path_start + 1 + path_end];
|
||||
if path.starts_with(&callback_path) || path.starts_with(&format!("/{}", callback_path.trim_start_matches('/'))) {
|
||||
format!("http://localhost:{}{}", port, path)
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unexpected callback path: {}",
|
||||
path
|
||||
));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Malformed HTTP request"));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Malformed HTTP request"));
|
||||
};
|
||||
|
||||
let response = "HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: text/html\r\n\
|
||||
Connection: close\r\n\
|
||||
\r\n\
|
||||
<!DOCTYPE html>\
|
||||
<html><head><title>Authentication Complete</title></head>\
|
||||
<body style=\"font-family: system-ui, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0;\">\
|
||||
<div style=\"text-align: center;\">\
|
||||
<h1>Authentication Complete</h1>\
|
||||
<p>You can close this window and return to Zed.</p>\
|
||||
</div></body></html>";
|
||||
|
||||
smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes())
|
||||
.await
|
||||
.ok();
|
||||
smol::io::AsyncWriteExt::flush(&mut stream).await.ok();
|
||||
|
||||
Ok(callback_url)
|
||||
};
|
||||
|
||||
let timeout_duration = Duration::from_secs(timeout_secs as u64);
|
||||
let callback_url = smol::future::or(
|
||||
accept_future,
|
||||
async {
|
||||
smol::Timer::after(timeout_duration).await;
|
||||
Err(anyhow::anyhow!(
|
||||
"OAuth callback timed out after {} seconds",
|
||||
timeout_secs
|
||||
))
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(llm_provider::OauthWebAuthResult {
|
||||
callback_url,
|
||||
port: port as u32,
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn send_oauth_http_request(
|
||||
&mut self,
|
||||
request: llm_provider::OauthHttpRequest,
|
||||
) -> wasmtime::Result<Result<llm_provider::OauthHttpResponse, String>> {
|
||||
let http_client = self.host.http_client.clone();
|
||||
|
||||
self.on_main_thread(move |_cx| {
|
||||
async move {
|
||||
let method = match request.method.to_uppercase().as_str() {
|
||||
"GET" => ::http_client::Method::GET,
|
||||
"POST" => ::http_client::Method::POST,
|
||||
"PUT" => ::http_client::Method::PUT,
|
||||
"DELETE" => ::http_client::Method::DELETE,
|
||||
"PATCH" => ::http_client::Method::PATCH,
|
||||
_ => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unsupported HTTP method: {}",
|
||||
request.method
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut builder = ::http_client::Request::builder()
|
||||
.method(method)
|
||||
.uri(&request.url);
|
||||
|
||||
for (key, value) in &request.headers {
|
||||
builder = builder.header(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
let body = if request.body.is_empty() {
|
||||
AsyncBody::empty()
|
||||
} else {
|
||||
AsyncBody::from(request.body.into_bytes())
|
||||
};
|
||||
|
||||
let http_request = builder
|
||||
.body(body)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
|
||||
|
||||
let mut response = http_client
|
||||
.send(http_request)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
|
||||
|
||||
let status = response.status().as_u16();
|
||||
let headers: Vec<(String, String)> = response
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
let mut body_bytes = Vec::new();
|
||||
futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
|
||||
|
||||
let body = String::from_utf8_lossy(&body_bytes).to_string();
|
||||
|
||||
Ok(llm_provider::OauthHttpResponse {
|
||||
status,
|
||||
headers,
|
||||
body,
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn oauth_open_browser(&mut self, url: String) -> wasmtime::Result<Result<(), String>> {
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
cx.update(|cx| {
|
||||
cx.open_url(&url);
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,7 +442,9 @@ impl ExtensionsPage {
|
||||
let extension_store = ExtensionStore::global(cx).read(cx);
|
||||
|
||||
match extension_store.outstanding_operations().get(extension_id) {
|
||||
Some(ExtensionOperation::Install) => ExtensionStatus::Installing,
|
||||
Some(ExtensionOperation::Install) | Some(ExtensionOperation::AutoInstall) => {
|
||||
ExtensionStatus::Installing
|
||||
}
|
||||
Some(ExtensionOperation::Remove) => ExtensionStatus::Removing,
|
||||
Some(ExtensionOperation::Upgrade) => ExtensionStatus::Upgrading,
|
||||
None => match extension_store.installed_extensions().get(extension_id) {
|
||||
|
||||
@@ -12,10 +12,10 @@ impl FeatureFlag for PanicFeatureFlag {
|
||||
const NAME: &'static str = "panic";
|
||||
}
|
||||
|
||||
pub struct InlineAssistantUseToolFeatureFlag;
|
||||
pub struct InlineAssistantV2FeatureFlag;
|
||||
|
||||
impl FeatureFlag for InlineAssistantUseToolFeatureFlag {
|
||||
const NAME: &'static str = "inline-assistant-use-tool";
|
||||
impl FeatureFlag for InlineAssistantV2FeatureFlag {
|
||||
const NAME: &'static str = "inline-assistant-v2";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
|
||||
@@ -23,7 +23,6 @@ use std::{
|
||||
path::PathBuf,
|
||||
sync::{Arc, LazyLock},
|
||||
};
|
||||
use text::LineEnding;
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
|
||||
pub static LOAD_INDEX_TEXT_TASK: LazyLock<TaskLabel> = LazyLock::new(TaskLabel::new);
|
||||
@@ -201,7 +200,6 @@ impl GitRepository for FakeGitRepository {
|
||||
async {
|
||||
Ok(CommitDetails {
|
||||
sha: commit.into(),
|
||||
message: "initial commit".into(),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
@@ -453,12 +451,7 @@ impl GitRepository for FakeGitRepository {
|
||||
})
|
||||
}
|
||||
|
||||
fn blame(
|
||||
&self,
|
||||
path: RepoPath,
|
||||
_content: Rope,
|
||||
_line_ending: LineEnding,
|
||||
) -> BoxFuture<'_, Result<git::blame::Blame>> {
|
||||
fn blame(&self, path: RepoPath, _content: Rope) -> BoxFuture<'_, Result<git::blame::Blame>> {
|
||||
self.with_state_async(false, move |state| {
|
||||
state
|
||||
.blames
|
||||
@@ -575,7 +568,7 @@ impl GitRepository for FakeGitRepository {
|
||||
_askpass: AskPassDelegate,
|
||||
_env: Arc<HashMap<String, String>>,
|
||||
) -> BoxFuture<'_, Result<()>> {
|
||||
async { Ok(()) }.boxed()
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn run_hook(
|
||||
@@ -583,7 +576,7 @@ impl GitRepository for FakeGitRepository {
|
||||
_hook: RunHook,
|
||||
_env: Arc<HashMap<String, String>>,
|
||||
) -> BoxFuture<'_, Result<()>> {
|
||||
async { Ok(()) }.boxed()
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn push(
|
||||
|
||||
@@ -803,7 +803,7 @@ impl Fs for RealFs {
|
||||
}
|
||||
let file = smol::fs::File::create(path).await?;
|
||||
let mut writer = smol::io::BufWriter::with_capacity(buffer_size, file);
|
||||
for chunk in text::chunks_with_line_ending(text, line_ending) {
|
||||
for chunk in chunks(text, line_ending) {
|
||||
writer.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
writer.flush().await?;
|
||||
@@ -2555,7 +2555,7 @@ impl Fs for FakeFs {
|
||||
async fn save(&self, path: &Path, text: &Rope, line_ending: LineEnding) -> Result<()> {
|
||||
self.simulate_random_delay().await;
|
||||
let path = normalize_path(path);
|
||||
let content = text::chunks_with_line_ending(text, line_ending).collect::<String>();
|
||||
let content = chunks(text, line_ending).collect::<String>();
|
||||
if let Some(path) = path.parent() {
|
||||
self.create_dir(path).await?;
|
||||
}
|
||||
@@ -2773,6 +2773,25 @@ impl Fs for FakeFs {
|
||||
}
|
||||
}
|
||||
|
||||
fn chunks(rope: &Rope, line_ending: LineEnding) -> impl Iterator<Item = &str> {
|
||||
rope.chunks().flat_map(move |chunk| {
|
||||
let mut newline = false;
|
||||
let end_with_newline = chunk.ends_with('\n').then_some(line_ending.as_str());
|
||||
chunk
|
||||
.lines()
|
||||
.flat_map(move |line| {
|
||||
let ending = if newline {
|
||||
Some(line_ending.as_str())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
newline = true;
|
||||
ending.into_iter().chain([line])
|
||||
})
|
||||
.chain(end_with_newline)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn normalize_path(path: &Path) -> PathBuf {
|
||||
let mut components = path.components().peekable();
|
||||
let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek().cloned() {
|
||||
|
||||
@@ -8,7 +8,7 @@ use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::{LineEnding, Rope};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
@@ -35,10 +35,8 @@ impl Blame {
|
||||
working_directory: &Path,
|
||||
path: &RepoPath,
|
||||
content: &Rope,
|
||||
line_ending: LineEnding,
|
||||
) -> Result<Self> {
|
||||
let output =
|
||||
run_git_blame(git_binary, working_directory, path, content, line_ending).await?;
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
@@ -65,12 +63,12 @@ async fn run_git_blame(
|
||||
working_directory: &Path,
|
||||
path: &RepoPath,
|
||||
contents: &Rope,
|
||||
line_ending: LineEnding,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("-w")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_unix_str())
|
||||
@@ -85,7 +83,7 @@ async fn run_git_blame(
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in text::chunks_with_line_ending(contents, line_ending) {
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
@@ -14,7 +14,6 @@ use rope::Rope;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use smol::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
|
||||
use text::LineEnding;
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::{OsStr, OsString};
|
||||
@@ -488,12 +487,7 @@ pub trait GitRepository: Send + Sync {
|
||||
fn show(&self, commit: String) -> BoxFuture<'_, Result<CommitDetails>>;
|
||||
|
||||
fn load_commit(&self, commit: String, cx: AsyncApp) -> BoxFuture<'_, Result<CommitDiff>>;
|
||||
fn blame(
|
||||
&self,
|
||||
path: RepoPath,
|
||||
content: Rope,
|
||||
line_ending: LineEnding,
|
||||
) -> BoxFuture<'_, Result<crate::blame::Blame>>;
|
||||
fn blame(&self, path: RepoPath, content: Rope) -> BoxFuture<'_, Result<crate::blame::Blame>>;
|
||||
fn file_history(&self, path: RepoPath) -> BoxFuture<'_, Result<FileHistory>>;
|
||||
fn file_history_paginated(
|
||||
&self,
|
||||
@@ -1518,12 +1512,7 @@ impl GitRepository for RealGitRepository {
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn blame(
|
||||
&self,
|
||||
path: RepoPath,
|
||||
content: Rope,
|
||||
line_ending: LineEnding,
|
||||
) -> BoxFuture<'_, Result<crate::blame::Blame>> {
|
||||
fn blame(&self, path: RepoPath, content: Rope) -> BoxFuture<'_, Result<crate::blame::Blame>> {
|
||||
let working_directory = self.working_directory();
|
||||
let git_binary_path = self.any_git_binary_path.clone();
|
||||
let executor = self.executor.clone();
|
||||
@@ -1535,7 +1524,6 @@ impl GitRepository for RealGitRepository {
|
||||
&working_directory?,
|
||||
&path,
|
||||
&content,
|
||||
line_ending,
|
||||
)
|
||||
.await
|
||||
})
|
||||
|
||||
@@ -47,13 +47,11 @@ impl BlameRenderer for GitBlameRenderer {
|
||||
let name = util::truncate_and_trailoff(author_name, GIT_BLAME_MAX_AUTHOR_CHARS_DISPLAYED);
|
||||
|
||||
let avatar = if ProjectSettings::get_global(cx).git.blame.show_avatar {
|
||||
Some(
|
||||
CommitAvatar::new(
|
||||
&blame_entry.sha.to_string().into(),
|
||||
details.as_ref().and_then(|it| it.remote.as_ref()),
|
||||
)
|
||||
.render(window, cx),
|
||||
CommitAvatar::new(
|
||||
&blame_entry.sha.to_string().into(),
|
||||
details.as_ref().and_then(|it| it.remote.as_ref()),
|
||||
)
|
||||
.render(window, cx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -67,7 +65,7 @@ impl BlameRenderer for GitBlameRenderer {
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.font(style.font())
|
||||
.font_family(style.font().family)
|
||||
.line_height(style.line_height)
|
||||
.text_color(cx.theme().status().hint)
|
||||
.child(
|
||||
@@ -266,7 +264,7 @@ impl BlameRenderer for GitBlameRenderer {
|
||||
.flex_wrap()
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.child(avatar)
|
||||
.children(avatar)
|
||||
.child(author)
|
||||
.when(!author_email.is_empty(), |this| {
|
||||
this.child(
|
||||
|
||||
@@ -636,6 +636,7 @@ impl PickerDelegate for BranchListDelegate {
|
||||
return Task::ready(());
|
||||
};
|
||||
|
||||
const RECENT_BRANCHES_COUNT: usize = 10;
|
||||
let display_remotes = self.display_remotes;
|
||||
cx.spawn_in(window, async move |picker, cx| {
|
||||
let mut matches: Vec<Entry> = if query.is_empty() {
|
||||
@@ -648,6 +649,7 @@ impl PickerDelegate for BranchListDelegate {
|
||||
!branch.is_remote()
|
||||
}
|
||||
})
|
||||
.take(RECENT_BRANCHES_COUNT)
|
||||
.map(|branch| Entry::Branch {
|
||||
branch,
|
||||
positions: Vec::new(),
|
||||
|
||||
@@ -139,7 +139,7 @@ impl CommitModal {
|
||||
&& !git_panel.amend_pending()
|
||||
{
|
||||
git_panel.set_amend_pending(true, cx);
|
||||
git_panel.load_last_commit_message(cx);
|
||||
git_panel.load_last_commit_message_if_empty(cx);
|
||||
}
|
||||
}
|
||||
ForceMode::Commit => {
|
||||
@@ -492,20 +492,53 @@ impl CommitModal {
|
||||
}
|
||||
}
|
||||
|
||||
fn on_commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.git_panel.update(cx, |git_panel, cx| {
|
||||
git_panel.commit(&self.commit_editor.focus_handle(cx), window, cx)
|
||||
}) {
|
||||
telemetry::event!("Git Committed", source = "Git Modal");
|
||||
cx.emit(DismissEvent);
|
||||
fn commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.git_panel.read(cx).amend_pending() {
|
||||
return;
|
||||
}
|
||||
telemetry::event!("Git Committed", source = "Git Modal");
|
||||
self.git_panel.update(cx, |git_panel, cx| {
|
||||
git_panel.commit_changes(
|
||||
CommitOptions {
|
||||
amend: false,
|
||||
signoff: git_panel.signoff_enabled(),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn on_amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.git_panel.update(cx, |git_panel, cx| {
|
||||
git_panel.amend(&self.commit_editor.focus_handle(cx), window, cx)
|
||||
}) {
|
||||
fn amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self
|
||||
.git_panel
|
||||
.read(cx)
|
||||
.active_repository
|
||||
.as_ref()
|
||||
.and_then(|repo| repo.read(cx).head_commit.as_ref())
|
||||
.is_none()
|
||||
{
|
||||
return;
|
||||
}
|
||||
if !self.git_panel.read(cx).amend_pending() {
|
||||
self.git_panel.update(cx, |git_panel, cx| {
|
||||
git_panel.set_amend_pending(true, cx);
|
||||
git_panel.load_last_commit_message_if_empty(cx);
|
||||
});
|
||||
} else {
|
||||
telemetry::event!("Git Amended", source = "Git Modal");
|
||||
self.git_panel.update(cx, |git_panel, cx| {
|
||||
git_panel.set_amend_pending(false, cx);
|
||||
git_panel.commit_changes(
|
||||
CommitOptions {
|
||||
amend: true,
|
||||
signoff: git_panel.signoff_enabled(),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
@@ -531,8 +564,8 @@ impl Render for CommitModal {
|
||||
.id("commit-modal")
|
||||
.key_context("GitCommit")
|
||||
.on_action(cx.listener(Self::dismiss))
|
||||
.on_action(cx.listener(Self::on_commit))
|
||||
.on_action(cx.listener(Self::on_amend))
|
||||
.on_action(cx.listener(Self::commit))
|
||||
.on_action(cx.listener(Self::amend))
|
||||
.when(!DisableAiSettings::get_global(cx).disable_ai, |this| {
|
||||
this.on_action(cx.listener(|this, _: &GenerateCommitMessage, _, cx| {
|
||||
this.git_panel.update(cx, |panel, cx| {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user