Compare commits

..

2 Commits

Author SHA1 Message Date
Conrad Irwin
675c3441ee clipster 2025-03-13 23:24:46 -06:00
Conrad Irwin
38773a9b0b Use new path-based multibuffers in "find references" and friends 2025-03-13 22:51:29 -06:00
190 changed files with 4108 additions and 8904 deletions

View File

@@ -453,6 +453,7 @@ jobs:
[[ "${{ needs.linux_tests.result }}" != 'success' ]] && { RET_CODE=1; echo "Linux tests failed"; }
[[ "${{ needs.windows_tests.result }}" != 'success' ]] && { RET_CODE=1; echo "Windows tests failed"; }
[[ "${{ needs.windows_clippy.result }}" != 'success' ]] && { RET_CODE=1; echo "Windows clippy failed"; }
[[ "${{ needs.migration_checks.result }}" != 'success' ]] && { RET_CODE=1; echo "Migration checks failed"; }
[[ "${{ needs.build_remote_server.result }}" != 'success' ]] && { RET_CODE=1; echo "Remote server build failed"; }
fi
if [[ "$RET_CODE" -eq 0 ]]; then

View File

@@ -17,7 +17,7 @@ jobs:
else
URL="https://zed.dev/releases/stable/latest"
fi
echo "URL=$URL" >> $GITHUB_OUTPUT
echo "::set-output name=URL::$URL"
- name: Get content
uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1
id: get-content
@@ -33,33 +33,3 @@ jobs:
with:
webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }}
content: ${{ steps.get-content.outputs.string }}
send_release_notes_email:
if: github.repository_owner == 'zed-industries' && !github.event.release.prerelease
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Check if release was promoted from preview
id: check-promotion-from-preview
run: |
VERSION="${{ github.event.release.tag_name }}"
PREVIEW_TAG="${VERSION}-pre"
if git rev-parse "$PREVIEW_TAG" >/dev/null 2>&1; then
echo "was_preview=true" >> $GITHUB_OUTPUT
else
echo "was_preview=false" >> $GITHUB_OUTPUT
fi
- name: Send release notes email
if: steps.check-promotion-from-preview.outputs.was_preview == 'true'
run: |
curl -X POST "https://zed.dev/api/send_release_notes_email" \
-H "Authorization: Bearer ${{ secrets.RELEASE_NOTES_API_TOKEN }}" \
-H "Content-Type: application/json" \
-d '{
"version": "${{ github.event.release.tag_name }}",
"markdown_body": ${{ toJSON(github.event.release.body) }}
}'

183
Cargo.lock generated
View File

@@ -467,7 +467,6 @@ dependencies = [
"fs",
"futures 0.3.31",
"fuzzy",
"git",
"gpui",
"heed",
"html_to_markdown",
@@ -497,7 +496,6 @@ dependencies = [
"settings",
"smol",
"streaming_diff",
"telemetry",
"telemetry_events",
"terminal",
"terminal_view",
@@ -566,41 +564,6 @@ dependencies = [
"workspace",
]
[[package]]
name = "assistant_eval"
version = "0.1.0"
dependencies = [
"anyhow",
"assistant2",
"assistant_tool",
"assistant_tools",
"clap",
"client",
"collections",
"context_server",
"env_logger 0.11.7",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"itertools 0.14.0",
"language",
"language_model",
"language_models",
"node_runtime",
"project",
"prompt_store",
"regex",
"release_channel",
"reqwest_client",
"serde",
"serde_json",
"serde_json_lenient",
"settings",
"smol",
"util",
]
[[package]]
name = "assistant_settings"
version = "0.1.0"
@@ -692,11 +655,9 @@ name = "assistant_tool"
version = "0.1.0"
dependencies = [
"anyhow",
"clock",
"collections",
"derive_more",
"gpui",
"language",
"language_model",
"parking_lot",
"project",
@@ -715,10 +676,8 @@ dependencies = [
"feature_flags",
"futures 0.3.31",
"gpui",
"itertools 0.14.0",
"language",
"language_model",
"pretty_assertions",
"project",
"rand 0.8.5",
"release_channel",
@@ -728,10 +687,8 @@ dependencies = [
"settings",
"theme",
"ui",
"unindent",
"util",
"workspace",
"worktree",
]
[[package]]
@@ -1878,7 +1835,7 @@ dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"itertools 0.10.5",
"lazy_static",
"lazycell",
"log",
@@ -1901,7 +1858,7 @@ dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"itertools 0.10.5",
"log",
"prettyplease",
"proc-macro2",
@@ -2744,7 +2701,6 @@ dependencies = [
"futures 0.3.31",
"gpui",
"http_client",
"http_client_tls",
"log",
"parking_lot",
"paths",
@@ -3365,15 +3321,6 @@ dependencies = [
"libc",
]
[[package]]
name = "core_maths"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77745e017f5edba1a9c1d854f6f3a52dac8a12dd5af5d2f54aecf61e43d80d30"
dependencies = [
"libm",
]
[[package]]
name = "coreaudio-rs"
version = "0.11.3"
@@ -3411,16 +3358,16 @@ version = "0.11.2"
source = "git+https://github.com/pop-os/cosmic-text?rev=542b20c#542b20ca4376a3b5de5fa629db1a4ace44e18e0c"
dependencies = [
"bitflags 2.8.0",
"fontdb 0.18.0",
"fontdb",
"log",
"rangemap",
"rayon",
"rustc-hash 1.1.0",
"rustybuzz 0.14.1",
"rustybuzz",
"self_cell",
"swash",
"sys-locale",
"ttf-parser 0.21.1",
"ttf-parser",
"unicode-bidi",
"unicode-linebreak",
"unicode-script",
@@ -5014,21 +4961,7 @@ dependencies = [
"memmap2",
"slotmap",
"tinyvec",
"ttf-parser 0.21.1",
]
[[package]]
name = "fontdb"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "457e789b3d1202543297a350643cf459f836cade38934e7a4cf6a39e7cde2905"
dependencies = [
"fontconfig-parser",
"log",
"memmap2",
"slotmap",
"tinyvec",
"ttf-parser 0.25.1",
"ttf-parser",
]
[[package]]
@@ -5487,9 +5420,8 @@ dependencies = [
[[package]]
name = "git2"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5220b8ba44c68a9a7f7a7659e864dd73692e417ef0211bea133c7b74e031eeb9"
version = "0.20.0"
source = "git+https://github.com/rust-lang/git2-rs?rev=a3b90cb3756c1bb63e2317bf9cfa57838178de5c#a3b90cb3756c1bb63e2317bf9cfa57838178de5c"
dependencies = [
"bitflags 2.8.0",
"libc",
@@ -5511,10 +5443,8 @@ dependencies = [
"indoc",
"pretty_assertions",
"regex",
"schemars",
"serde",
"serde_json",
"settings",
"url",
"util",
]
@@ -6210,19 +6140,13 @@ dependencies = [
"futures 0.3.31",
"http 1.2.0",
"log",
"rustls 0.23.23",
"rustls-platform-verifier",
"serde",
"serde_json",
"url",
]
[[package]]
name = "http_client_tls"
version = "0.1.0"
dependencies = [
"rustls 0.23.23",
"rustls-platform-verifier",
]
[[package]]
name = "httparse"
version = "1.9.5"
@@ -7430,9 +7354,8 @@ dependencies = [
[[package]]
name = "libgit2-sys"
version = "0.18.1+1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1dcb20f84ffcdd825c7a311ae347cce604a6f084a767dec4a4929829645290e"
version = "0.18.0+1.9.0"
source = "git+https://github.com/rust-lang/git2-rs?rev=a3b90cb3756c1bb63e2317bf9cfa57838178de5c#a3b90cb3756c1bb63e2317bf9cfa57838178de5c"
dependencies = [
"cc",
"libc",
@@ -7447,7 +7370,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets 0.52.6",
"windows-targets 0.48.5",
]
[[package]]
@@ -10642,8 +10565,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes 1.10.1",
"heck 0.5.0",
"itertools 0.12.1",
"heck 0.4.1",
"itertools 0.10.5",
"log",
"multimap 0.10.0",
"once_cell",
@@ -10676,7 +10599,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1"
dependencies = [
"anyhow",
"itertools 0.12.1",
"itertools 0.10.5",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -11262,7 +11185,6 @@ dependencies = [
"smol",
"tempfile",
"thiserror 1.0.69",
"urlencoding",
"util",
]
@@ -11477,7 +11399,6 @@ dependencies = [
"futures 0.3.31",
"gpui",
"http_client",
"http_client_tls",
"log",
"regex",
"reqwest 0.12.8",
@@ -11488,9 +11409,9 @@ dependencies = [
[[package]]
name = "resvg"
version = "0.45.0"
version = "0.44.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd43d1c474e9dadf09a8fdf22d713ba668b499b5117b9b9079500224e26b5b29"
checksum = "4a325d5e8d1cebddd070b13f44cec8071594ab67d1012797c121f27a669b7958"
dependencies = [
"log",
"pico-args",
@@ -11952,27 +11873,9 @@ dependencies = [
"bytemuck",
"libm",
"smallvec",
"ttf-parser 0.21.1",
"unicode-bidi-mirroring 0.2.0",
"unicode-ccc 0.2.0",
"unicode-properties",
"unicode-script",
]
[[package]]
name = "rustybuzz"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd3c7c96f8a08ee34eff8857b11b49b07d71d1c3f4e88f8a88d4c9e9f90b1702"
dependencies = [
"bitflags 2.8.0",
"bytemuck",
"core_maths",
"log",
"smallvec",
"ttf-parser 0.25.1",
"unicode-bidi-mirroring 0.4.0",
"unicode-ccc 0.4.0",
"ttf-parser",
"unicode-bidi-mirroring",
"unicode-ccc",
"unicode-properties",
"unicode-script",
]
@@ -13420,9 +13323,9 @@ checksum = "ce5d813d71d82c4cbc1742135004e4a79fd870214c155443451c139c9470a0aa"
[[package]]
name = "svgtypes"
version = "0.15.3"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68c7541fff44b35860c1a7a47a7cadf3e4a304c457b58f9870d9706ece028afc"
checksum = "794de53cc48eaabeed0ab6a3404a65f40b3e38c067e4435883a65d2aa4ca000e"
dependencies = [
"kurbo",
"siphasher 1.0.1",
@@ -14754,15 +14657,6 @@ version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c591d83f69777866b9126b24c6dd9a18351f177e49d625920d19f989fd31cf8"
[[package]]
name = "ttf-parser"
version = "0.25.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2df906b07856748fa3f6e0ad0cbaa047052d4a7dd609e231c4f72cee8c36f31"
dependencies = [
"core_maths",
]
[[package]]
name = "tungstenite"
version = "0.20.1"
@@ -14910,24 +14804,12 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23cb788ffebc92c5948d0e997106233eeb1d8b9512f93f41651f52b6c5f5af86"
[[package]]
name = "unicode-bidi-mirroring"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5dfa6e8c60bb66d49db113e0125ee8711b7647b5579dc7f5f19c42357ed039fe"
[[package]]
name = "unicode-ccc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1df77b101bcc4ea3d78dafc5ad7e4f58ceffe0b2b16bf446aeb50b6cb4157656"
[[package]]
name = "unicode-ccc"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce61d488bcdc9bc8b5d1772c404828b17fc481c0a582b5581e95fb233aef503e"
[[package]]
name = "unicode-ident"
version = "1.0.14"
@@ -14967,12 +14849,6 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-vo"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1d386ff53b415b7fe27b50bb44679e2cc4660272694b7b6f3326d8480823a94"
[[package]]
name = "unicode-width"
version = "0.1.14"
@@ -15023,28 +14899,23 @@ checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "usvg"
version = "0.45.0"
version = "0.44.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ac8e0e3e4696253dc06167990b3fe9a2668ab66270adf949a464db4088cb354"
checksum = "7447e703d7223b067607655e625e0dbca80822880248937da65966194c4864e6"
dependencies = [
"base64 0.22.1",
"data-url",
"flate2",
"fontdb 0.23.0",
"imagesize",
"kurbo",
"log",
"pico-args",
"roxmltree",
"rustybuzz 0.20.1",
"simplecss",
"siphasher 1.0.1",
"strict-num",
"svgtypes",
"tiny-skia-path",
"unicode-bidi",
"unicode-script",
"unicode-vo",
"xmlwriter",
]
@@ -15202,7 +15073,6 @@ dependencies = [
"collections",
"command_palette",
"command_palette_hooks",
"db",
"editor",
"futures 0.3.31",
"git_ui",
@@ -15228,7 +15098,6 @@ dependencies = [
"serde_json",
"settings",
"task",
"text",
"theme",
"tokio",
"ui",
@@ -16107,7 +15976,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.48.0",
]
[[package]]

View File

@@ -8,7 +8,6 @@ members = [
"crates/assistant",
"crates/assistant2",
"crates/assistant_context_editor",
"crates/assistant_eval",
"crates/assistant_settings",
"crates/assistant_slash_command",
"crates/assistant_slash_commands",
@@ -65,7 +64,6 @@ members = [
"crates/gpui_tokio",
"crates/html_to_markdown",
"crates/http_client",
"crates/http_client_tls",
"crates/image_viewer",
"crates/indexed_docs",
"crates/inline_completion",
@@ -208,7 +206,6 @@ assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant2 = { path = "crates/assistant2" }
assistant_context_editor = { path = "crates/assistant_context_editor" }
assistant_eval = { path = "crates/assistant_eval" }
assistant_settings = { path = "crates/assistant_settings" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
@@ -263,7 +260,6 @@ gpui_macros = { path = "crates/gpui_macros" }
gpui_tokio = { path = "crates/gpui_tokio" }
html_to_markdown = { path = "crates/html_to_markdown" }
http_client = { path = "crates/http_client" }
http_client_tls = { path = "crates/http_client_tls" }
image_viewer = { path = "crates/image_viewer" }
indexed_docs = { path = "crates/indexed_docs" }
inline_completion = { path = "crates/inline_completion" }
@@ -421,7 +417,8 @@ fork = "0.2.0"
futures = "0.3"
futures-batch = "0.6.1"
futures-lite = "1.13"
git2 = { version = "0.20.1", default-features = false }
# TODO: get back to regular versions when https://github.com/rust-lang/git2-rs/pull/1120 is released
git2 = { git = "https://github.com/rust-lang/git2-rs", rev = "a3b90cb3756c1bb63e2317bf9cfa57838178de5c", default-features = false }
globset = "0.4"
handlebars = "4.3"
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
@@ -564,7 +561,6 @@ unindent = "0.2.0"
unicode-segmentation = "1.10"
unicode-script = "0.5.7"
url = "2.2"
urlencoding = "2.1.2"
uuid = { version = "1.1.2", features = ["v4", "v5", "v7", "serde"] }
wasmparser = "0.221"
wasm-encoder = "0.221"

View File

@@ -1,6 +1,6 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3 4H8" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M6 10L11 10" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<circle cx="4" cy="10" r="1.875" stroke="black" stroke-width="1.5"/>
<circle cx="10" cy="4" r="1.875" stroke="black" stroke-width="1.5"/>
<path d="M3 4H8" stroke="black" stroke-width="1.75" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M6 10L11 10" stroke="black" stroke-width="1.75" stroke-linecap="round" stroke-linejoin="round"/>
<circle cx="4" cy="10" r="1.875" stroke="black" stroke-width="1.75"/>
<circle cx="10" cy="4" r="1.875" stroke="black" stroke-width="1.75"/>
</svg>

Before

Width:  |  Height:  |  Size: 446 B

After

Width:  |  Height:  |  Size: 450 B

View File

@@ -107,7 +107,6 @@
"ctrl-a": "editor::SelectAll",
"ctrl-l": "editor::SelectLine",
"ctrl-shift-i": "editor::Format",
"alt-shift-o": "editor::OrganizeImports",
// "cmd-shift-left": ["editor::SelectToBeginningOfLine", {"stop_at_soft_wraps": true, "stop_at_indent": true }],
// "ctrl-shift-a": ["editor::SelectToBeginningOfLine", { "stop_at_soft_wraps": true, "stop_at_indent": true }],
"shift-home": ["editor::SelectToBeginningOfLine", { "stop_at_soft_wraps": true, "stop_at_indent": true }],
@@ -363,7 +362,6 @@
"ctrl-k ctrl-0": "editor::FoldAll",
"ctrl-k ctrl-j": "editor::UnfoldAll",
"ctrl-space": "editor::ShowCompletions",
"ctrl-shift-space": "editor::ShowWordCompletions",
"ctrl-.": "editor::ToggleCodeActions",
"ctrl-k r": "editor::RevealInFileManager",
"ctrl-k p": "editor::CopyPath",

View File

@@ -466,7 +466,6 @@
// Using `ctrl-space` in Zed requires disabling the macOS global shortcut.
// System Preferences->Keyboard->Keyboard Shortcuts->Input Sources->Select the previous input source (uncheck)
"ctrl-space": "editor::ShowCompletions",
"ctrl-shift-space": "editor::ShowWordCompletions",
"cmd-.": "editor::ToggleCodeActions",
"cmd-k r": "editor::RevealInFileManager",
"cmd-k p": "editor::CopyPath",

View File

@@ -31,7 +31,6 @@
"shift-alt-up": "editor::MoveLineUp",
"shift-alt-down": "editor::MoveLineDown",
"ctrl-alt-l": "editor::Format",
"ctrl-alt-o": "editor::OrganizeImports",
"shift-f6": "editor::Rename",
"ctrl-alt-left": "pane::GoBack",
"ctrl-alt-right": "pane::GoForward",

View File

@@ -29,7 +29,6 @@
"shift-alt-up": "editor::MoveLineUp",
"shift-alt-down": "editor::MoveLineDown",
"cmd-alt-l": "editor::Format",
"ctrl-alt-o": "editor::OrganizeImports",
"shift-f6": "editor::Rename",
"cmd-[": "pane::GoBack",
"cmd-]": "pane::GoForward",

View File

@@ -850,24 +850,8 @@
//
// The minimum column number to show the inline blame information at
// "min_column": 0
},
// How git hunks are displayed visually in the editor.
// This setting can take two values:
//
// 1. Show unstaged hunks filled and staged hunks hollow:
// "hunk_style": "staged_hollow"
// 2. Show unstaged hunks hollow and staged hunks filled:
// "hunk_style": "unstaged_hollow"
"hunk_style": "staged_hollow"
}
},
// The list of custom Git hosting providers.
"git_hosting_providers": [
// {
// "provider": "github",
// "name": "BigCorp GitHub",
// "base_url": "https://code.big-corp.com"
// }
],
// Configuration for how direnv configuration should be loaded. May take 2 values:
// 1. Load direnv configuration using `direnv export json` directly.
// "load_direnv": "direct"
@@ -1108,12 +1092,11 @@
//
// May take 3 values:
// 1. "enabled"
// Always fetch document's words for completions along with LSP completions.
// Always fetch document's words for completions.
// 2. "fallback"
// Only if LSP response errors or times out, use document's words to show completions.
// Only if LSP response errors/times out/is empty, use document's words to show completions.
// 3. "disabled"
// Never fetch or complete document's words for completions.
// (Word-based completions can still be queried via a separate action)
//
// Default: fallback
"words": "fallback",
@@ -1124,8 +1107,8 @@
// When fetching LSP completions, determines how long to wait for a response of a particular server.
// When set to 0, waits indefinitely.
//
// Default: 0
"lsp_fetch_timeout_ms": 0
// Default: 500
"lsp_fetch_timeout_ms": 500
},
// Different settings for specific languages.
"languages": {

View File

@@ -3569,7 +3569,6 @@ impl CodeActionProvider for AssistantCodeActionProvider {
title: "Fix with Assistant".into(),
..Default::default()
})),
resolved: true,
}]))
} else {
Task::ready(Ok(Vec::new()))

View File

@@ -38,7 +38,6 @@ file_icons.workspace = true
fs.workspace = true
futures.workspace = true
fuzzy.workspace = true
git.workspace = true
gpui.workspace = true
heed.workspace = true
html_to_markdown.workspace = true
@@ -66,7 +65,6 @@ serde_json.workspace = true
settings.workspace = true
smol.workspace = true
streaming_diff.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
terminal.workspace = true
terminal_view.workspace = true

View File

@@ -1,7 +1,6 @@
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
use std::sync::Arc;
use std::time::Duration;
use collections::HashMap;
use editor::{Editor, MultiBuffer};
use gpui::{
@@ -15,20 +14,19 @@ use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle};
use scripting_tool::{ScriptingTool, ScriptingToolInput};
use settings::Settings as _;
use std::sync::Arc;
use std::time::Duration;
use theme::ThemeSettings;
use ui::Color;
use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _;
use crate::context_store::{refresh_context_store_text, ContextStore};
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
pub struct ActiveThread {
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<Thread>,
context_store: Entity<ContextStore>,
save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>,
list_state: ListState,
@@ -49,7 +47,6 @@ impl ActiveThread {
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
language_registry: Arc<LanguageRegistry>,
context_store: Entity<ContextStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -62,7 +59,6 @@ impl ActiveThread {
language_registry,
thread_store,
thread: thread.clone(),
context_store,
save_thread_task: None,
messages: Vec::new(),
rendered_messages_by_id: HashMap::default(),
@@ -116,7 +112,7 @@ impl ActiveThread {
pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
self.last_error.take();
self.thread
.update(cx, |thread, cx| thread.cancel_last_completion(cx))
.update(cx, |thread, _cx| thread.cancel_last_completion())
}
pub fn last_error(&self) -> Option<ThreadError> {
@@ -298,7 +294,6 @@ impl ActiveThread {
ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
self.save_thread(cx);
}
ThreadEvent::DoneStreaming => {}
ThreadEvent::StreamedAssistantText(message_id, text) => {
if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
markdown.update(cx, |markdown, cx| {
@@ -343,11 +338,8 @@ impl ActiveThread {
});
}
ThreadEvent::ToolFinished {
pending_tool_use,
canceled,
..
pending_tool_use, ..
} => {
let canceled = *canceled;
if let Some(tool_use) = pending_tool_use {
self.render_scripting_tool_use_markdown(
tool_use.id.clone(),
@@ -359,54 +351,11 @@ impl ActiveThread {
}
if self.thread.read(cx).all_tools_finished() {
let pending_refresh_buffers = self.thread.update(cx, |thread, cx| {
thread.action_log().update(cx, |action_log, _cx| {
action_log.take_stale_buffers_in_context()
})
});
let context_update_task = if !pending_refresh_buffers.is_empty() {
let refresh_task = refresh_context_store_text(
self.context_store.clone(),
&pending_refresh_buffers,
cx,
);
cx.spawn(|this, mut cx| async move {
let updated_context_ids = refresh_task.await;
this.update(&mut cx, |this, cx| {
this.context_store.read_with(cx, |context_store, cx| {
context_store
.context()
.iter()
.filter(|context| {
updated_context_ids.contains(&context.id())
})
.flat_map(|context| context.snapshot(cx))
.collect()
})
})
})
} else {
Task::ready(anyhow::Ok(Vec::new()))
};
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
cx.spawn(|this, mut cx| async move {
let updated_context = context_update_task.await?;
this.update(&mut cx, |this, cx| {
this.thread.update(cx, |thread, cx| {
thread.attach_tool_results(updated_context, cx);
if !canceled {
thread.send_to_model(model, RequestKind::Chat, cx);
}
});
})
})
.detach();
self.thread.update(cx, |thread, cx| {
thread.send_tool_results_to_model(model, cx);
});
}
}
}
@@ -549,7 +498,7 @@ impl ActiveThread {
};
let thread = self.thread.read(cx);
// Get all the data we need from thread before we start using it in closures
let context = thread.context_for_message(message_id);
let tool_uses = thread.tool_uses_for_message(message_id);
let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id);
@@ -704,27 +653,28 @@ impl ActiveThread {
)
.child(message_content),
),
Role::Assistant => {
v_flex()
.id(("message-container", ix))
.child(message_content)
.when(
!tool_uses.is_empty() || !scripting_tool_uses.is_empty(),
|parent| {
parent.child(
v_flex()
.children(
tool_uses
.into_iter()
.map(|tool_use| self.render_tool_use(tool_use, cx)),
)
.children(scripting_tool_uses.into_iter().map(|tool_use| {
self.render_scripting_tool_use(tool_use, cx)
})),
Role::Assistant => v_flex()
.id(("message-container", ix))
.child(message_content)
.map(|parent| {
if tool_uses.is_empty() && scripting_tool_uses.is_empty() {
return parent;
}
parent.child(
v_flex()
.children(
tool_uses
.into_iter()
.map(|tool_use| self.render_tool_use(tool_use, cx)),
)
},
.children(
scripting_tool_uses
.into_iter()
.map(|tool_use| self.render_scripting_tool_use(tool_use, cx)),
),
)
}
}),
Role::System => div().id(("message-container", ix)).py_1().px_2().child(
v_flex()
.bg(colors.editor_background)
@@ -743,13 +693,12 @@ impl ActiveThread {
.copied()
.unwrap_or_default();
let lighter_border = cx.theme().colors().border.opacity(0.5);
div().px_2p5().child(
v_flex()
.gap_1()
.rounded_lg()
.border_1()
.border_color(lighter_border)
.border_color(cx.theme().colors().border)
.child(
h_flex()
.justify_between()
@@ -764,7 +713,7 @@ impl ActiveThread {
element.rounded_md()
}
})
.border_color(lighter_border)
.border_color(cx.theme().colors().border)
.child(
h_flex()
.gap_1()
@@ -781,11 +730,7 @@ impl ActiveThread {
}
}),
))
.child(
Label::new(tool_use.name)
.size(LabelSize::Small)
.buffer_font(cx),
),
.child(Label::new(tool_use.name)),
)
.child({
let (icon_name, color, animated) = match &tool_use.status {
@@ -822,88 +767,39 @@ impl ActiveThread {
return parent;
}
let content_container = || v_flex().py_1().gap_0p5().px_2p5();
parent.child(
v_flex()
.gap_1()
.bg(cx.theme().colors().editor_background)
.rounded_b_lg()
.child(
content_container()
v_flex()
.gap_0p5()
.py_1()
.px_2p5()
.border_b_1()
.border_color(lighter_border)
.child(
Label::new("Input")
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx),
)
.child(
Label::new(
serde_json::to_string_pretty(&tool_use.input)
.unwrap_or_default(),
)
.size(LabelSize::Small)
.buffer_font(cx),
),
.border_color(cx.theme().colors().border)
.child(Label::new("Input:"))
.child(Label::new(
serde_json::to_string_pretty(&tool_use.input)
.unwrap_or_default(),
)),
)
.map(|container| match tool_use.status {
ToolUseStatus::Finished(output) => container.child(
content_container()
.child(
Label::new("Result")
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx),
)
.child(
Label::new(output)
.size(LabelSize::Small)
.buffer_font(cx),
),
.map(|parent| match tool_use.status {
ToolUseStatus::Finished(output) => parent.child(
v_flex()
.gap_0p5()
.py_1()
.px_2p5()
.child(Label::new("Result:"))
.child(Label::new(output)),
),
ToolUseStatus::Running => container.child(
content_container().child(
h_flex()
.gap_1()
.pb_1()
.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.color(Color::Accent)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2))
.repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(
percentage(delta),
))
},
),
)
.child(
Label::new("Running…")
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx),
),
),
ToolUseStatus::Error(err) => parent.child(
v_flex()
.gap_0p5()
.py_1()
.px_2p5()
.child(Label::new("Error:"))
.child(Label::new(err)),
),
ToolUseStatus::Error(err) => container.child(
content_container()
.child(
Label::new("Error")
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx),
)
.child(
Label::new(err).size(LabelSize::Small).buffer_font(cx),
),
),
ToolUseStatus::Pending => container,
ToolUseStatus::Pending | ToolUseStatus::Running => parent,
}),
)
}),

View File

@@ -31,11 +31,8 @@ use gpui::{actions, App};
use prompt_store::PromptBuilder;
use settings::Settings as _;
pub use crate::active_thread::ActiveThread;
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;
actions!(
assistant2,

View File

@@ -155,14 +155,10 @@ impl AssistantPanel {
let workspace = workspace.weak_handle();
let weak_self = cx.entity().downgrade();
let message_editor_context_store =
cx.new(|_cx| crate::context_store::ContextStore::new(workspace.clone()));
let message_editor = cx.new(|cx| {
MessageEditor::new(
fs.clone(),
workspace.clone(),
message_editor_context_store.clone(),
thread_store.downgrade(),
thread.clone(),
window,
@@ -178,7 +174,6 @@ impl AssistantPanel {
thread.clone(),
thread_store.clone(),
language_registry.clone(),
message_editor_context_store.clone(),
window,
cx,
)
@@ -247,16 +242,11 @@ impl AssistantPanel {
.update(cx, |this, cx| this.create_thread(cx));
self.active_view = ActiveView::Thread;
let message_editor_context_store =
cx.new(|_cx| crate::context_store::ContextStore::new(self.workspace.clone()));
self.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
self.language_registry.clone(),
message_editor_context_store.clone(),
window,
cx,
)
@@ -265,7 +255,6 @@ impl AssistantPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
message_editor_context_store,
self.thread_store.downgrade(),
thread,
window,
@@ -386,14 +375,11 @@ impl AssistantPanel {
let thread = open_thread_task.await?;
this.update_in(&mut cx, |this, window, cx| {
this.active_view = ActiveView::Thread;
let message_editor_context_store =
cx.new(|_cx| crate::context_store::ContextStore::new(this.workspace.clone()));
this.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
this.thread_store.clone(),
this.language_registry.clone(),
message_editor_context_store.clone(),
window,
cx,
)
@@ -402,7 +388,6 @@ impl AssistantPanel {
MessageEditor::new(
this.fs.clone(),
this.workspace.clone(),
message_editor_context_store,
this.thread_store.downgrade(),
thread,
window,

View File

@@ -43,6 +43,15 @@ pub enum ContextKind {
}
impl ContextKind {
pub fn label(&self) -> &'static str {
match self {
ContextKind::File => "File",
ContextKind::Directory => "Folder",
ContextKind::FetchedUrl => "Fetch",
ContextKind::Thread => "Thread",
}
}
pub fn icon(&self) -> IconName {
match self {
ContextKind::File => IconName::File,

View File

@@ -1,3 +1,4 @@
mod directory_context_picker;
mod fetch_context_picker;
mod file_context_picker;
mod thread_context_picker;
@@ -14,6 +15,8 @@ use thread_context_picker::{render_thread_context_entry, ThreadContextEntry};
use ui::{prelude::*, ContextMenu, ContextMenuEntry, ContextMenuItem};
use workspace::{notifications::NotifyResultExt, Workspace};
use crate::context::ContextKind;
use crate::context_picker::directory_context_picker::DirectoryContextPicker;
use crate::context_picker::fetch_context_picker::FetchContextPicker;
use crate::context_picker::file_context_picker::FileContextPicker;
use crate::context_picker::thread_context_picker::ThreadContextPicker;
@@ -27,41 +30,17 @@ pub enum ConfirmBehavior {
Close,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContextPickerMode {
File,
Fetch,
Thread,
}
impl ContextPickerMode {
pub fn label(&self) -> &'static str {
match self {
Self::File => "File/Directory",
Self::Fetch => "Fetch",
Self::Thread => "Thread",
}
}
pub fn icon(&self) -> IconName {
match self {
Self::File => IconName::File,
Self::Fetch => IconName::Globe,
Self::Thread => IconName::MessageCircle,
}
}
}
#[derive(Debug, Clone)]
enum ContextPickerState {
enum ContextPickerMode {
Default(Entity<ContextMenu>),
File(Entity<FileContextPicker>),
Directory(Entity<DirectoryContextPicker>),
Fetch(Entity<FetchContextPicker>),
Thread(Entity<ThreadContextPicker>),
}
pub(super) struct ContextPicker {
mode: ContextPickerState,
mode: ContextPickerMode,
workspace: WeakEntity<Workspace>,
editor: WeakEntity<Editor>,
context_store: WeakEntity<ContextStore>,
@@ -80,7 +59,7 @@ impl ContextPicker {
cx: &mut Context<Self>,
) -> Self {
ContextPicker {
mode: ContextPickerState::Default(ContextMenu::build(
mode: ContextPickerMode::Default(ContextMenu::build(
window,
cx,
|menu, _window, _cx| menu,
@@ -94,7 +73,7 @@ impl ContextPicker {
}
pub fn init(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.mode = ContextPickerState::Default(self.build_menu(window, cx));
self.mode = ContextPickerMode::Default(self.build_menu(window, cx));
cx.notify();
}
@@ -109,9 +88,13 @@ impl ContextPicker {
.enumerate()
.map(|(ix, entry)| self.recent_menu_item(context_picker.clone(), ix, entry));
let mut modes = vec![ContextPickerMode::File, ContextPickerMode::Fetch];
let mut context_kinds = vec![
ContextKind::File,
ContextKind::Directory,
ContextKind::FetchedUrl,
];
if self.allow_threads() {
modes.push(ContextPickerMode::Thread);
context_kinds.push(ContextKind::Thread);
}
let menu = menu
@@ -129,15 +112,15 @@ impl ContextPicker {
})
.extend(recent_entries)
.when(has_recent, |menu| menu.separator())
.extend(modes.into_iter().map(|mode| {
.extend(context_kinds.into_iter().map(|kind| {
let context_picker = context_picker.clone();
ContextMenuEntry::new(mode.label())
.icon(mode.icon())
ContextMenuEntry::new(kind.label())
.icon(kind.icon())
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.handler(move |window, cx| {
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
context_picker.update(cx, |this, cx| this.select_kind(kind, window, cx))
})
}));
@@ -160,17 +143,12 @@ impl ContextPicker {
self.thread_store.is_some()
}
fn select_mode(
&mut self,
mode: ContextPickerMode,
window: &mut Window,
cx: &mut Context<Self>,
) {
fn select_kind(&mut self, kind: ContextKind, window: &mut Window, cx: &mut Context<Self>) {
let context_picker = cx.entity().downgrade();
match mode {
ContextPickerMode::File => {
self.mode = ContextPickerState::File(cx.new(|cx| {
match kind {
ContextKind::File => {
self.mode = ContextPickerMode::File(cx.new(|cx| {
FileContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
@@ -182,8 +160,20 @@ impl ContextPicker {
)
}));
}
ContextPickerMode::Fetch => {
self.mode = ContextPickerState::Fetch(cx.new(|cx| {
ContextKind::Directory => {
self.mode = ContextPickerMode::Directory(cx.new(|cx| {
DirectoryContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
}));
}
ContextKind::FetchedUrl => {
self.mode = ContextPickerMode::Fetch(cx.new(|cx| {
FetchContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
@@ -194,9 +184,9 @@ impl ContextPicker {
)
}));
}
ContextPickerMode::Thread => {
ContextKind::Thread => {
if let Some(thread_store) = self.thread_store.as_ref() {
self.mode = ContextPickerState::Thread(cx.new(|cx| {
self.mode = ContextPickerMode::Thread(cx.new(|cx| {
ThreadContextPicker::new(
thread_store.clone(),
context_picker.clone(),
@@ -234,7 +224,6 @@ impl ContextPicker {
ElementId::NamedInteger("ctx-recent".into(), ix),
&path,
&path_prefix,
false,
context_store.clone(),
cx,
)
@@ -403,10 +392,11 @@ impl EventEmitter<DismissEvent> for ContextPicker {}
impl Focusable for ContextPicker {
fn focus_handle(&self, cx: &App) -> FocusHandle {
match &self.mode {
ContextPickerState::Default(menu) => menu.focus_handle(cx),
ContextPickerState::File(file_picker) => file_picker.focus_handle(cx),
ContextPickerState::Fetch(fetch_picker) => fetch_picker.focus_handle(cx),
ContextPickerState::Thread(thread_picker) => thread_picker.focus_handle(cx),
ContextPickerMode::Default(menu) => menu.focus_handle(cx),
ContextPickerMode::File(file_picker) => file_picker.focus_handle(cx),
ContextPickerMode::Directory(directory_picker) => directory_picker.focus_handle(cx),
ContextPickerMode::Fetch(fetch_picker) => fetch_picker.focus_handle(cx),
ContextPickerMode::Thread(thread_picker) => thread_picker.focus_handle(cx),
}
}
}
@@ -417,10 +407,13 @@ impl Render for ContextPicker {
.w(px(400.))
.min_w(px(400.))
.map(|parent| match &self.mode {
ContextPickerState::Default(menu) => parent.child(menu.clone()),
ContextPickerState::File(file_picker) => parent.child(file_picker.clone()),
ContextPickerState::Fetch(fetch_picker) => parent.child(fetch_picker.clone()),
ContextPickerState::Thread(thread_picker) => parent.child(thread_picker.clone()),
ContextPickerMode::Default(menu) => parent.child(menu.clone()),
ContextPickerMode::File(file_picker) => parent.child(file_picker.clone()),
ContextPickerMode::Directory(directory_picker) => {
parent.child(directory_picker.clone())
}
ContextPickerMode::Fetch(fetch_picker) => parent.child(fetch_picker.clone()),
ContextPickerMode::Thread(thread_picker) => parent.child(thread_picker.clone()),
})
}
}

View File

@@ -0,0 +1,269 @@
use std::path::Path;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use fuzzy::PathMatch;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use project::{PathMatchCandidateSet, ProjectPath, WorktreeId};
use ui::{prelude::*, ListItem};
use util::ResultExt as _;
use workspace::{notifications::NotifyResultExt, Workspace};
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_store::ContextStore;
pub struct DirectoryContextPicker {
picker: Entity<Picker<DirectoryContextPickerDelegate>>,
}
impl DirectoryContextPicker {
pub fn new(
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = DirectoryContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
}
}
impl Focusable for DirectoryContextPicker {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.picker.focus_handle(cx)
}
}
impl Render for DirectoryContextPicker {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.picker.clone()
}
}
pub struct DirectoryContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<PathMatch>,
selected_index: usize,
}
impl DirectoryContextPickerDelegate {
pub fn new(
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
Self {
context_picker,
workspace,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
}
fn search(
&mut self,
query: String,
cancellation_flag: Arc<AtomicBool>,
workspace: &Entity<Workspace>,
cx: &mut Context<Picker<Self>>,
) -> Task<Vec<PathMatch>> {
if query.is_empty() {
let workspace = workspace.read(cx);
let project = workspace.project().read(cx);
let directory_matches = project.worktrees(cx).flat_map(|worktree| {
let worktree = worktree.read(cx);
let path_prefix: Arc<str> = worktree.root_name().into();
worktree.directories(false, 0).map(move |entry| PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: worktree.id().to_usize(),
path: entry.path.clone(),
path_prefix: path_prefix.clone(),
distance_to_relative_ancestor: 0,
is_dir: true,
})
});
Task::ready(directory_matches.collect())
} else {
let worktrees = workspace.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
let candidate_sets = worktrees
.into_iter()
.map(|worktree| {
let worktree = worktree.read(cx);
PathMatchCandidateSet {
snapshot: worktree.snapshot(),
include_ignored: worktree
.root_entry()
.map_or(false, |entry| entry.is_ignored),
include_root_name: true,
candidates: project::Candidates::Directories,
}
})
.collect::<Vec<_>>();
let executor = cx.background_executor().clone();
cx.foreground_executor().spawn(async move {
fuzzy::match_path_sets(
candidate_sets.as_slice(),
query.as_str(),
None,
false,
100,
&cancellation_flag,
executor,
)
.await
})
}
}
}
impl PickerDelegate for DirectoryContextPickerDelegate {
type ListItem = ListItem;
fn match_count(&self) -> usize {
self.matches.len()
}
fn selected_index(&self) -> usize {
self.selected_index
}
fn set_selected_index(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) {
self.selected_index = ix;
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Search folders…".into()
}
fn update_matches(
&mut self,
query: String,
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let Some(workspace) = self.workspace.upgrade() else {
return Task::ready(());
};
let search_task = self.search(query, Arc::<AtomicBool>::default(), &workspace, cx);
cx.spawn(|this, mut cx| async move {
let mut paths = search_task.await;
let empty_path = Path::new("");
paths.retain(|path_match| path_match.path.as_ref() != empty_path);
this.update(&mut cx, |this, _cx| {
this.delegate.matches = paths;
})
.log_err();
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(mat) = self.matches.get(self.selected_index) else {
return;
};
let project_path = ProjectPath {
worktree_id: WorktreeId::from_usize(mat.worktree_id),
path: mat.path.clone(),
};
let Some(task) = self
.context_store
.update(cx, |context_store, cx| {
context_store.add_directory(project_path, cx)
})
.ok()
else {
return;
};
let confirm_behavior = self.confirm_behavior;
cx.spawn_in(window, |this, mut cx| async move {
match task.await.notify_async_err(&mut cx) {
None => anyhow::Ok(()),
Some(()) => this.update_in(&mut cx, |this, window, cx| match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}),
}
})
.detach_and_log_err(cx);
}
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
self.context_picker
.update(cx, |_, cx| {
cx.emit(DismissEvent);
})
.ok();
}
fn render_match(
&self,
ix: usize,
selected: bool,
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let path_match = &self.matches[ix];
let directory_name = path_match.path.to_string_lossy().to_string();
let added = self.context_store.upgrade().map_or(false, |context_store| {
context_store
.read(cx)
.includes_directory(&path_match.path)
.is_some()
});
Some(
ListItem::new(ix)
.inset(true)
.toggle_state(selected)
.start_slot(
Icon::new(IconName::Folder)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(Label::new(directory_name))
.when(added, |el| {
el.end_slot(
h_flex()
.gap_1()
.child(
Icon::new(IconName::Check)
.size(IconSize::Small)
.color(Color::Success),
)
.child(Label::new("Added").size(LabelSize::Small)),
)
}),
)
}
}

View File

@@ -99,6 +99,7 @@ impl FileContextPickerDelegate {
query: String,
cancellation_flag: Arc<AtomicBool>,
workspace: &Entity<Workspace>,
cx: &mut Context<Picker<Self>>,
) -> Task<Vec<PathMatch>> {
if query.is_empty() {
@@ -123,14 +124,14 @@ impl FileContextPickerDelegate {
let file_matches = project.worktrees(cx).flat_map(|worktree| {
let worktree = worktree.read(cx);
let path_prefix: Arc<str> = worktree.root_name().into();
worktree.entries(false, 0).map(move |entry| PathMatch {
worktree.files(false, 0).map(move |entry| PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: worktree.id().to_usize(),
path: entry.path.clone(),
path_prefix: path_prefix.clone(),
distance_to_relative_ancestor: 0,
is_dir: entry.is_dir(),
is_dir: false,
})
});
@@ -148,7 +149,7 @@ impl FileContextPickerDelegate {
.root_entry()
.map_or(false, |entry| entry.is_ignored),
include_root_name: true,
candidates: project::Candidates::Entries,
candidates: project::Candidates::Files,
}
})
.collect::<Vec<_>>();
@@ -191,7 +192,7 @@ impl PickerDelegate for FileContextPickerDelegate {
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Search files & directories".into()
"Search files…".into()
}
fn update_matches(
@@ -222,11 +223,13 @@ impl PickerDelegate for FileContextPickerDelegate {
return;
};
let file_name = mat
let Some(file_name) = mat
.path
.file_name()
.map(|os_str| os_str.to_string_lossy().into_owned())
.unwrap_or(mat.path_prefix.to_string());
else {
return;
};
let full_path = mat.path.display().to_string();
@@ -235,8 +238,6 @@ impl PickerDelegate for FileContextPickerDelegate {
path: mat.path.clone(),
};
let is_directory = mat.is_dir;
let Some(editor_entity) = self.editor.upgrade() else {
return;
};
@@ -287,12 +288,8 @@ impl PickerDelegate for FileContextPickerDelegate {
editor.insert("\n", window, cx); // Needed to end the fold
let file_icon = if is_directory {
FileIcons::get_folder_icon(false, cx)
} else {
FileIcons::get_icon(&Path::new(&full_path), cx)
}
.unwrap_or_else(|| SharedString::new(""));
let file_icon = FileIcons::get_icon(&Path::new(&full_path), cx)
.unwrap_or_else(|| SharedString::new(""));
let placeholder = FoldPlaceholder {
render: render_fold_icon_button(
@@ -333,11 +330,7 @@ impl PickerDelegate for FileContextPickerDelegate {
let Some(task) = self
.context_store
.update(cx, |context_store, cx| {
if is_directory {
context_store.add_directory(project_path, cx)
} else {
context_store.add_file_from_path(project_path, cx)
}
context_store.add_file_from_path(project_path, cx)
})
.ok()
else {
@@ -382,7 +375,6 @@ impl PickerDelegate for FileContextPickerDelegate {
ElementId::NamedInteger("file-ctx-picker".into(), ix),
&path_match.path,
&path_match.path_prefix,
path_match.is_dir,
self.context_store.clone(),
cx,
)),
@@ -394,7 +386,6 @@ pub fn render_file_context_entry(
id: ElementId,
path: &Path,
path_prefix: &Arc<str>,
is_directory: bool,
context_store: WeakEntity<ContextStore>,
cx: &App,
) -> Stateful<Div> {
@@ -418,24 +409,13 @@ pub fn render_file_context_entry(
(file_name, Some(directory))
};
let added = context_store.upgrade().and_then(|context_store| {
if is_directory {
context_store
.read(cx)
.includes_directory(path)
.map(FileInclusion::Direct)
} else {
context_store.read(cx).will_include_file_path(path, cx)
}
});
let added = context_store
.upgrade()
.and_then(|context_store| context_store.read(cx).will_include_file_path(path, cx));
let file_icon = if is_directory {
FileIcons::get_folder_icon(false, cx)
} else {
FileIcons::get_icon(&path, cx)
}
.map(Icon::from_path)
.unwrap_or_else(|| Icon::new(IconName::File));
let file_icon = FileIcons::get_icon(&path, cx)
.map(Icon::from_path)
.unwrap_or_else(|| Icon::new(IconName::File));
h_flex()
.id(id)

View File

@@ -223,18 +223,13 @@ pub fn render_thread_context_entry(
h_flex()
.gap_1p5()
.w_full()
.justify_between()
.child(
h_flex()
.gap_1p5()
.max_w_72()
.child(
Icon::new(IconName::MessageCircle)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(Label::new(thread.summary.clone()).truncate()),
Icon::new(IconName::MessageCircle)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(Label::new(thread.summary.clone()))
.child(div().w_full())
.when(added, |el| {
el.child(
h_flex()

View File

@@ -9,7 +9,6 @@ use language::Buffer;
use project::{ProjectPath, Worktree};
use rope::Rope;
use text::BufferId;
use util::maybe;
use workspace::Workspace;
use crate::context::{
@@ -532,59 +531,35 @@ fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<Arc<Path>> {
pub fn refresh_context_store_text(
context_store: Entity<ContextStore>,
changed_buffers: &HashSet<Entity<Buffer>>,
cx: &App,
) -> impl Future<Output = Vec<ContextId>> {
) -> impl Future<Output = ()> {
let mut tasks = Vec::new();
for context in &context_store.read(cx).context {
let id = context.id();
let task = maybe!({
match context {
AssistantContext::File(file_context) => {
if changed_buffers.is_empty()
|| changed_buffers.contains(&file_context.context_buffer.buffer)
{
let context_store = context_store.clone();
return refresh_file_text(context_store, file_context, cx);
}
match context {
AssistantContext::File(file_context) => {
let context_store = context_store.clone();
if let Some(task) = refresh_file_text(context_store, file_context, cx) {
tasks.push(task);
}
AssistantContext::Directory(directory_context) => {
let should_refresh = changed_buffers.is_empty()
|| changed_buffers.iter().any(|buffer| {
let buffer = buffer.read(cx);
buffer_path_log_err(&buffer)
.map_or(false, |path| path.starts_with(&directory_context.path))
});
if should_refresh {
let context_store = context_store.clone();
return refresh_directory_text(context_store, directory_context, cx);
}
}
AssistantContext::Thread(thread_context) => {
if changed_buffers.is_empty() {
let context_store = context_store.clone();
return Some(refresh_thread_text(context_store, thread_context, cx));
}
}
// Intentionally omit refreshing fetched URLs as it doesn't seem all that useful,
// and doing the caching properly could be tricky (unless it's already handled by
// the HttpClient?).
AssistantContext::FetchedUrl(_) => {}
}
None
});
if let Some(task) = task {
tasks.push(task.map(move |_| id));
AssistantContext::Directory(directory_context) => {
let context_store = context_store.clone();
if let Some(task) = refresh_directory_text(context_store, directory_context, cx) {
tasks.push(task);
}
}
AssistantContext::Thread(thread_context) => {
let context_store = context_store.clone();
tasks.push(refresh_thread_text(context_store, thread_context, cx));
}
// Intentionally omit refreshing fetched URLs as it doesn't seem all that useful,
// and doing the caching properly could be tricky (unless it's already handled by
// the HttpClient?).
AssistantContext::FetchedUrl(_) => {}
}
}
future::join_all(tasks)
future::join_all(tasks).map(|_| ())
}
fn refresh_file_text(

View File

@@ -2,10 +2,10 @@ use assistant_context_editor::SavedContextMetadata;
use chrono::{DateTime, Utc};
use gpui::{prelude::*, Entity};
use crate::thread_store::{SerializedThreadMetadata, ThreadStore};
use crate::thread_store::{SavedThreadMetadata, ThreadStore};
pub enum HistoryEntry {
Thread(SerializedThreadMetadata),
Thread(SavedThreadMetadata),
Context(SavedContextMetadata),
}

View File

@@ -1729,7 +1729,6 @@ impl CodeActionProvider for AssistantCodeActionProvider {
title: "Fix with Assistant".into(),
..Default::default()
})),
resolved: true,
}]))
} else {
Task::ready(Ok(Vec::new()))

View File

@@ -1,6 +1,5 @@
use std::sync::Arc;
use collections::HashSet;
use editor::actions::MoveUp;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle};
use file_icons::FileIcons;
@@ -21,8 +20,7 @@ use ui::{
Tooltip,
};
use vim_mode_setting::VimModeSetting;
use workspace::notifications::{NotificationId, NotifyTaskExt};
use workspace::{Toast, Workspace};
use workspace::Workspace;
use crate::assistant_model_selector::AssistantModelSelector;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
@@ -36,7 +34,6 @@ use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker};
pub struct MessageEditor {
thread: Entity<Thread>,
editor: Entity<Editor>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
@@ -52,13 +49,13 @@ impl MessageEditor {
pub fn new(
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
thread_store: WeakEntity<ThreadStore>,
thread: Entity<Thread>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let tools = thread.read(cx).tools().clone();
let context_store = cx.new(|_cx| ContextStore::new(workspace.clone()));
let context_picker_menu_handle = PopoverMenuHandle::default();
let inline_context_picker_menu_handle = PopoverMenuHandle::default();
let model_selector_menu_handle = PopoverMenuHandle::default();
@@ -109,7 +106,6 @@ impl MessageEditor {
Self {
thread,
editor: editor.clone(),
workspace,
context_store,
context_strip,
context_picker_menu_handle,
@@ -158,7 +154,7 @@ impl MessageEditor {
return;
}
if self.thread.read(cx).is_generating() {
if self.thread.read(cx).is_streaming() {
return;
}
@@ -201,8 +197,7 @@ impl MessageEditor {
text
});
let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
let refresh_task = refresh_context_store_text(self.context_store.clone(), cx);
let thread = self.thread.clone();
let context_store = self.context_store.clone();
@@ -285,34 +280,6 @@ impl MessageEditor {
self.context_strip.focus_handle(cx).focus(window);
}
}
fn handle_feedback_click(
&mut self,
is_positive: bool,
window: &mut Window,
cx: &mut Context<Self>,
) {
let workspace = self.workspace.clone();
let report = self
.thread
.update(cx, |thread, cx| thread.report_feedback(is_positive, cx));
cx.spawn(|_, mut cx| async move {
report.await?;
workspace.update(&mut cx, |workspace, cx| {
let message = if is_positive {
"Positive feedback recorded. Thank you!"
} else {
"Negative feedback recorded. Thank you for helping us improve!"
};
struct ThreadFeedback;
let id = NotificationId::unique::<ThreadFeedback>();
workspace.show_toast(Toast::new(id, message).autohide(), cx)
})
})
.detach_and_notify_err(window, cx);
}
}
impl Focusable for MessageEditor {
@@ -328,7 +295,7 @@ impl Render for MessageEditor {
let focus_handle = self.editor.focus_handle(cx);
let inline_context_picker = self.inline_context_picker.clone();
let bg_color = cx.theme().colors().editor_background;
let is_generating = self.thread.read(cx).is_generating();
let is_streaming_completion = self.thread.read(cx).is_streaming();
let is_model_selected = self.is_model_selected(cx);
let is_editor_empty = self.is_editor_empty(cx);
let submit_label_color = if is_editor_empty {
@@ -352,7 +319,7 @@ impl Render for MessageEditor {
v_flex()
.size_full()
.when(is_generating, |parent| {
.when(is_streaming_completion, |parent| {
let focus_handle = self.editor.focus_handle(cx).clone();
parent.child(
h_flex().py_3().w_full().justify_center().child(
@@ -530,45 +497,7 @@ impl Render for MessageEditor {
.bg(bg_color)
.border_t_1()
.border_color(cx.theme().colors().border)
.child(
h_flex()
.justify_between()
.child(self.context_strip.clone())
.when(!self.thread.read(cx).is_empty(), |this| {
this.child(
h_flex()
.gap_2()
.child(
IconButton::new(
"feedback-thumbs-up",
IconName::ThumbsUp,
)
.style(ButtonStyle::Subtle)
.icon_size(IconSize::Small)
.tooltip(Tooltip::text("Helpful"))
.on_click(
cx.listener(|this, _, window, cx| {
this.handle_feedback_click(true, window, cx);
}),
),
)
.child(
IconButton::new(
"feedback-thumbs-down",
IconName::ThumbsDown,
)
.style(ButtonStyle::Subtle)
.icon_size(IconSize::Small)
.tooltip(Tooltip::text("Not Helpful"))
.on_click(
cx.listener(|this, _, window, cx| {
this.handle_feedback_click(false, window, cx);
}),
),
),
)
}),
)
.child(self.context_strip.clone())
.child(
v_flex()
.gap_5()
@@ -625,7 +554,7 @@ impl Render for MessageEditor {
.disabled(
is_editor_empty
|| !is_model_selected
|| is_generating,
|| is_streaming_completion,
)
.child(
h_flex()
@@ -660,7 +589,7 @@ impl Render for MessageEditor {
"Type a message to submit",
))
})
.when(is_generating, |button| {
.when(is_streaming_completion, |button| {
button.tooltip(Tooltip::text(
"Cancel to submit a new message",
))

View File

@@ -1,14 +1,11 @@
use std::fmt::Write as _;
use std::io::Write;
use std::sync::Arc;
use anyhow::{Context as _, Result};
use assistant_tool::{ActionLog, ToolWorkingSet};
use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git;
use futures::StreamExt as _;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
@@ -24,9 +21,7 @@ use util::{post_inc, ResultExt, TryFutureExt as _};
use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
use crate::thread_store::{
SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
};
use crate::thread_store::SavedThread;
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
#[derive(Debug, Clone, Copy)]
@@ -68,27 +63,6 @@ pub struct Message {
pub text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProjectSnapshot {
pub worktree_snapshots: Vec<WorktreeSnapshot>,
pub unsaved_buffer_paths: Vec<String>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorktreeSnapshot {
pub worktree_path: String,
pub git_state: Option<GitState>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GitState {
pub remote_url: Option<String>,
pub head_sha: Option<String>,
pub current_branch: Option<String>,
pub diff: Option<String>,
}
/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
@@ -105,10 +79,8 @@ pub struct Thread {
prompt_builder: Arc<PromptBuilder>,
tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState,
action_log: Entity<ActionLog>,
scripting_session: Entity<ScriptingSession>,
scripting_tool_use: ToolUseState,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
cumulative_token_usage: TokenUsage,
}
@@ -119,6 +91,8 @@ impl Thread {
prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
) -> Self {
let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
Self {
id: ThreadId::new(),
updated_at: Utc::now(),
@@ -130,53 +104,43 @@ impl Thread {
context_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
project: project.clone(),
project,
prompt_builder,
tools,
tool_use: ToolUseState::new(),
scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
scripting_session,
scripting_tool_use: ToolUseState::new(),
action_log: cx.new(|_| ActionLog::new()),
initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project, cx);
cx.foreground_executor()
.spawn(async move { Some(project_snapshot.await) })
.shared()
},
cumulative_token_usage: TokenUsage::default(),
}
}
pub fn deserialize(
pub fn from_saved(
id: ThreadId,
serialized: SerializedThread,
saved: SavedThread,
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
serialized
saved
.messages
.last()
.map(|message| message.id.0 + 1)
.unwrap_or(0),
);
let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| {
name != ScriptingTool::NAME
});
let tool_use =
ToolUseState::from_saved_messages(&saved.messages, |name| name != ScriptingTool::NAME);
let scripting_tool_use =
ToolUseState::from_serialized_messages(&serialized.messages, |name| {
name == ScriptingTool::NAME
});
ToolUseState::from_saved_messages(&saved.messages, |name| name == ScriptingTool::NAME);
let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
Self {
id,
updated_at: serialized.updated_at,
summary: Some(serialized.summary),
updated_at: saved.updated_at,
summary: Some(saved.summary),
pending_summary: Task::ready(None),
messages: serialized
messages: saved
.messages
.into_iter()
.map(|message| Message {
@@ -194,10 +158,8 @@ impl Thread {
prompt_builder,
tools,
tool_use,
action_log: cx.new(|_| ActionLog::new()),
scripting_session,
scripting_tool_use,
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
// TODO: persist token usage?
cumulative_token_usage: TokenUsage::default(),
}
@@ -241,7 +203,7 @@ impl Thread {
self.messages.iter()
}
pub fn is_generating(&self) -> bool {
pub fn is_streaming(&self) -> bool {
!self.pending_completions.is_empty() || !self.all_tools_finished()
}
@@ -268,8 +230,8 @@ impl Thread {
.into_iter()
.chain(self.scripting_tool_use.pending_tool_uses());
// If the only pending tool uses left are the ones with errors, then
// that means that we've finished running all of the pending tools.
// If the only pending tool uses left are the ones with errors, then that means that we've finished running all
// of the pending tools.
all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
}
@@ -285,10 +247,6 @@ impl Thread {
self.tool_use.tool_results_for_message(id)
}
pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
self.tool_use.tool_result(id)
}
pub fn scripting_tool_results_for_message(
&self,
id: MessageId,
@@ -391,47 +349,6 @@ impl Thread {
text
}
/// Serializes this thread into a format for storage or telemetry.
pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
let initial_project_snapshot = self.initial_project_snapshot.clone();
cx.spawn(|this, cx| async move {
let initial_project_snapshot = initial_project_snapshot.await;
this.read_with(&cx, |this, _| SerializedThread {
summary: this.summary_or_default(),
updated_at: this.updated_at(),
messages: this
.messages()
.map(|message| SerializedMessage {
id: message.id,
role: message.role,
text: message.text.clone(),
tool_uses: this
.tool_uses_for_message(message.id)
.into_iter()
.chain(this.scripting_tool_uses_for_message(message.id))
.map(|tool_use| SerializedToolUse {
id: tool_use.id,
name: tool_use.name,
input: tool_use.input,
})
.collect(),
tool_results: this
.tool_results_for_message(message.id)
.into_iter()
.chain(this.scripting_tool_results_for_message(message.id))
.map(|tool_result| SerializedToolResult {
tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
})
.collect(),
})
.collect(),
initial_project_snapshot,
})
})
}
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
@@ -561,39 +478,9 @@ impl Thread {
request.messages.push(context_message);
}
self.attach_stale_files(&mut request.messages, cx);
request
}
fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
const STALE_FILES_HEADER: &str = "These files changed since last read:";
let mut stale_message = String::new();
for stale_file in self.action_log.read(cx).stale_buffers(cx) {
let Some(file) = stale_file.read(cx).file() else {
continue;
};
if stale_message.is_empty() {
write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
}
writeln!(&mut stale_message, "- {}", file.path().display()).ok();
}
if !stale_message.is_empty() {
let context_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![stale_message.into()],
cache: false,
};
messages.push(context_message);
}
}
pub fn stream_completion(
&mut self,
request: LanguageModelRequest,
@@ -687,37 +574,32 @@ impl Thread {
let result = stream_completion.await;
thread
.update(&mut cx, |thread, cx| {
match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
cx.emit(ThreadEvent::UsePendingTools);
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
},
Err(error) => {
if error.is::<PaymentRequiredError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
} else if error.is::<MaxMonthlySpendReachedError>() {
cx.emit(ThreadEvent::ShowError(
ThreadError::MaxMonthlySpendReached,
));
} else {
let error_message = error
.chain()
.map(|err| err.to_string())
.collect::<Vec<_>>()
.join("\n");
cx.emit(ThreadEvent::ShowError(ThreadError::Message(
SharedString::from(error_message.clone()),
)));
}
thread.cancel_last_completion(cx);
.update(&mut cx, |thread, cx| match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
cx.emit(ThreadEvent::UsePendingTools);
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
},
Err(error) => {
if error.is::<PaymentRequiredError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
} else if error.is::<MaxMonthlySpendReachedError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
} else {
let error_message = error
.chain()
.map(|err| err.to_string())
.collect::<Vec<_>>()
.join("\n");
cx.emit(ThreadEvent::ShowError(ThreadError::Message(
SharedString::from(error_message.clone()),
)));
}
thread.cancel_last_completion();
}
cx.emit(ThreadEvent::DoneStreaming);
})
.ok();
});
@@ -793,13 +675,7 @@ impl Thread {
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(
tool_use.input,
&request.messages,
self.project.clone(),
self.action_log.clone(),
cx,
);
let task = tool.run(tool_use.input, &request.messages, self.project.clone(), cx);
self.insert_tool_output(tool_use.id.clone(), task, cx);
}
@@ -864,7 +740,6 @@ impl Thread {
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
canceled: false,
});
})
.ok();
@@ -894,7 +769,6 @@ impl Thread {
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
canceled: false,
});
})
.ok();
@@ -905,17 +779,11 @@ impl Thread {
.run_pending_tool(tool_use_id, insert_output_task);
}
pub fn attach_tool_results(
pub fn send_tool_results_to_model(
&mut self,
updated_context: Vec<ContextSnapshot>,
model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) {
self.context.extend(
updated_context
.into_iter()
.map(|context| (context.id, context)),
);
// Insert a user message to contain the tool results.
self.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back
@@ -925,155 +793,20 @@ impl Thread {
Vec::new(),
cx,
);
self.send_to_model(model, RequestKind::Chat, cx);
}
/// Cancels the last pending completion, if there are any pending.
///
/// Returns whether a completion was canceled.
pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
if self.pending_completions.pop().is_some() {
pub fn cancel_last_completion(&mut self) -> bool {
if let Some(_last_completion) = self.pending_completions.pop() {
true
} else {
let mut canceled = false;
for pending_tool_use in self.tool_use.cancel_pending() {
canceled = true;
cx.emit(ThreadEvent::ToolFinished {
tool_use_id: pending_tool_use.id.clone(),
pending_tool_use: Some(pending_tool_use),
canceled: true,
});
}
canceled
false
}
}
/// Reports feedback about the thread and stores it in our telemetry backend.
pub fn report_feedback(&self, is_positive: bool, cx: &mut Context<Self>) -> Task<Result<()>> {
let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
let serialized_thread = self.serialize(cx);
let thread_id = self.id().clone();
let client = self.project.read(cx).client();
cx.background_spawn(async move {
let final_project_snapshot = final_project_snapshot.await;
let serialized_thread = serialized_thread.await?;
let thread_data =
serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
let rating = if is_positive { "positive" } else { "negative" };
telemetry::event!(
"Assistant Thread Rated",
rating,
thread_id,
thread_data,
final_project_snapshot
);
client.telemetry().flush_events();
Ok(())
})
}
/// Create a snapshot of the current project state including git information and unsaved buffers.
fn project_snapshot(
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Arc<ProjectSnapshot>> {
let worktree_snapshots: Vec<_> = project
.read(cx)
.visible_worktrees(cx)
.map(|worktree| Self::worktree_snapshot(worktree, cx))
.collect();
cx.spawn(move |_, cx| async move {
let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
let mut unsaved_buffers = Vec::new();
cx.update(|app_cx| {
let buffer_store = project.read(app_cx).buffer_store();
for buffer_handle in buffer_store.read(app_cx).buffers() {
let buffer = buffer_handle.read(app_cx);
if buffer.is_dirty() {
if let Some(file) = buffer.file() {
let path = file.path().to_string_lossy().to_string();
unsaved_buffers.push(path);
}
}
}
})
.ok();
Arc::new(ProjectSnapshot {
worktree_snapshots,
unsaved_buffer_paths: unsaved_buffers,
timestamp: Utc::now(),
})
})
}
fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
cx.spawn(move |cx| async move {
// Get worktree path and snapshot
let worktree_info = cx.update(|app_cx| {
let worktree = worktree.read(app_cx);
let path = worktree.abs_path().to_string_lossy().to_string();
let snapshot = worktree.snapshot();
(path, snapshot)
});
let Ok((worktree_path, snapshot)) = worktree_info else {
return WorktreeSnapshot {
worktree_path: String::new(),
git_state: None,
};
};
// Extract git information
let git_state = match snapshot.repositories().first() {
None => None,
Some(repo_entry) => {
// Get branch information
let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
// Get repository info
let repo_result = worktree.read_with(&cx, |worktree, _cx| {
if let project::Worktree::Local(local_worktree) = &worktree {
local_worktree.get_local_repo(repo_entry).map(|local_repo| {
let repo = local_repo.repo();
(repo.remote_url("origin"), repo.head_sha(), repo.clone())
})
} else {
None
}
});
match repo_result {
Ok(Some((remote_url, head_sha, repository))) => {
// Get diff asynchronously
let diff = repository
.diff(git::repository::DiffType::HeadToWorktree, cx)
.await
.ok();
Some(GitState {
remote_url,
head_sha,
current_branch,
diff,
})
}
Err(_) | Ok(None) => None,
}
}
};
WorktreeSnapshot {
worktree_path,
git_state,
}
})
}
pub fn to_markdown(&self) -> Result<String> {
let mut markdown = Vec::new();
@@ -1122,10 +855,6 @@ impl Thread {
Ok(String::from_utf8_lossy(&markdown).to_string())
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone()
}
@@ -1143,7 +872,6 @@ pub enum ThreadEvent {
ShowError(ThreadError),
StreamedCompletion,
StreamedAssistantText(MessageId, String),
DoneStreaming,
MessageAdded(MessageId),
MessageEdited(MessageId),
MessageDeleted(MessageId),
@@ -1154,8 +882,6 @@ pub enum ThreadEvent {
tool_use_id: LanguageModelToolUseId,
/// The pending tool use that corresponds to this tool.
pending_tool_use: Option<PendingToolUse>,
/// Whether the tool was canceled by the user.
canceled: bool,
},
}

View File

@@ -7,7 +7,7 @@ use time::{OffsetDateTime, UtcOffset};
use ui::{prelude::*, IconButtonShape, ListItem, ListItemSpacing, Tooltip};
use crate::history_store::{HistoryEntry, HistoryStore};
use crate::thread_store::SerializedThreadMetadata;
use crate::thread_store::SavedThreadMetadata;
use crate::{AssistantPanel, RemoveSelectedThread};
pub struct ThreadHistory {
@@ -221,14 +221,14 @@ impl Render for ThreadHistory {
#[derive(IntoElement)]
pub struct PastThread {
thread: SerializedThreadMetadata,
thread: SavedThreadMetadata,
assistant_panel: WeakEntity<AssistantPanel>,
selected: bool,
}
impl PastThread {
pub fn new(
thread: SerializedThreadMetadata,
thread: SavedThreadMetadata,
assistant_panel: WeakEntity<AssistantPanel>,
selected: bool,
) -> Self {

View File

@@ -20,7 +20,7 @@ use prompt_store::PromptBuilder;
use serde::{Deserialize, Serialize};
use util::ResultExt as _;
use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadId};
use crate::thread::{MessageId, Thread, ThreadId};
pub fn init(cx: &mut App) {
ThreadsDatabase::init(cx);
@@ -32,7 +32,7 @@ pub struct ThreadStore {
prompt_builder: Arc<PromptBuilder>,
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
threads: Vec<SavedThreadMetadata>,
}
impl ThreadStore {
@@ -70,13 +70,13 @@ impl ThreadStore {
self.threads.len()
}
pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
pub fn threads(&self) -> Vec<SavedThreadMetadata> {
let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
threads
}
pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
self.threads().into_iter().take(limit).collect()
}
@@ -107,7 +107,7 @@ impl ThreadStore {
this.update(&mut cx, |this, cx| {
cx.new(|cx| {
Thread::deserialize(
Thread::from_saved(
id.clone(),
thread,
this.project.clone(),
@@ -121,14 +121,53 @@ impl ThreadStore {
}
pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
let (metadata, serialized_thread) =
thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
let (metadata, thread) = thread.update(cx, |thread, _cx| {
let id = thread.id().clone();
let thread = SavedThread {
summary: thread.summary_or_default(),
updated_at: thread.updated_at(),
messages: thread
.messages()
.map(|message| {
let all_tool_uses = thread
.tool_uses_for_message(message.id)
.into_iter()
.chain(thread.scripting_tool_uses_for_message(message.id))
.map(|tool_use| SavedToolUse {
id: tool_use.id,
name: tool_use.name,
input: tool_use.input,
})
.collect();
let all_tool_results = thread
.tool_results_for_message(message.id)
.into_iter()
.chain(thread.scripting_tool_results_for_message(message.id))
.map(|tool_result| SavedToolResult {
tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
})
.collect();
SavedMessage {
id: message.id,
role: message.role,
text: message.text.clone(),
tool_uses: all_tool_uses,
tool_results: all_tool_results,
}
})
.collect(),
};
(id, thread)
});
let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(|this, mut cx| async move {
let serialized_thread = serialized_thread.await?;
let database = database_future.await.map_err(|err| anyhow!(err))?;
database.save_thread(metadata, serialized_thread).await?;
database.save_thread(metadata, thread).await?;
this.update(&mut cx, |this, cx| this.reload(cx))?.await
})
@@ -231,41 +270,39 @@ impl ThreadStore {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedThreadMetadata {
pub struct SavedThreadMetadata {
pub id: ThreadId,
pub summary: SharedString,
pub updated_at: DateTime<Utc>,
}
#[derive(Serialize, Deserialize)]
pub struct SerializedThread {
pub struct SavedThread {
pub summary: SharedString,
pub updated_at: DateTime<Utc>,
pub messages: Vec<SerializedMessage>,
#[serde(default)]
pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
pub messages: Vec<SavedMessage>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedMessage {
pub struct SavedMessage {
pub id: MessageId,
pub role: Role,
pub text: String,
#[serde(default)]
pub tool_uses: Vec<SerializedToolUse>,
pub tool_uses: Vec<SavedToolUse>,
#[serde(default)]
pub tool_results: Vec<SerializedToolResult>,
pub tool_results: Vec<SavedToolResult>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedToolUse {
pub struct SavedToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub input: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedToolResult {
pub struct SavedToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool,
pub content: Arc<str>,
@@ -280,7 +317,7 @@ impl Global for GlobalThreadsDatabase {}
pub(crate) struct ThreadsDatabase {
executor: BackgroundExecutor,
env: heed::Env,
threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
threads: Database<SerdeBincode<ThreadId>, SerdeJson<SavedThread>>,
}
impl ThreadsDatabase {
@@ -327,7 +364,7 @@ impl ThreadsDatabase {
})
}
pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
let env = self.env.clone();
let threads = self.threads;
@@ -336,7 +373,7 @@ impl ThreadsDatabase {
let mut iter = threads.iter(&txn)?;
let mut threads = Vec::new();
while let Some((key, value)) = iter.next().transpose()? {
threads.push(SerializedThreadMetadata {
threads.push(SavedThreadMetadata {
id: key,
summary: value.summary,
updated_at: value.updated_at,
@@ -347,7 +384,7 @@ impl ThreadsDatabase {
})
}
pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
let env = self.env.clone();
let threads = self.threads;
@@ -358,7 +395,7 @@ impl ThreadsDatabase {
})
}
pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
let env = self.env.clone();
let threads = self.threads;

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use assistant_tool::{ToolSource, ToolWorkingSet};
use gpui::Entity;
use scripting_tool::ScriptingTool;
use ui::{prelude::*, ContextMenu, PopoverMenu, Tooltip};
use ui::{prelude::*, ContextMenu, IconButtonShape, PopoverMenu, Tooltip};
pub struct ToolSelector {
tools: Arc<ToolWorkingSet>,
@@ -19,22 +19,27 @@ impl ToolSelector {
window: &mut Window,
cx: &mut Context<Self>,
) -> Entity<ContextMenu> {
let tool_set = self.tools.clone();
ContextMenu::build_persistent(window, cx, move |mut menu, _window, cx| {
ContextMenu::build(window, cx, |mut menu, _window, cx| {
let icon_position = IconPosition::End;
let tools_by_source = tool_set.tools_by_source(cx);
let tools_by_source = self.tools.tools_by_source(cx);
let all_tools_enabled = tool_set.are_all_tools_enabled();
menu = menu.toggleable_entry("All Tools", all_tools_enabled, icon_position, None, {
let tools = tool_set.clone();
move |_window, cx| {
if all_tools_enabled {
tools.disable_all_tools(cx);
} else {
tools.enable_all_tools();
let all_tools_enabled = self.tools.are_all_tools_enabled();
menu = menu.header("Tools").toggleable_entry(
"All Tools",
all_tools_enabled,
icon_position,
None,
{
let tools = self.tools.clone();
move |_window, cx| {
if all_tools_enabled {
tools.disable_all_tools(cx);
} else {
tools.enable_all_tools();
}
}
}
});
},
);
for (source, tools) in tools_by_source {
let mut tools = tools
@@ -42,7 +47,7 @@ impl ToolSelector {
.map(|tool| {
let source = tool.source();
let name = tool.name().into();
let is_enabled = tool_set.is_enabled(&source, &name);
let is_enabled = self.tools.is_enabled(&source, &name);
(source, name, is_enabled)
})
@@ -52,16 +57,16 @@ impl ToolSelector {
tools.push((
ToolSource::Native,
ScriptingTool::NAME.into(),
tool_set.is_scripting_tool_enabled(),
self.tools.is_scripting_tool_enabled(),
));
tools.sort_by(|(_, name_a, _), (_, name_b, _)| name_a.cmp(name_b));
}
menu = match &source {
ToolSource::Native => menu.separator().header("Zed Tools"),
ToolSource::Native => menu.header("Zed"),
ToolSource::ContextServer { id } => {
let all_tools_from_source_enabled =
tool_set.are_all_tools_from_source_enabled(&source);
self.tools.are_all_tools_from_source_enabled(&source);
menu.separator().header(id).toggleable_entry(
"All Tools",
@@ -69,7 +74,7 @@ impl ToolSelector {
icon_position,
None,
{
let tools = tool_set.clone();
let tools = self.tools.clone();
let source = source.clone();
move |_window, cx| {
if all_tools_from_source_enabled {
@@ -85,7 +90,7 @@ impl ToolSelector {
for (source, name, is_enabled) in tools {
menu = menu.toggleable_entry(name.clone(), is_enabled, icon_position, None, {
let tools = tool_set.clone();
let tools = self.tools.clone();
move |_window, _cx| {
if name.as_ref() == ScriptingTool::NAME {
if is_enabled {
@@ -119,6 +124,7 @@ impl Render for ToolSelector {
})
.trigger_with_tooltip(
IconButton::new("tool-selector-button", IconName::SettingsAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Muted),
Tooltip::text("Customize Tools"),

View File

@@ -11,7 +11,7 @@ use language_model::{
};
use crate::thread::MessageId;
use crate::thread_store::SerializedMessage;
use crate::thread_store::SavedMessage;
#[derive(Debug)]
pub struct ToolUse {
@@ -46,11 +46,11 @@ impl ToolUseState {
}
}
/// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
/// Constructs a [`ToolUseState`] from the given list of [`SavedMessage`]s.
///
/// Accepts a function to filter the tools that should be used to populate the state.
pub fn from_serialized_messages(
messages: &[SerializedMessage],
pub fn from_saved_messages(
messages: &[SavedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self {
let mut this = Self::new();
@@ -118,22 +118,6 @@ impl ToolUseState {
this
}
pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
let mut pending_tools = Vec::new();
for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id,
content: "Tool canceled by user".into(),
is_error: true,
},
);
pending_tools.push(tool_use.clone());
}
pending_tools
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
@@ -198,13 +182,6 @@ impl ToolUseState {
.map_or(false, |results| !results.is_empty())
}
pub fn tool_result(
&self,
tool_use_id: &LanguageModelToolUseId,
) -> Option<&LanguageModelToolResult> {
self.tool_results.get(tool_use_id)
}
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,
@@ -249,12 +226,12 @@ impl ToolUseState {
output: Result<String>,
) -> Option<PendingToolUse> {
match output {
Ok(tool_result) => {
Ok(output) => {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
content: tool_result.into(),
content: output.into(),
is_error: false,
},
);

View File

@@ -48,7 +48,7 @@ impl SlashCommandCompletionProvider {
name_range: Range<Anchor>,
window: &mut Window,
cx: &mut App,
) -> Task<Result<Option<Vec<project::Completion>>>> {
) -> Task<Result<Vec<project::Completion>>> {
let slash_commands = self.slash_commands.clone();
let candidates = slash_commands
.command_names(cx)
@@ -71,67 +71,65 @@ impl SlashCommandCompletionProvider {
.await;
cx.update(|_, cx| {
Some(
matches
.into_iter()
.filter_map(|mat| {
let command = slash_commands.command(&mat.string, cx)?;
let mut new_text = mat.string.clone();
let requires_argument = command.requires_argument();
let accepts_arguments = command.accepts_arguments();
if requires_argument || accepts_arguments {
new_text.push(' ');
}
matches
.into_iter()
.filter_map(|mat| {
let command = slash_commands.command(&mat.string, cx)?;
let mut new_text = mat.string.clone();
let requires_argument = command.requires_argument();
let accepts_arguments = command.accepts_arguments();
if requires_argument || accepts_arguments {
new_text.push(' ');
}
let confirm =
editor
.clone()
.zip(workspace.clone())
.map(|(editor, workspace)| {
let command_name = mat.string.clone();
let command_range = command_range.clone();
let editor = editor.clone();
let workspace = workspace.clone();
Arc::new(
move |intent: CompletionIntent,
window: &mut Window,
cx: &mut App| {
if !requires_argument
&& (!accepts_arguments || intent.is_complete())
{
editor
.update(cx, |editor, cx| {
editor.run_command(
command_range.clone(),
&command_name,
&[],
true,
workspace.clone(),
window,
cx,
);
})
.ok();
false
} else {
requires_argument || accepts_arguments
}
},
) as Arc<_>
});
Some(project::Completion {
old_range: name_range.clone(),
documentation: Some(CompletionDocumentation::SingleLine(
command.description().into(),
)),
new_text,
label: command.label(cx),
confirm,
source: CompletionSource::Custom,
})
let confirm =
editor
.clone()
.zip(workspace.clone())
.map(|(editor, workspace)| {
let command_name = mat.string.clone();
let command_range = command_range.clone();
let editor = editor.clone();
let workspace = workspace.clone();
Arc::new(
move |intent: CompletionIntent,
window: &mut Window,
cx: &mut App| {
if !requires_argument
&& (!accepts_arguments || intent.is_complete())
{
editor
.update(cx, |editor, cx| {
editor.run_command(
command_range.clone(),
&command_name,
&[],
true,
workspace.clone(),
window,
cx,
);
})
.ok();
false
} else {
requires_argument || accepts_arguments
}
},
) as Arc<_>
});
Some(project::Completion {
old_range: name_range.clone(),
documentation: Some(CompletionDocumentation::SingleLine(
command.description().into(),
)),
new_text,
label: command.label(cx),
confirm,
source: CompletionSource::Custom,
})
.collect(),
)
})
.collect()
})
})
}
@@ -145,7 +143,7 @@ impl SlashCommandCompletionProvider {
last_argument_range: Range<Anchor>,
window: &mut Window,
cx: &mut App,
) -> Task<Result<Option<Vec<project::Completion>>>> {
) -> Task<Result<Vec<project::Completion>>> {
let new_cancel_flag = Arc::new(AtomicBool::new(false));
let mut flag = self.cancel_flag.lock();
flag.store(true, SeqCst);
@@ -163,28 +161,27 @@ impl SlashCommandCompletionProvider {
let workspace = self.workspace.clone();
let arguments = arguments.to_vec();
cx.background_spawn(async move {
Ok(Some(
completions
.await?
.into_iter()
.map(|new_argument| {
let confirm =
editor
.clone()
.zip(workspace.clone())
.map(|(editor, workspace)| {
Arc::new({
let mut completed_arguments = arguments.clone();
if new_argument.replace_previous_arguments {
completed_arguments.clear();
} else {
completed_arguments.pop();
}
completed_arguments.push(new_argument.new_text.clone());
Ok(completions
.await?
.into_iter()
.map(|new_argument| {
let confirm =
editor
.clone()
.zip(workspace.clone())
.map(|(editor, workspace)| {
Arc::new({
let mut completed_arguments = arguments.clone();
if new_argument.replace_previous_arguments {
completed_arguments.clear();
} else {
completed_arguments.pop();
}
completed_arguments.push(new_argument.new_text.clone());
let command_range = command_range.clone();
let command_name = command_name.clone();
move |intent: CompletionIntent,
let command_range = command_range.clone();
let command_name = command_name.clone();
move |intent: CompletionIntent,
window: &mut Window,
cx: &mut App| {
if new_argument.after_completion.run()
@@ -208,32 +205,31 @@ impl SlashCommandCompletionProvider {
!new_argument.after_completion.run()
}
}
}) as Arc<_>
});
}) as Arc<_>
});
let mut new_text = new_argument.new_text.clone();
if new_argument.after_completion == AfterCompletion::Continue {
new_text.push(' ');
}
let mut new_text = new_argument.new_text.clone();
if new_argument.after_completion == AfterCompletion::Continue {
new_text.push(' ');
}
project::Completion {
old_range: if new_argument.replace_previous_arguments {
argument_range.clone()
} else {
last_argument_range.clone()
},
label: new_argument.label,
new_text,
documentation: None,
confirm,
source: CompletionSource::Custom,
}
})
.collect(),
))
project::Completion {
old_range: if new_argument.replace_previous_arguments {
argument_range.clone()
} else {
last_argument_range.clone()
},
label: new_argument.label,
new_text,
documentation: None,
confirm,
source: CompletionSource::Custom,
}
})
.collect())
})
} else {
Task::ready(Ok(Some(Vec::new())))
Task::ready(Ok(Vec::new()))
}
}
}
@@ -246,7 +242,7 @@ impl CompletionProvider for SlashCommandCompletionProvider {
_: editor::CompletionContext,
window: &mut Window,
cx: &mut Context<Editor>,
) -> Task<Result<Option<Vec<project::Completion>>>> {
) -> Task<Result<Vec<project::Completion>>> {
let Some((name, arguments, command_range, last_argument_range)) =
buffer.update(cx, |buffer, _cx| {
let position = buffer_position.to_point(buffer);
@@ -290,7 +286,7 @@ impl CompletionProvider for SlashCommandCompletionProvider {
Some((name, arguments, command_range, last_argument_range))
})
else {
return Task::ready(Ok(Some(Vec::new())));
return Task::ready(Ok(Vec::new()));
};
if let Some((arguments, argument_range)) = arguments {

View File

@@ -1,44 +0,0 @@
[package]
name = "assistant_eval"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[[bin]]
name = "assistant_eval"
path = "src/main.rs"
[dependencies]
anyhow.workspace = true
assistant2.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
clap.workspace = true
client.workspace = true
collections.workspace = true
context_server.workspace = true
env_logger.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
language_models.workspace = true
node_runtime.workspace = true
project.workspace = true
prompt_store.workspace = true
regex.workspace = true
release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
smol.workspace = true
util.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-GPL

View File

@@ -1,77 +0,0 @@
# Tool Evals
A framework for evaluating and benchmarking AI assistant performance in the Zed editor.
## Overview
Tool Evals provides a headless environment for running assistants evaluations on code repositories. It automates the process of:
1. Cloning and setting up test repositories
2. Sending prompts to language models
3. Allowing the assistant to use tools to modify code
4. Collecting metrics on performance
5. Evaluating results against known good solutions
## How It Works
The system consists of several key components:
- **Eval**: Loads test cases from the evaluation_data directory, clones repos, and executes evaluations
- **HeadlessAssistant**: Provides a headless environment for running the AI assistant
- **Judge**: Compares AI-generated diffs with reference solutions and scores their functional similarity
The evaluation flow:
1. An evaluation is loaded from the evaluation_data directory
2. The target repository is cloned and checked out at a specific commit
3. A HeadlessAssistant instance is created with the specified language model
4. The user prompt is sent to the assistant
5. The assistant responds and uses tools to modify code
6. Upon completion, a diff is generated from the changes
7. Results are saved including the diff, assistant's response, and performance metrics
8. If a reference solution exists, a Judge evaluates the similarity of the solution
## Setup Requirements
### Prerequisites
- Rust and Cargo
- Git
- Network access to clone repositories
- Appropriate API keys for language models and git services (Anthropic, GitHub, etc.)
### Environment Variables
Ensure you have the required API keys set, either from a dev run of Zed or via these environment variables:
- `ZED_ANTHROPIC_API_KEY` for Claude models
- `ZED_OPENAI_API_KEY` for OpenAI models
- `ZED_GITHUB_API_KEY` for GitHub API (or similar)
## Usage
### Running a Single Evaluation
To run a specific evaluation:
```bash
cargo run -p assistant_eval -- bubbletea-add-set-window-title
```
The arguments are regex patterns for the evaluation names to run, so to run all evaluations that contain `bubbletea`, run:
```bash
cargo run -p assistant_eval -- bubbletea
```
To run all evaluations:
```bash
cargo run -p assistant_eval -- --all
```
## Evaluation Data Structure
Each evaluation should be placed in the `evaluation_data` directory with the following structure:
* `prompt.txt`: The user's prompt.
* `original.diff`: The `git diff` of the change anticipated for this prompt.
* `setup.json`: Information about the repo used for the evaluation.

View File

@@ -1,61 +0,0 @@
// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
use std::process::Command;
fn main() {
if cfg!(target_os = "macos") {
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
println!("cargo:rerun-if-env-changed=ZED_BUNDLE");
if std::env::var("ZED_BUNDLE").ok().as_deref() == Some("true") {
// Find WebRTC.framework in the Frameworks folder when running as part of an application bundle.
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../Frameworks");
} else {
// Find WebRTC.framework as a sibling of the executable when running outside of an application bundle.
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
}
// Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
// Seems to be required to enable Swift concurrency
println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
// Register exported Objective-C selectors, protocols, etc
println!("cargo:rustc-link-arg=-Wl,-ObjC");
}
// Populate git sha environment variable if git is available
println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
println!(
"cargo:rustc-env=TARGET={}",
std::env::var("TARGET").unwrap()
);
if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
if output.status.success() {
let git_sha = String::from_utf8_lossy(&output.stdout);
let git_sha = git_sha.trim();
println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
if let Ok(build_profile) = std::env::var("PROFILE") {
if build_profile == "release" {
// This is currently the best way to make `cargo build ...`'s build script
// to print something to stdout without extra verbosity.
println!(
"cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
);
}
}
}
}
#[cfg(target_os = "windows")]
{
#[cfg(target_env = "msvc")]
{
// todo(windows): This is to avoid stack overflow. Remove it when solved.
println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
}
}
}

View File

@@ -1,252 +0,0 @@
use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
use anyhow::anyhow;
use assistant2::RequestKind;
use collections::HashMap;
use gpui::{App, Task};
use language_model::{LanguageModel, TokenUsage};
use serde::{Deserialize, Serialize};
use std::{
fs,
io::Write,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use util::command::new_smol_command;
pub struct Eval {
pub name: String,
pub path: PathBuf,
pub repo_path: PathBuf,
pub eval_setup: EvalSetup,
pub user_prompt: String,
}
#[derive(Debug, Serialize)]
pub struct EvalOutput {
pub diff: String,
pub last_message: String,
pub elapsed_time: Duration,
pub assistant_response_count: usize,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub token_usage: TokenUsage,
}
#[derive(Deserialize)]
pub struct EvalSetup {
pub url: String,
pub base_sha: String,
}
impl Eval {
/// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
/// if necessary.
pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
let prompt_path = path.join("prompt.txt");
let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
let setup_path = path.join("setup.json");
let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
Ok(Eval {
name,
path,
repo_path,
eval_setup,
user_prompt,
})
}
pub fn run(
self,
app_state: Arc<HeadlessAppState>,
model: Arc<dyn LanguageModel>,
cx: &mut App,
) -> Task<anyhow::Result<EvalOutput>> {
cx.spawn(move |mut cx| async move {
checkout_repo(&self.eval_setup, &self.repo_path).await?;
let (assistant, done_rx) =
cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
let _worktree = assistant
.update(&mut cx, |assistant, cx| {
assistant.project.update(cx, |project, cx| {
project.create_worktree(&self.repo_path, true, cx)
})
})?
.await?;
let start_time = std::time::SystemTime::now();
assistant.update(&mut cx, |assistant, cx| {
assistant.thread.update(cx, |thread, cx| {
let context = vec![];
thread.insert_user_message(self.user_prompt.clone(), context, cx);
thread.send_to_model(model, RequestKind::Chat, cx);
});
})?;
done_rx.recv().await??;
let elapsed_time = start_time.elapsed()?;
let diff = query_git(&self.repo_path, vec!["diff"]).await?;
assistant.update(&mut cx, |assistant, cx| {
let thread = assistant.thread.read(cx);
let last_message = thread.messages().last().unwrap();
if last_message.role != language_model::Role::Assistant {
return Err(anyhow!("Last message is not from assistant"));
}
let assistant_response_count = thread
.messages()
.filter(|message| message.role == language_model::Role::Assistant)
.count();
Ok(EvalOutput {
diff,
last_message: last_message.text.clone(),
elapsed_time,
assistant_response_count,
tool_use_counts: assistant.tool_use_counts.clone(),
token_usage: thread.cumulative_token_usage(),
})
})?
})
}
}
impl EvalOutput {
// Method to save the output to a directory
pub fn save_to_directory(
&self,
output_dir: &Path,
eval_output_value: String,
) -> anyhow::Result<()> {
// Create the output directory if it doesn't exist
fs::create_dir_all(&output_dir)?;
// Save the diff to a file
let diff_path = output_dir.join("diff.patch");
let mut diff_file = fs::File::create(&diff_path)?;
diff_file.write_all(self.diff.as_bytes())?;
// Save the last message to a file
let message_path = output_dir.join("assistant_response.txt");
let mut message_file = fs::File::create(&message_path)?;
message_file.write_all(self.last_message.as_bytes())?;
// Current metrics for this run
let current_metrics = serde_json::json!({
"elapsed_time_ms": self.elapsed_time.as_millis(),
"assistant_response_count": self.assistant_response_count,
"tool_use_counts": self.tool_use_counts,
"token_usage": self.token_usage,
"eval_output_value": eval_output_value,
});
// Get current timestamp in milliseconds
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_millis()
.to_string();
// Path to metrics file
let metrics_path = output_dir.join("metrics.json");
// Load existing metrics if the file exists, or create a new object
let mut historical_metrics = if metrics_path.exists() {
let metrics_content = fs::read_to_string(&metrics_path)?;
serde_json::from_str::<serde_json::Value>(&metrics_content)
.unwrap_or_else(|_| serde_json::json!({}))
} else {
serde_json::json!({})
};
// Add new run with timestamp as key
if let serde_json::Value::Object(ref mut map) = historical_metrics {
map.insert(timestamp, current_metrics);
}
// Write updated metrics back to file
let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
let mut metrics_file = fs::File::create(&metrics_path)?;
metrics_file.write_all(metrics_json.as_bytes())?;
Ok(())
}
}
fn repo_dir_name(url: &str) -> String {
url.trim_start_matches("https://")
.replace(|c: char| !c.is_alphanumeric(), "_")
}
async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
if !repo_path.exists() {
smol::unblock({
let repo_path = repo_path.to_path_buf();
|| std::fs::create_dir_all(repo_path)
})
.await?;
run_git(repo_path, vec!["init"]).await?;
run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
} else {
let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
if actual_origin != eval_setup.url {
return Err(anyhow!(
"remote origin {} does not match expected origin {}",
actual_origin,
eval_setup.url
));
}
// TODO: consider including "-x" to remove ignored files. The downside of this is that it will
// also remove build artifacts, and so prevent incremental reuse there.
run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
}
run_git(
repo_path,
vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
)
.await?;
run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
Ok(())
}
async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
let exit_status = new_smol_command("git")
.current_dir(repo_path)
.args(args.clone())
.status()
.await?;
if exit_status.success() {
Ok(())
} else {
Err(anyhow!(
"`git {}` failed with {}",
args.join(" "),
exit_status,
))
}
}
async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
let output = new_smol_command("git")
.current_dir(repo_path)
.args(args.clone())
.output()
.await?;
if output.status.success() {
Ok(String::from_utf8(output.stdout)?.trim().to_string())
} else {
Err(anyhow!(
"`git {}` failed with {}",
args.join(" "),
output.status
))
}
}

View File

@@ -1,241 +0,0 @@
use anyhow::anyhow;
use assistant2::{RequestKind, Thread, ThreadEvent, ThreadStore};
use assistant_tool::ToolWorkingSet;
use client::{Client, UserStore};
use collections::HashMap;
use futures::StreamExt;
use gpui::{prelude::*, App, AsyncApp, Entity, SemanticVersion, Subscription, Task};
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest,
};
use node_runtime::NodeRuntime;
use project::{Project, RealFs};
use prompt_store::PromptBuilder;
use settings::SettingsStore;
use smol::channel;
use std::sync::Arc;
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct HeadlessAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
// Additional fields not present in `workspace::AppState`.
pub prompt_builder: Arc<PromptBuilder>,
}
pub struct HeadlessAssistant {
pub thread: Entity<Thread>,
pub project: Entity<Project>,
#[allow(dead_code)]
pub thread_store: Entity<ThreadStore>,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub done_tx: channel::Sender<anyhow::Result<()>>,
_subscription: Subscription,
}
impl HeadlessAssistant {
pub fn new(
app_state: Arc<HeadlessAppState>,
cx: &mut App,
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
let env = None;
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
env,
cx,
);
let tools = Arc::new(ToolWorkingSet::default());
let thread_store =
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
let headless_thread = cx.new(move |cx| Self {
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
thread,
project,
thread_store,
tool_use_counts: HashMap::default(),
done_tx,
});
Ok((headless_thread, done_rx))
}
fn handle_thread_event(
&mut self,
thread: Entity<Thread>,
event: &ThreadEvent,
cx: &mut Context<Self>,
) {
match event {
ThreadEvent::ShowError(err) => self
.done_tx
.send_blocking(Err(anyhow!("{:?}", err)))
.unwrap(),
ThreadEvent::DoneStreaming => {
let thread = thread.read(cx);
if let Some(message) = thread.messages().last() {
println!("Message: {}", message.text,);
}
if thread.all_tools_finished() {
self.done_tx.send_blocking(Ok(())).unwrap()
}
}
ThreadEvent::UsePendingTools => {
thread.update(cx, |thread, cx| {
thread.use_pending_tools(cx);
});
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
if let Some(pending_tool_use) = pending_tool_use {
println!(
"Used tool {} with input: {}",
pending_tool_use.name, pending_tool_use.input
);
*self
.tool_use_counts
.entry(pending_tool_use.name.clone())
.or_insert(0) += 1;
}
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
println!("Tool result: {:?}", tool_result);
}
if thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
thread.update(cx, |thread, cx| {
thread.attach_tool_results(vec![], cx);
thread.send_to_model(model, RequestKind::Chat, cx);
});
}
}
}
ThreadEvent::StreamedCompletion
| ThreadEvent::SummaryChanged
| ThreadEvent::StreamedAssistantText(_, _)
| ThreadEvent::MessageAdded(_)
| ThreadEvent::MessageEdited(_)
| ThreadEvent::MessageDeleted(_) => {}
}
}
}
pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
release_channel::init(SemanticVersion::default(), cx);
gpui_tokio::init(cx);
let mut settings_store = SettingsStore::new(cx);
settings_store
.set_default_settings(settings::default_settings().as_ref(), cx)
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
Project::init_settings(cx);
let client = Client::production(cx);
cx.set_http_client(client.http_client().clone());
let git_binary_path = None;
let fs = Arc::new(RealFs::new(git_binary_path));
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
assistant_tools::init(cx);
context_server::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
Arc::new(HeadlessAppState {
languages,
client,
user_store,
fs,
node_runtime: NodeRuntime::unavailable(),
prompt_builder,
})
}
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model = model_registry
.available_models(cx)
.find(|model| model.id().0 == model_name);
let Some(model) = model else {
return Err(anyhow!(
"No language model named {} was available. Available models: {}",
model_name,
model_registry
.available_models(cx)
.map(|model| model.id().0.clone())
.collect::<Vec<_>>()
.join(", ")
));
};
Ok(model)
}
pub fn authenticate_model_provider(
provider_id: LanguageModelProviderId,
cx: &mut App,
) -> Task<std::result::Result<(), AuthenticateError>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}
pub async fn send_language_model_request(
model: Arc<dyn LanguageModel>,
request: LanguageModelRequest,
cx: AsyncApp,
) -> anyhow::Result<String> {
match model.stream_completion_text(request, &cx).await {
Ok(mut stream) => {
let mut full_response = String::new();
// Process the response stream
while let Some(chunk_result) = stream.stream.next().await {
match chunk_result {
Ok(chunk_str) => {
full_response.push_str(&chunk_str);
}
Err(err) => {
return Err(anyhow!(
"Error receiving response from language model: {err}"
));
}
}
}
Ok(full_response)
}
Err(err) => Err(anyhow!(
"Failed to get response from language model. Error was: {err}"
)),
}
}

View File

@@ -1,121 +0,0 @@
use crate::eval::EvalOutput;
use crate::headless_assistant::send_language_model_request;
use anyhow::anyhow;
use gpui::{App, Task};
use language_model::{
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
use std::{path::Path, sync::Arc};
pub struct Judge {
pub original_diff: Option<String>,
#[allow(dead_code)]
pub original_message: Option<String>,
pub model: Arc<dyn LanguageModel>,
}
impl Judge {
pub async fn load(eval_path: &Path, model: Arc<dyn LanguageModel>) -> anyhow::Result<Judge> {
let original_diff_path = eval_path.join("original.diff");
let original_diff = smol::unblock(move || {
if std::fs::exists(&original_diff_path)? {
anyhow::Ok(Some(std::fs::read_to_string(&original_diff_path)?))
} else {
anyhow::Ok(None)
}
});
let original_message_path = eval_path.join("original_message.txt");
let original_message = smol::unblock(move || {
if std::fs::exists(&original_message_path)? {
anyhow::Ok(Some(std::fs::read_to_string(&original_message_path)?))
} else {
anyhow::Ok(None)
}
});
Ok(Self {
original_diff: original_diff.await?,
original_message: original_message.await?,
model,
})
}
pub fn run(&self, eval_output: &EvalOutput, cx: &mut App) -> Task<anyhow::Result<String>> {
let Some(original_diff) = self.original_diff.as_ref() else {
return Task::ready(Err(anyhow!("No original.diff found")));
};
// TODO: check for empty diff?
let prompt = diff_comparison_prompt(&original_diff, &eval_output.diff);
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text(prompt)],
cache: false,
}],
temperature: Some(0.0),
tools: Vec::new(),
stop: Vec::new(),
};
let model = self.model.clone();
cx.spawn(move |cx| send_language_model_request(model, request, cx))
}
}
pub fn diff_comparison_prompt(original_diff: &str, new_diff: &str) -> String {
format!(
r#"# Git Diff Similarity Evaluation Template
## Instructions
Compare the two diffs and score them between 0.0 and 1.0 based on their functional similarity.
- 1.0 = Perfect functional match (achieves identical results)
- 0.0 = No functional similarity whatsoever
## Evaluation Criteria
Please consider the following aspects in order of importance:
1. **Functional Equivalence (60%)**
- Do both diffs achieve the same end result?
- Are the changes functionally equivalent despite possibly using different approaches?
- Do the modifications address the same issues or implement the same features?
2. **Logical Structure (20%)**
- Are the logical flows similar?
- Do the modifications affect the same code paths?
- Are control structures (if/else, loops, etc.) modified in similar ways?
3. **Code Content (15%)**
- Are similar lines added/removed?
- Are the same variables, functions, or methods being modified?
- Are the same APIs or libraries being used?
4. **File Layout (5%)**
- Are the same files being modified?
- Are changes occurring in similar locations within files?
## Input
Original Diff:
```git
{}
```
New Diff:
```git
{}
```
## Output Format
THE ONLY OUTPUT SHOULD BE A SCORE BETWEEN 0.0 AND 1.0.
Example output:
0.85"#,
original_diff, new_diff
)
}

View File

@@ -1,243 +0,0 @@
mod eval;
mod headless_assistant;
mod judge;
use clap::Parser;
use eval::{Eval, EvalOutput};
use futures::future;
use gpui::{Application, AsyncApp};
use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState};
use itertools::Itertools;
use judge::Judge;
use language_model::{LanguageModel, LanguageModelRegistry};
use regex::Regex;
use reqwest_client::ReqwestClient;
use std::{cmp, path::PathBuf, sync::Arc};
#[derive(Parser, Debug)]
#[command(
name = "assistant_eval",
disable_version_flag = true,
before_help = "Tool eval runner"
)]
struct Args {
/// Regexes to match the names of evals to run.
eval_name_regexes: Vec<String>,
/// Runs all evals in `evaluation_data`, causes the regex to be ignored.
#[arg(long)]
all: bool,
/// Name of the model (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model_name: String,
/// Name of the editor model (default: value of `--model_name`).
#[arg(long)]
editor_model_name: Option<String>,
/// Name of the judge model (default: value of `--model_name`).
#[arg(long)]
judge_model_name: Option<String>,
/// Number of evaluations to run concurrently (default: 10)
#[arg(short, long, default_value = "10")]
concurrency: usize,
}
fn main() {
env_logger::init();
let args = Args::parse();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
let crate_dir = PathBuf::from("../zed-agent-bench");
let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
let repos_dir = crate_dir.join("repos");
if !repos_dir.exists() {
std::fs::create_dir_all(&repos_dir).unwrap();
}
let repos_dir = repos_dir.canonicalize().unwrap();
let all_evals = std::fs::read_dir(&evaluation_data_dir)
.unwrap()
.map(|path| path.unwrap().file_name().to_string_lossy().to_string())
.collect::<Vec<_>>();
let evals_to_run = if args.all {
all_evals
} else {
args.eval_name_regexes
.into_iter()
.map(|regex_string| Regex::new(&regex_string).unwrap())
.flat_map(|regex| {
all_evals
.iter()
.filter(|eval_name| regex.is_match(eval_name))
.cloned()
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
};
if evals_to_run.is_empty() {
panic!("Names of evals to run must be provided or `--all` specified");
}
println!("Will run the following evals: {evals_to_run:?}");
println!("Running up to {} evals concurrently", args.concurrency);
let editor_model_name = if let Some(model_name) = args.editor_model_name {
model_name
} else {
args.model_name.clone()
};
let judge_model_name = if let Some(model_name) = args.judge_model_name {
model_name
} else {
args.model_name.clone()
};
app.run(move |cx| {
let app_state = headless_assistant::init(cx);
let model = find_model(&args.model_name, cx).unwrap();
let editor_model = find_model(&editor_model_name, cx).unwrap();
let judge_model = find_model(&judge_model_name, cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_active_model(Some(model.clone()), cx);
registry.set_editor_model(Some(editor_model.clone()), cx);
});
let model_provider_id = model.provider_id();
let editor_model_provider_id = editor_model.provider_id();
let judge_model_provider_id = judge_model.provider_id();
cx.spawn(move |cx| async move {
// Authenticate all model providers first
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
let eval_load_futures = evals_to_run
.into_iter()
.map(|eval_name| {
let eval_path = evaluation_data_dir.join(&eval_name);
let load_future = Eval::load(eval_name.clone(), eval_path, &repos_dir);
async move {
match load_future.await {
Ok(eval) => Some(eval),
Err(err) => {
// TODO: Persist errors / surface errors at the end.
println!("Error loading {eval_name}: {err}");
None
}
}
}
})
.collect::<Vec<_>>();
let loaded_evals = future::join_all(eval_load_futures)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>();
// The evals need to be loaded and grouped by URL before concurrently running, since
// evals that use the same remote URL will use the same working directory.
let mut evals_grouped_by_url: Vec<Vec<Eval>> = loaded_evals
.into_iter()
.map(|eval| (eval.eval_setup.url.clone(), eval))
.into_group_map()
.into_values()
.collect::<Vec<_>>();
// Sort groups in descending order, so that bigger groups start first.
evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
let result_futures = evals_grouped_by_url
.into_iter()
.map(|evals| {
let model = model.clone();
let judge_model = judge_model.clone();
let app_state = app_state.clone();
let cx = cx.clone();
async move {
let mut results = Vec::new();
for eval in evals {
let name = eval.name.clone();
println!("Starting eval named {}", name);
let result = run_eval(
eval,
model.clone(),
judge_model.clone(),
app_state.clone(),
cx.clone(),
)
.await;
results.push((name, result));
}
results
}
})
.collect::<Vec<_>>();
let results = future::join_all(result_futures)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>();
// Process results in order of completion
for (eval_name, result) in results {
match result {
Ok((eval_output, judge_output)) => {
println!("Generated diff for {eval_name}:\n");
println!("{}\n", eval_output.diff);
println!("Last message for {eval_name}:\n");
println!("{}\n", eval_output.last_message);
println!("Elapsed time: {:?}", eval_output.elapsed_time);
println!(
"Assistant response count: {}",
eval_output.assistant_response_count
);
println!("Tool use counts: {:?}", eval_output.tool_use_counts);
println!("Judge output for {eval_name}: {judge_output}");
}
Err(err) => {
// TODO: Persist errors / surface errors at the end.
println!("Error running {eval_name}: {err}");
}
}
}
cx.update(|cx| cx.quit()).unwrap();
})
.detach();
});
println!("Done running evals");
}
async fn run_eval(
eval: Eval,
model: Arc<dyn LanguageModel>,
judge_model: Arc<dyn LanguageModel>,
app_state: Arc<HeadlessAppState>,
cx: AsyncApp,
) -> anyhow::Result<(EvalOutput, String)> {
let path = eval.path.clone();
let judge = Judge::load(&path, judge_model).await?;
let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?;
let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?;
eval_output.save_to_directory(&path, judge_output.to_string())?;
Ok((eval_output, judge_output))
}

View File

@@ -14,11 +14,9 @@ path = "src/assistant_tool.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
clock.workspace = true
derive_more.workspace = true
gpui.workspace = true
language.workspace = true
language_model.workspace = true
gpui.workspace = true
parking_lot.workspace = true
project.workspace = true
serde.workspace = true

View File

@@ -4,10 +4,7 @@ mod tool_working_set;
use std::sync::Arc;
use anyhow::Result;
use collections::{HashMap, HashSet};
use gpui::Context;
use gpui::{App, Entity, SharedString, Task};
use language::Buffer;
use language_model::LanguageModelRequestMessage;
use project::Project;
@@ -50,61 +47,6 @@ pub trait Tool: 'static + Send + Sync {
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>>;
}
/// Tracks actions performed by tools in a thread
#[derive(Debug)]
pub struct ActionLog {
/// Buffers that user manually added to the context, and whose content has
/// changed since the model last saw them.
stale_buffers_in_context: HashSet<Entity<Buffer>>,
/// Buffers that we want to notify the model about when they change.
tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
}
#[derive(Debug, Default)]
struct TrackedBuffer {
version: clock::Global,
}
impl ActionLog {
/// Creates a new, empty action log.
pub fn new() -> Self {
Self {
stale_buffers_in_context: HashSet::default(),
tracked_buffers: HashMap::default(),
}
}
/// Track a buffer as read, so we can notify the model about user edits.
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
tracked_buffer.version = buffer.read(cx).version();
}
/// Mark a buffer as edited, so we can refresh it in the context
pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
for buffer in &buffers {
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
tracked_buffer.version = buffer.read(cx).version();
}
self.stale_buffers_in_context.extend(buffers);
}
/// Iterate over buffers changed since last read or edited by the model
pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
self.tracked_buffers
.iter()
.filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
.map(|(buffer, _)| buffer)
}
/// Takes and returns the set of buffers pending refresh, clearing internal state.
pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
std::mem::take(&mut self.stale_buffers_in_context)
}
}

View File

@@ -18,7 +18,6 @@ chrono.workspace = true
collections.workspace = true
feature_flags.workspace = true
futures.workspace = true
itertools.workspace = true
gpui.workspace = true
language.workspace = true
language_model.workspace = true
@@ -31,7 +30,6 @@ theme.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
worktree.workspace = true
settings.workspace = true
[dev-dependencies]
@@ -39,7 +37,4 @@ rand.workspace = true
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
unindent.workspace = true
workspace = { workspace = true, features = ["test-support"] }

View File

@@ -7,7 +7,6 @@ mod now_tool;
mod path_search_tool;
mod read_file_tool;
mod regex_search;
mod thinking_tool;
use assistant_tool::ToolRegistry;
use gpui::App;
@@ -21,7 +20,6 @@ use crate::now_tool::NowTool;
use crate::path_search_tool::PathSearchTool;
use crate::read_file_tool::ReadFileTool;
use crate::regex_search::RegexSearchTool;
use crate::thinking_tool::ThinkingTool;
pub fn init(cx: &mut App) {
assistant_tool::init(cx);
@@ -37,5 +35,4 @@ pub fn init(cx: &mut App) {
registry.register_tool(PathSearchTool);
registry.register_tool(ReadFileTool);
registry.register_tool(RegexSearchTool);
registry.register_tool(ThinkingTool);
}

View File

@@ -1,5 +1,5 @@
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
@@ -37,7 +37,6 @@ impl Tool for BashTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input: BashToolInput = match serde_json::from_value(input) {

View File

@@ -1,15 +1,16 @@
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::{fs, path::PathBuf, sync::Arc};
use util::paths::PathMatcher;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct DeletePathToolInput {
/// The path of the file or directory to delete.
/// The glob to match files in the project to delete.
///
/// <example>
/// If the project has the following files:
@@ -18,9 +19,9 @@ pub struct DeletePathToolInput {
/// - directory2/a/things.txt
/// - directory3/a/other.txt
///
/// You can delete the first file by providing a path of "directory1/a/something.txt"
/// You can delete the first two files by providing a glob of "*thing*.txt"
/// </example>
pub path: String,
pub glob: String,
}
pub struct DeletePathTool;
@@ -44,29 +45,121 @@ impl Tool for DeletePathTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
Ok(input) => input.path,
let glob = match serde_json::from_value::<DeletePathToolInput>(input) {
Ok(input) => input.glob,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let path_matcher = match PathMatcher::new(&[glob.clone()]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {}", err))),
};
match project
.read(cx)
.find_project_path(&path_str, cx)
.and_then(|path| project.update(cx, |project, cx| project.delete_file(path, false, cx)))
{
Some(deletion_task) => cx.background_spawn(async move {
match deletion_task.await {
Ok(()) => Ok(format!("Deleted {}", &path_str)),
Err(err) => Err(anyhow!("Failed to delete {}: {}", &path_str, err)),
struct Match {
display_path: String,
path: PathBuf,
}
let mut matches = Vec::new();
let mut deleted_paths = Vec::new();
let mut errors = Vec::new();
for worktree_handle in project.read(cx).worktrees(cx) {
let worktree = worktree_handle.read(cx);
let worktree_root = worktree.abs_path().to_path_buf();
// Don't consider ignored entries.
for entry in worktree.entries(false, 0) {
if path_matcher.is_match(&entry.path) {
matches.push(Match {
path: worktree_root.join(&entry.path),
display_path: entry.path.display().to_string(),
});
}
}),
None => Task::ready(Err(anyhow!(
"Couldn't delete {} because that path isn't in this project.",
path_str
))),
}
}
if matches.is_empty() {
return Task::ready(Ok(format!("No paths in the project matched {glob:?}")));
}
let paths_matched = matches.len();
// Delete the files
for Match { path, display_path } in matches {
match fs::remove_file(&path) {
Ok(()) => {
deleted_paths.push(display_path);
}
Err(file_err) => {
// Try to remove directory if it's not a file. Retrying as a directory
// on error saves a syscall compared to checking whether it's
// a directory up front for every single file.
if let Err(dir_err) = fs::remove_dir_all(&path) {
let error = if path.is_dir() {
format!("Failed to delete directory {}: {dir_err}", display_path)
} else {
format!("Failed to delete file {}: {file_err}", display_path)
};
errors.push(error);
} else {
deleted_paths.push(display_path);
}
}
}
}
if errors.is_empty() {
// 0 deleted paths should never happen if there were no errors;
// we already returned if matches was empty.
let answer = if deleted_paths.len() == 1 {
format!(
"Deleted {}",
deleted_paths.first().unwrap_or(&String::new())
)
} else {
// Sort to group entries in the same directory together
deleted_paths.sort();
let mut buf = format!("Deleted these {} paths:\n", deleted_paths.len());
for path in deleted_paths.iter() {
buf.push('\n');
buf.push_str(path);
}
buf
};
Task::ready(Ok(answer))
} else {
if deleted_paths.is_empty() {
Task::ready(Err(anyhow!(
"{glob:?} matched {} deleted because of {}:\n{}",
if paths_matched == 1 {
"1 path, but it was not".to_string()
} else {
format!("{} paths, but none were", paths_matched)
},
if errors.len() == 1 {
"this error".to_string()
} else {
format!("{} errors", errors.len())
},
errors.join("\n")
)))
} else {
// Sort to group entries in the same directory together
deleted_paths.sort();
Task::ready(Ok(format!(
"Deleted {} paths matching glob {glob:?}:\n{}\n\nErrors:\n{}",
deleted_paths.len(),
deleted_paths.join("\n"),
errors.join("\n")
)))
}
}
}
}

View File

@@ -1 +1 @@
Deletes the file or directory (and the directory's contents, recursively) at the specified path in the project, and returns confirmation of the deletion.
Deletes all files and directories in the project which match the given glob, and returns a list of the paths that were deleted.

View File

@@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::LanguageModelRequestMessage;
@@ -51,7 +51,6 @@ impl Tool for DiagnosticsTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<DiagnosticsToolInput>(input) {

View File

@@ -1,24 +1,22 @@
mod edit_action;
pub mod log;
mod resolve_search_block;
use anyhow::{anyhow, Context, Result};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::Tool;
use collections::HashSet;
use edit_action::{EditAction, EditActionParser};
use futures::StreamExt;
use gpui::{App, AsyncApp, Entity, Task};
use language::OffsetRangeExt;
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use log::{EditToolLog, EditToolRequestId};
use project::Project;
use resolve_search_block::resolve_search_block;
use project::{search::SearchQuery, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::sync::Arc;
use util::paths::PathMatcher;
use util::ResultExt;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -82,7 +80,6 @@ impl Tool for EditFilesTool {
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<EditFilesToolInput>(input) {
@@ -96,14 +93,8 @@ impl Tool for EditFilesTool {
log.new_request(input.edit_instructions.clone(), cx)
});
let task = EditToolRequest::new(
input,
messages,
project,
action_log,
Some((log.clone(), req_id)),
cx,
);
let task =
EditToolRequest::new(input, messages, project, Some((log.clone(), req_id)), cx);
cx.spawn(|mut cx| async move {
let result = task.await;
@@ -122,18 +113,29 @@ impl Tool for EditFilesTool {
})
}
None => EditToolRequest::new(input, messages, project, action_log, None, cx),
None => EditToolRequest::new(input, messages, project, None, cx),
}
}
}
struct EditToolRequest {
parser: EditActionParser,
output: String,
changed_buffers: HashSet<Entity<language::Buffer>>,
bad_searches: Vec<BadSearch>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
}
#[derive(Debug)]
enum DiffResult {
BadSearch(BadSearch),
Diff(language::Diff),
}
#[derive(Debug)]
struct BadSearch {
file_path: String,
search: String,
}
impl EditToolRequest {
@@ -141,8 +143,7 @@ impl EditToolRequest {
input: EditFilesToolInput,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
cx: &mut App,
) -> Task<Result<String>> {
let model_registry = LanguageModelRegistry::read_global(cx);
@@ -151,23 +152,12 @@ impl EditToolRequest {
};
let mut messages = messages.to_vec();
// Remove the last tool use (this run) to prevent an invalid request
'outer: for message in messages.iter_mut().rev() {
for (index, content) in message.content.iter().enumerate().rev() {
match content {
MessageContent::ToolUse(_) => {
message.content.remove(index);
break 'outer;
}
MessageContent::ToolResult(_) => {
// If we find any tool results before a tool use, the request is already valid
break 'outer;
}
MessageContent::Text(_) | MessageContent::Image(_) => {}
}
}
if let Some(last_message) = messages.last_mut() {
// Strip out tool use from the last message because we're in the middle of executing a tool call.
last_message
.content
.retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
}
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![
@@ -190,12 +180,10 @@ impl EditToolRequest {
let mut request = Self {
parser: EditActionParser::new(),
// we start with the success header so we don't need to shift the output in the common case
output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
changed_buffers: HashSet::default(),
action_log,
bad_searches: Vec::new(),
project,
tool_log,
log,
};
while let Some(chunk) = chunks.stream.next().await {
@@ -209,7 +197,7 @@ impl EditToolRequest {
async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
let new_actions = self.parser.parse_chunk(chunk);
if let Some((ref log, req_id)) = self.tool_log {
if let Some((ref log, req_id)) = self.log {
log.update(cx, |log, cx| {
log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
})
@@ -223,11 +211,7 @@ impl EditToolRequest {
Ok(())
}
async fn apply_action(
&mut self,
(action, source): (EditAction, String),
cx: &mut AsyncApp,
) -> Result<()> {
async fn apply_action(&mut self, action: EditAction, cx: &mut AsyncApp) -> Result<()> {
let project_path = self.project.read_with(cx, |project, cx| {
project
.find_project_path(action.file_path(), cx)
@@ -239,30 +223,35 @@ impl EditToolRequest {
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let diff = match action {
let result = match action {
EditAction::Replace {
old,
new,
file_path: _,
file_path,
} => {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let diff = cx
.background_executor()
.spawn(Self::replace_diff(old, new, snapshot))
.await;
anyhow::Ok(diff)
cx.background_executor()
.spawn(Self::replace_diff(old, new, file_path, snapshot))
.await
}
EditAction::Write { content, .. } => Ok(buffer
.read_with(cx, |buffer, cx| buffer.diff(content, cx))?
.await),
EditAction::Write { content, .. } => Ok(DiffResult::Diff(
buffer
.read_with(cx, |buffer, cx| buffer.diff(content, cx))?
.await,
)),
}?;
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
match result {
DiffResult::BadSearch(invalid_replace) => {
self.bad_searches.push(invalid_replace);
}
DiffResult::Diff(diff) => {
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
write!(&mut self.output, "\n\n{}", source)?;
self.changed_buffers.insert(buffer);
self.changed_buffers.insert(buffer);
}
}
Ok(())
}
@@ -270,9 +259,29 @@ impl EditToolRequest {
async fn replace_diff(
old: String,
new: String,
file_path: std::path::PathBuf,
snapshot: language::BufferSnapshot,
) -> language::Diff {
let edit_range = resolve_search_block(&snapshot, &old).to_offset(&snapshot);
) -> Result<DiffResult> {
let query = SearchQuery::text(
old.clone(),
false,
true,
true,
PathMatcher::new(&[])?,
PathMatcher::new(&[])?,
None,
)?;
let matches = query.search(&snapshot, None).await;
if matches.is_empty() {
return Ok(DiffResult::BadSearch(BadSearch {
search: new.clone(),
file_path: file_path.display().to_string(),
}));
}
let edit_range = matches[0].clone();
let diff = language::text_diff(&old, &new);
let edits = diff
@@ -290,71 +299,77 @@ impl EditToolRequest {
edits,
};
diff
anyhow::Ok(DiffResult::Diff(diff))
}
const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
const ERROR_OUTPUT_HEADER_NO_EDITS: &str = "I couldn't apply any edits!";
const ERROR_OUTPUT_HEADER_WITH_EDITS: &str =
"Errors occurred. First, here's a list of the edits we managed to apply:";
async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
let changed_buffer_count = self.changed_buffers.len();
let mut answer = match self.changed_buffers.len() {
0 => "No files were edited.".to_string(),
1 => "Successfully edited ".to_string(),
_ => "Successfully edited these files:\n\n".to_string(),
};
// Save each buffer once at the end
for buffer in &self.changed_buffers {
self.project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
}
for buffer in self.changed_buffers {
let (path, save_task) = self.project.update(cx, |project, cx| {
let path = buffer
.read(cx)
.file()
.map(|file| file.path().display().to_string());
self.action_log
.update(cx, |log, cx| log.buffer_edited(self.changed_buffers, cx))
.log_err();
let task = project.save_buffer(buffer.clone(), cx);
(path, task)
})?;
save_task.await?;
if let Some(path) = path {
writeln!(&mut answer, "{}", path)?;
}
}
let errors = self.parser.errors();
if errors.is_empty() {
if changed_buffer_count == 0 {
return Err(anyhow!(
"The instructions didn't lead to any changes. You might need to consult the file contents first."
));
}
Ok(self.output)
if errors.is_empty() && self.bad_searches.is_empty() {
Ok(answer.trim_end().to_string())
} else {
let mut output = self.output;
if !self.bad_searches.is_empty() {
writeln!(
&mut answer,
"\nThese searches failed because they didn't match any strings:"
)?;
if output.is_empty() {
output.replace_range(
0..Self::SUCCESS_OUTPUT_HEADER.len(),
Self::ERROR_OUTPUT_HEADER_NO_EDITS,
);
} else {
output.replace_range(
0..Self::SUCCESS_OUTPUT_HEADER.len(),
Self::ERROR_OUTPUT_HEADER_WITH_EDITS,
);
for replace in self.bad_searches {
writeln!(
&mut answer,
"- '{}' does not appear in `{}`",
replace.search.replace("\r", "\\r").replace("\n", "\\n"),
replace.file_path
)?;
}
writeln!(&mut answer, "Make sure to use exact searches.")?;
}
if !errors.is_empty() {
writeln!(
&mut output,
"\n\nThese SEARCH/REPLACE blocks failed to parse:"
&mut answer,
"\nThese SEARCH/REPLACE blocks failed to parse:"
)?;
for error in errors {
writeln!(&mut output, "- {}", error)?;
writeln!(&mut answer, "- {}", error)?;
}
}
writeln!(
&mut output,
"\nYou can fix errors by running the tool again. You can include instructions, \
&mut answer,
"\nYou can fix errors by running the tool again. You can include instructions,\
but errors are part of the conversation so you don't need to repeat them."
)?;
Err(anyhow!(output))
Err(anyhow!(answer))
}
}
}

View File

@@ -1,8 +1,4 @@
use std::{
mem::take,
ops::Range,
path::{Path, PathBuf},
};
use std::path::{Path, PathBuf};
use util::ResultExt;
/// Represents an edit action to be performed on a file.
@@ -32,14 +28,12 @@ impl EditAction {
#[derive(Debug)]
pub struct EditActionParser {
state: State,
pre_fence_line: Vec<u8>,
marker_ix: usize,
line: usize,
column: usize,
marker_ix: usize,
action_source: Vec<u8>,
fence_start_offset: usize,
block_range: Range<usize>,
old_range: Range<usize>,
new_range: Range<usize>,
old_bytes: Vec<u8>,
new_bytes: Vec<u8>,
errors: Vec<ParseError>,
}
@@ -64,14 +58,12 @@ impl EditActionParser {
pub fn new() -> Self {
Self {
state: State::Default,
pre_fence_line: Vec::new(),
marker_ix: 0,
line: 1,
column: 0,
action_source: Vec::new(),
fence_start_offset: 0,
marker_ix: 0,
block_range: Range::default(),
old_range: Range::default(),
new_range: Range::default(),
old_bytes: Vec::new(),
new_bytes: Vec::new(),
errors: Vec::new(),
}
}
@@ -84,7 +76,7 @@ impl EditActionParser {
///
/// If a block fails to parse, it will simply be skipped and an error will be recorded.
/// All errors can be accessed through the `EditActionsParser::errors` method.
pub fn parse_chunk(&mut self, input: &str) -> Vec<(EditAction, String)> {
pub fn parse_chunk(&mut self, input: &str) -> Vec<EditAction> {
use State::*;
const FENCE: &[u8] = b"```";
@@ -105,21 +97,20 @@ impl EditActionParser {
self.column += 1;
}
let action_offset = self.action_source.len();
match &self.state {
Default => match self.match_marker(byte, FENCE, false) {
Default => match match_marker(byte, FENCE, false, &mut self.marker_ix) {
MarkerMatch::Complete => {
self.fence_start_offset = action_offset + 1 - FENCE.len();
self.to_state(OpenFence);
}
MarkerMatch::Partial => {}
MarkerMatch::None => {
if self.marker_ix > 0 {
self.marker_ix = 0;
} else if self.action_source.ends_with(b"\n") {
self.action_source.clear();
} else if self.pre_fence_line.ends_with(b"\n") {
self.pre_fence_line.clear();
}
self.pre_fence_line.push(byte);
}
},
OpenFence => {
@@ -134,34 +125,39 @@ impl EditActionParser {
}
}
SearchBlock => {
if self.extend_block_range(byte, DIVIDER, NL_DIVIDER) {
self.old_range = take(&mut self.block_range);
if collect_until_marker(
byte,
DIVIDER,
NL_DIVIDER,
true,
&mut self.marker_ix,
&mut self.old_bytes,
) {
self.to_state(ReplaceBlock);
}
}
ReplaceBlock => {
if self.extend_block_range(byte, REPLACE_MARKER, NL_REPLACE_MARKER) {
self.new_range = take(&mut self.block_range);
if collect_until_marker(
byte,
REPLACE_MARKER,
NL_REPLACE_MARKER,
true,
&mut self.marker_ix,
&mut self.new_bytes,
) {
self.to_state(CloseFence);
}
}
CloseFence => {
if self.expect_marker(byte, FENCE, false) {
self.action_source.push(byte);
if let Some(action) = self.action() {
actions.push(action);
}
self.errors();
self.reset();
continue;
}
}
};
self.action_source.push(byte);
}
actions
@@ -172,76 +168,37 @@ impl EditActionParser {
&self.errors
}
fn action(&mut self) -> Option<(EditAction, String)> {
let old_range = take(&mut self.old_range);
let new_range = take(&mut self.new_range);
let action_source = take(&mut self.action_source);
let action_source = String::from_utf8(action_source).log_err()?;
let mut file_path_bytes = action_source[..self.fence_start_offset].to_owned();
if file_path_bytes.ends_with("\n") {
file_path_bytes.pop();
if file_path_bytes.ends_with("\r") {
file_path_bytes.pop();
}
}
let file_path = PathBuf::from(file_path_bytes);
if old_range.is_empty() && new_range.is_empty() {
fn action(&mut self) -> Option<EditAction> {
if self.old_bytes.is_empty() && self.new_bytes.is_empty() {
self.push_error(ParseErrorKind::NoOp);
return None;
}
if old_range.is_empty() {
return Some((
EditAction::Write {
file_path,
content: action_source[new_range].to_owned(),
},
action_source,
));
let mut pre_fence_line = std::mem::take(&mut self.pre_fence_line);
if pre_fence_line.ends_with(b"\n") {
pre_fence_line.pop();
pop_carriage_return(&mut pre_fence_line);
}
let old = action_source[old_range].to_owned();
let new = action_source[new_range].to_owned();
let file_path = PathBuf::from(String::from_utf8(pre_fence_line).log_err()?);
let content = String::from_utf8(std::mem::take(&mut self.new_bytes)).log_err()?;
let action = EditAction::Replace {
file_path,
old,
new,
};
if self.old_bytes.is_empty() {
Some(EditAction::Write { file_path, content })
} else {
let old = String::from_utf8(std::mem::take(&mut self.old_bytes)).log_err()?;
Some((action, action_source))
}
fn to_state(&mut self, state: State) {
self.state = state;
self.marker_ix = 0;
}
fn reset(&mut self) {
self.action_source.clear();
self.block_range = Range::default();
self.old_range = Range::default();
self.new_range = Range::default();
self.fence_start_offset = 0;
self.marker_ix = 0;
self.to_state(State::Default);
}
fn push_error(&mut self, kind: ParseErrorKind) {
self.errors.push(ParseError {
line: self.line,
column: self.column,
kind,
});
Some(EditAction::Replace {
file_path,
old,
new: content,
})
}
}
fn expect_marker(&mut self, byte: u8, marker: &'static [u8], trailing_newline: bool) -> bool {
match self.match_marker(byte, marker, trailing_newline) {
match match_marker(byte, marker, trailing_newline, &mut self.marker_ix) {
MarkerMatch::Complete => true,
MarkerMatch::Partial => false,
MarkerMatch::None => {
@@ -255,68 +212,24 @@ impl EditActionParser {
}
}
fn extend_block_range(&mut self, byte: u8, marker: &[u8], nl_marker: &[u8]) -> bool {
let marker = if self.block_range.is_empty() {
// do not require another newline if block is empty
marker
} else {
nl_marker
};
let offset = self.action_source.len();
match self.match_marker(byte, marker, true) {
MarkerMatch::Complete => {
if self.action_source[self.block_range.clone()].ends_with(b"\r") {
self.block_range.end -= 1;
}
true
}
MarkerMatch::Partial => false,
MarkerMatch::None => {
if self.marker_ix > 0 {
self.marker_ix = 0;
self.block_range.end = offset;
// The beginning of marker might match current byte
match self.match_marker(byte, marker, true) {
MarkerMatch::Complete => return true,
MarkerMatch::Partial => return false,
MarkerMatch::None => { /* no match, keep collecting */ }
}
}
if self.block_range.is_empty() {
self.block_range.start = offset;
}
self.block_range.end = offset + 1;
false
}
}
fn to_state(&mut self, state: State) {
self.state = state;
self.marker_ix = 0;
}
fn match_marker(&mut self, byte: u8, marker: &[u8], trailing_newline: bool) -> MarkerMatch {
if trailing_newline && self.marker_ix >= marker.len() {
if byte == b'\n' {
MarkerMatch::Complete
} else if byte == b'\r' {
MarkerMatch::Partial
} else {
MarkerMatch::None
}
} else if byte == marker[self.marker_ix] {
self.marker_ix += 1;
fn reset(&mut self) {
self.pre_fence_line.clear();
self.old_bytes.clear();
self.new_bytes.clear();
self.to_state(State::Default);
}
if self.marker_ix < marker.len() || trailing_newline {
MarkerMatch::Partial
} else {
MarkerMatch::Complete
}
} else {
MarkerMatch::None
}
fn push_error(&mut self, kind: ParseErrorKind) {
self.errors.push(ParseError {
line: self.line,
column: self.column,
kind,
});
}
}
@@ -327,6 +240,80 @@ enum MarkerMatch {
Complete,
}
fn match_marker(
byte: u8,
marker: &[u8],
trailing_newline: bool,
marker_ix: &mut usize,
) -> MarkerMatch {
if trailing_newline && *marker_ix >= marker.len() {
if byte == b'\n' {
MarkerMatch::Complete
} else if byte == b'\r' {
MarkerMatch::Partial
} else {
MarkerMatch::None
}
} else if byte == marker[*marker_ix] {
*marker_ix += 1;
if *marker_ix < marker.len() || trailing_newline {
MarkerMatch::Partial
} else {
MarkerMatch::Complete
}
} else {
MarkerMatch::None
}
}
fn collect_until_marker(
byte: u8,
marker: &[u8],
nl_marker: &[u8],
trailing_newline: bool,
marker_ix: &mut usize,
buf: &mut Vec<u8>,
) -> bool {
let marker = if buf.is_empty() {
// do not require another newline if block is empty
marker
} else {
nl_marker
};
match match_marker(byte, marker, trailing_newline, marker_ix) {
MarkerMatch::Complete => {
pop_carriage_return(buf);
true
}
MarkerMatch::Partial => false,
MarkerMatch::None => {
if *marker_ix > 0 {
buf.extend_from_slice(&marker[..*marker_ix]);
*marker_ix = 0;
// The beginning of marker might match current byte
match match_marker(byte, marker, trailing_newline, marker_ix) {
MarkerMatch::Complete => return true,
MarkerMatch::Partial => return false,
MarkerMatch::None => { /* no match, keep collecting */ }
}
}
buf.push(byte);
false
}
}
}
fn pop_carriage_return(buf: &mut Vec<u8>) {
if buf.ends_with(b"\r") {
buf.pop();
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct ParseError {
line: usize,
@@ -368,7 +355,6 @@ impl std::fmt::Display for ParseError {
mod tests {
use super::*;
use rand::prelude::*;
use util::line_endings;
#[test]
fn test_simple_edit_action() {
@@ -385,16 +371,16 @@ fn replacement() {}
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {}".to_string(),
new: "fn replacement() {}".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -412,16 +398,16 @@ fn replacement() {}
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {}".to_string(),
new: "fn replacement() {}".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -443,16 +429,16 @@ This change makes the function better.
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {}".to_string(),
new: "fn replacement() {}".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -481,27 +467,24 @@ fn new_util() -> bool { true }
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 2);
let (action, _) = &actions[0];
assert_eq!(
action,
&EditAction::Replace {
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {}".to_string(),
new: "fn replacement() {}".to_string(),
}
);
let (action2, _) = &actions[1];
assert_eq!(
action2,
&EditAction::Replace {
actions[1],
EditAction::Replace {
file_path: PathBuf::from("src/utils.rs"),
old: "fn old_util() -> bool { false }".to_string(),
new: "fn new_util() -> bool { true }".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -533,18 +516,16 @@ fn replacement() {
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
let (action, _) = &actions[0];
assert_eq!(
action,
&EditAction::Replace {
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {\n println!(\"This is the original function\");\n let x = 42;\n if x > 0 {\n println!(\"Positive number\");\n }\n}".to_string(),
new: "fn replacement() {\n println!(\"This is the replacement function\");\n let x = 100;\n if x > 50 {\n println!(\"Large number\");\n } else {\n println!(\"Small number\");\n }\n}".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -565,16 +546,16 @@ fn new_function() {
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Write {
file_path: PathBuf::from("src/main.rs"),
content: "fn new_function() {\n println!(\"This function is being added\");\n}"
.to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -592,11 +573,9 @@ fn this_will_be_deleted() {
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(&input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn this_will_be_deleted() {\n println!(\"Deleting this function\");\n}"
@@ -604,13 +583,12 @@ fn this_will_be_deleted() {
new: "".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(&input.replace("\n", "\r\n"));
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old:
@@ -619,6 +597,7 @@ fn this_will_be_deleted() {
new: "".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -663,27 +642,26 @@ fn replacement() {}"#;
let mut parser = EditActionParser::new();
let actions1 = parser.parse_chunk(input_part1);
assert_no_errors(&parser);
assert_eq!(actions1.len(), 0);
assert_eq!(parser.errors().len(), 0);
let actions2 = parser.parse_chunk(input_part2);
// No actions should be complete yet
assert_no_errors(&parser);
assert_eq!(actions2.len(), 0);
assert_eq!(parser.errors().len(), 0);
let actions3 = parser.parse_chunk(input_part3);
// The third chunk should complete the action
assert_no_errors(&parser);
assert_eq!(actions3.len(), 1);
let (action, _) = &actions3[0];
assert_eq!(
action,
&EditAction::Replace {
actions3[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {}".to_string(),
new: "fn replacement() {}".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -692,35 +670,28 @@ fn replacement() {}"#;
let actions1 = parser.parse_chunk("src/main.rs\n```rust\n<<<<<<< SEARCH\n");
// Check parser is in the correct state
assert_no_errors(&parser);
assert_eq!(parser.state, State::SearchBlock);
assert_eq!(
parser.action_source,
b"src/main.rs\n```rust\n<<<<<<< SEARCH\n"
);
assert_eq!(parser.pre_fence_line, b"src/main.rs\n");
assert_eq!(parser.errors().len(), 0);
// Continue parsing
let actions2 = parser.parse_chunk("original code\n=======\n");
assert_no_errors(&parser);
assert_eq!(parser.state, State::ReplaceBlock);
assert_eq!(
&parser.action_source[parser.old_range.clone()],
b"original code"
);
assert_eq!(parser.old_bytes, b"original code");
assert_eq!(parser.errors().len(), 0);
let actions3 = parser.parse_chunk("replacement code\n>>>>>>> REPLACE\n```\n");
// After complete parsing, state should reset
assert_no_errors(&parser);
assert_eq!(parser.state, State::Default);
assert_eq!(parser.action_source, b"\n");
assert!(parser.old_range.is_empty());
assert!(parser.new_range.is_empty());
assert_eq!(parser.pre_fence_line, b"\n");
assert!(parser.old_bytes.is_empty());
assert!(parser.new_bytes.is_empty());
assert_eq!(actions1.len(), 0);
assert_eq!(actions2.len(), 0);
assert_eq!(actions3.len(), 1);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -774,10 +745,9 @@ fn new_utils_func() {}
// Only the second block should be parsed
assert_eq!(actions.len(), 1);
let (action, _) = &actions[0];
assert_eq!(
action,
&EditAction::Replace {
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/utils.rs"),
old: "fn utils_func() {}".to_string(),
new: "fn new_utils_func() {}".to_string(),
@@ -813,65 +783,64 @@ fn new_utils_func() {}
let (chunk, rest) = remaining.split_at(chunk_size);
let chunk_actions = parser.parse_chunk(chunk);
actions.extend(chunk_actions);
actions.extend(parser.parse_chunk(chunk));
remaining = rest;
}
assert_examples_in_system_prompt(&actions, parser.errors());
}
fn assert_examples_in_system_prompt(actions: &[(EditAction, String)], errors: &[ParseError]) {
fn assert_examples_in_system_prompt(actions: &[EditAction], errors: &[ParseError]) {
assert_eq!(actions.len(), 5);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("mathweb/flask/app.py"),
old: "from flask import Flask".to_string(),
new: line_endings!("import math\nfrom flask import Flask").to_string(),
},
new: "import math\nfrom flask import Flask".to_string(),
}
.fix_lf(),
);
assert_eq!(
actions[1].0,
actions[1],
EditAction::Replace {
file_path: PathBuf::from("mathweb/flask/app.py"),
old: line_endings!("def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n").to_string(),
old: "def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n".to_string(),
new: "".to_string(),
}
.fix_lf()
);
assert_eq!(
actions[2].0,
actions[2],
EditAction::Replace {
file_path: PathBuf::from("mathweb/flask/app.py"),
old: " return str(factorial(n))".to_string(),
new: " return str(math.factorial(n))".to_string(),
},
}
.fix_lf(),
);
assert_eq!(
actions[3].0,
actions[3],
EditAction::Write {
file_path: PathBuf::from("hello.py"),
content: line_endings!(
"def hello():\n \"print a greeting\"\n\n print(\"hello\")"
)
.to_string(),
},
content: "def hello():\n \"print a greeting\"\n\n print(\"hello\")"
.to_string(),
}
.fix_lf(),
);
assert_eq!(
actions[4].0,
actions[4],
EditAction::Replace {
file_path: PathBuf::from("main.py"),
old: line_endings!(
"def hello():\n \"print a greeting\"\n\n print(\"hello\")"
)
.to_string(),
old: "def hello():\n \"print a greeting\"\n\n print(\"hello\")".to_string(),
new: "from hello import hello".to_string(),
},
}
.fix_lf(),
);
// The system prompt includes some text that would produce errors
@@ -891,6 +860,29 @@ fn new_utils_func() {}
);
}
impl EditAction {
fn fix_lf(self: EditAction) -> EditAction {
#[cfg(windows)]
match self {
EditAction::Replace {
file_path,
old,
new,
} => EditAction::Replace {
file_path: file_path.clone(),
old: old.replace("\n", "\r\n"),
new: new.replace("\n", "\r\n"),
},
EditAction::Write { file_path, content } => EditAction::Write {
file_path: file_path.clone(),
content: content.replace("\n", "\r\n"),
},
}
#[cfg(not(windows))]
self
}
}
#[test]
fn test_print_error() {
let input = r#"src/main.rs
@@ -912,20 +904,4 @@ fn replacement() {}
assert_eq!(format!("{}", error), expected_error);
}
// helpers
fn assert_no_errors(parser: &EditActionParser) {
let errors = parser.errors();
assert!(
errors.is_empty(),
"Expected no errors, but found:\n\n{}",
errors
.iter()
.map(|e| e.to_string())
.collect::<Vec<String>>()
.join("\n")
);
}
}

View File

@@ -80,7 +80,7 @@ impl EditToolLog {
&mut self,
id: EditToolRequestId,
chunk: &str,
new_actions: &[(EditAction, String)],
new_actions: &[EditAction],
cx: &mut Context<Self>,
) {
if let Some(request) = self.requests.get_mut(id.0 as usize) {
@@ -92,9 +92,7 @@ impl EditToolLog {
response.push_str(chunk);
}
}
request
.parsed_edits
.extend(new_actions.iter().cloned().map(|(action, _)| action));
request.parsed_edits.extend(new_actions.iter().cloned());
cx.emit(EditToolLogEvent::Updated);
}

View File

@@ -1,226 +0,0 @@
use language::{Anchor, Bias, BufferSnapshot};
use std::ops::Range;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum SearchDirection {
Up,
Left,
Diagonal,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SearchState {
cost: u32,
direction: SearchDirection,
}
impl SearchState {
fn new(cost: u32, direction: SearchDirection) -> Self {
Self { cost, direction }
}
}
struct SearchMatrix {
cols: usize,
data: Vec<SearchState>,
}
impl SearchMatrix {
fn new(rows: usize, cols: usize) -> Self {
SearchMatrix {
cols,
data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
}
}
fn get(&self, row: usize, col: usize) -> SearchState {
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, cost: SearchState) {
self.data[row * self.cols + col] = cost;
}
}
pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
const INSERTION_COST: u32 = 3;
const DELETION_COST: u32 = 10;
const WHITESPACE_INSERTION_COST: u32 = 1;
const WHITESPACE_DELETION_COST: u32 = 1;
let buffer_len = buffer.len();
let query_len = search_query.len();
let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
let mut leading_deletion_cost = 0_u32;
for (row, query_byte) in search_query.bytes().enumerate() {
let deletion_cost = if query_byte.is_ascii_whitespace() {
WHITESPACE_DELETION_COST
} else {
DELETION_COST
};
leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
matrix.set(
row + 1,
0,
SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
);
for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
let insertion_cost = if buffer_byte.is_ascii_whitespace() {
WHITESPACE_INSERTION_COST
} else {
INSERTION_COST
};
let up = SearchState::new(
matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
SearchDirection::Up,
);
let left = SearchState::new(
matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
SearchDirection::Left,
);
let diagonal = SearchState::new(
if query_byte == *buffer_byte {
matrix.get(row, col).cost
} else {
matrix
.get(row, col)
.cost
.saturating_add(deletion_cost + insertion_cost)
},
SearchDirection::Diagonal,
);
matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
}
}
// Traceback to find the best match
let mut best_buffer_end = buffer_len;
let mut best_cost = u32::MAX;
for col in 1..=buffer_len {
let cost = matrix.get(query_len, col).cost;
if cost < best_cost {
best_cost = cost;
best_buffer_end = col;
}
}
let mut query_ix = query_len;
let mut buffer_ix = best_buffer_end;
while query_ix > 0 && buffer_ix > 0 {
let current = matrix.get(query_ix, buffer_ix);
match current.direction {
SearchDirection::Diagonal => {
query_ix -= 1;
buffer_ix -= 1;
}
SearchDirection::Up => {
query_ix -= 1;
}
SearchDirection::Left => {
buffer_ix -= 1;
}
}
}
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
start.column = 0;
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
if end.column > 0 {
end.column = buffer.line_len(end.row);
}
buffer.anchor_after(start)..buffer.anchor_before(end)
}
#[cfg(test)]
mod tests {
use crate::edit_files_tool::resolve_search_block::resolve_search_block;
use gpui::{prelude::*, App};
use language::{Buffer, OffsetRangeExt as _};
use unindent::Unindent as _;
use util::test::{generate_marked_text, marked_text_ranges};
#[gpui::test]
fn test_resolve_search_block(cx: &mut App) {
assert_resolved(
concat!(
" Lorem\n",
"« ipsum\n",
" dolor sit amet»\n",
" consecteur",
),
"ipsum\ndolor",
cx,
);
assert_resolved(
&"
«fn foo1(a: usize) -> usize {
40
fn foo2(b: usize) -> usize {
42
}
"
.unindent(),
"fn foo1(b: usize) {\n40\n}",
cx,
);
assert_resolved(
&"
fn main() {
« Foo
.bar()
.baz()
.qux()»
}
fn foo2(b: usize) -> usize {
42
}
"
.unindent(),
"Foo.bar.baz.qux()",
cx,
);
assert_resolved(
&"
class Something {
one() { return 1; }
« two() { return 2222; }
three() { return 333; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
» seven() { return 7; }
eight() { return 8; }
}
"
.unindent(),
&"
two() { return 2222; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
"
.unindent(),
cx,
);
}
#[track_caller]
fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) {
let (text, _) = marked_text_ranges(text_with_expected_range, false);
let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
let snapshot = buffer.read(cx).snapshot();
let range = resolve_search_block(&snapshot, query).to_offset(&snapshot);
let text_with_actual_range = generate_marked_text(&text, &[range], false);
pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
}
}

View File

@@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
@@ -55,7 +55,6 @@ impl Tool for ListDirectoryTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::Tool;
use chrono::{Local, Utc};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
@@ -45,7 +45,6 @@ impl Tool for NowTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_cx: &mut App,
) -> Task<Result<String>> {
let input: NowToolInput = match serde_json::from_value(input) {

View File

@@ -1,13 +1,12 @@
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use util::paths::PathMatcher;
use worktree::Snapshot;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct PathSearchToolInput {
@@ -23,15 +22,8 @@ pub struct PathSearchToolInput {
/// You can get back the first two paths by providing a glob of "*thing*.txt"
/// </example>
pub glob: String,
/// Optional starting position for paginated results (0-based).
/// When not provided, starts from the beginning.
#[serde(default)]
pub offset: Option<usize>,
}
const RESULTS_PER_PAGE: usize = 50;
pub struct PathSearchTool;
impl Tool for PathSearchTool {
@@ -53,69 +45,44 @@ impl Tool for PathSearchTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
Ok(input) => (input.offset.unwrap_or(0), input.glob),
let glob = match serde_json::from_value::<PathSearchToolInput>(input) {
Ok(input) => input.glob,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let path_matcher = match PathMatcher::new(&[glob.clone()]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {}", err))),
};
let snapshots: Vec<Snapshot> = project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
cx.background_spawn(async move {
let mut matches = Vec::new();
let mut matches = Vec::new();
for worktree in snapshots {
let root_name = worktree.root_name();
for worktree_handle in project.read(cx).worktrees(cx) {
let worktree = worktree_handle.read(cx);
let root_name = worktree.root_name();
// Don't consider ignored entries.
for entry in worktree.entries(false, 0) {
if path_matcher.is_match(&entry.path) {
matches.push(
PathBuf::from(root_name)
.join(&entry.path)
.to_string_lossy()
.to_string(),
);
}
// Don't consider ignored entries.
for entry in worktree.entries(false, 0) {
if path_matcher.is_match(&entry.path) {
matches.push(
PathBuf::from(root_name)
.join(&entry.path)
.to_string_lossy()
.to_string(),
);
}
}
}
if matches.is_empty() {
Ok(format!("No paths in the project matched the glob {glob:?}"))
} else {
// Sort to group entries in the same directory together.
matches.sort();
let total_matches = matches.len();
let response = if total_matches > offset + RESULTS_PER_PAGE {
let paginated_matches: Vec<_> = matches
.into_iter()
.skip(offset)
.take(RESULTS_PER_PAGE)
.collect();
format!(
"Found {} total matches. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}",
total_matches,
offset + 1,
offset + paginated_matches.len(),
paginated_matches.join("\n")
)
} else {
matches.join("\n")
};
Ok(response)
}
})
if matches.is_empty() {
Task::ready(Ok(format!(
"No paths in the project matched the glob {glob:?}"
)))
} else {
// Sort to group entries in the same directory together.
matches.sort();
Task::ready(Ok(matches.join("\n")))
}
}
}

View File

@@ -1,3 +1 @@
Returns paths in the project which match the given glob.
Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages.
Returns all the paths in the project which match the given glob.

View File

@@ -2,9 +2,8 @@ use std::path::Path;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
use itertools::Itertools;
use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
@@ -27,14 +26,6 @@ pub struct ReadFileToolInput {
/// If you wanna access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
/// </example>
pub path: Arc<Path>,
/// Optional line number to start reading on (1-based index)
#[serde(default)]
pub start_line: Option<usize>,
/// Optional line number to end reading on (1-based index)
#[serde(default)]
pub end_line: Option<usize>,
}
pub struct ReadFileTool;
@@ -58,7 +49,6 @@ impl Tool for ReadFileTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
@@ -69,44 +59,23 @@ impl Tool for ReadFileTool {
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path not found in project")));
};
cx.spawn(|mut cx| async move {
cx.spawn(|cx| async move {
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
})?
.await?;
let result = buffer.read_with(&cx, |buffer, _cx| {
buffer.read_with(&cx, |buffer, _cx| {
if buffer
.file()
.map_or(false, |file| file.disk_state().exists())
{
let text = buffer.text();
let string = if input.start_line.is_some() || input.end_line.is_some() {
let start = input.start_line.unwrap_or(1);
let lines = text.split('\n').skip(start - 1);
if let Some(end) = input.end_line {
let count = end.saturating_sub(start);
Itertools::intersperse(lines.take(count), "\n").collect()
} else {
Itertools::intersperse(lines, "\n").collect()
}
} else {
text
};
Ok(string)
Ok(buffer.text())
} else {
Err(anyhow!("File does not exist"))
}
})??;
action_log.update(&mut cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
anyhow::Ok(result)
})?
})
}
}

View File

@@ -1,16 +1,13 @@
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::Tool;
use futures::StreamExt;
use gpui::{App, Entity, Task};
use language::{OffsetRangeExt, Point};
use language::OffsetRangeExt;
use language_model::LanguageModelRequestMessage;
use project::{
search::{SearchQuery, SearchResult},
Project,
};
use project::{search::SearchQuery, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{cmp, fmt::Write, ops::Range, sync::Arc};
use std::{cmp, fmt::Write, sync::Arc};
use util::paths::PathMatcher;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -18,17 +15,8 @@ pub struct RegexSearchToolInput {
/// A regex pattern to search for in the entire project. Note that the regex
/// will be parsed by the Rust `regex` crate.
pub regex: String,
/// Optional starting position for paginated results (0-based).
/// When not provided, starts from the beginning.
#[serde(default)]
pub offset: Option<usize>,
}
const RESULTS_PER_PAGE: usize = 20;
const MAX_LINE_LENGTH: u32 = 240;
const LONG_LINE_CONTEXT: usize = 120;
pub struct RegexSearchTool;
impl Tool for RegexSearchTool {
@@ -50,22 +38,17 @@ impl Tool for RegexSearchTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
const CONTEXT_LINES: u32 = 2;
let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
Ok(input) => (input.offset.unwrap_or(0), input.regex),
let input = match serde_json::from_value::<RegexSearchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
if regex.is_empty() {
return Task::ready(Err(anyhow!("Empty regex pattern is not allowed")));
};
let query = match SearchQuery::regex(
&regex,
&input.regex,
false,
false,
false,
@@ -78,345 +61,59 @@ impl Tool for RegexSearchTool {
};
let results = project.update(cx, |project, cx| project.search(query, cx));
enum MatchRange {
Lines(Range<Point>),
LongLine(Range<Point>),
}
cx.spawn(|cx| async move {
futures::pin_mut!(results);
let mut output = String::new();
let mut skips_remaining = offset;
let mut matches_found = 0;
let mut has_more_matches = false;
while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
while let Some(project::search::SearchResult::Buffer { buffer, ranges }) =
results.next().await
{
if ranges.is_empty() {
continue;
}
buffer.read_with(&cx, |buffer, cx| -> Result<(), anyhow::Error> {
buffer.read_with(&cx, |buffer, cx| {
if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
let mut file_header_written = false;
writeln!(output, "### Found matches in {}:\n", path.display()).unwrap();
let mut ranges = ranges
.into_iter()
.map(|range| -> MatchRange {
.map(|range| {
let mut point_range = range.to_point(buffer);
let is_long_line = buffer.line_len(point_range.start.row) > MAX_LINE_LENGTH;
if is_long_line {
let line_range = Point::new(point_range.start.row, 0)..Point::new(point_range.start.row, buffer.line_len(point_range.start.row));
let line_text = buffer.text_for_range(line_range).collect::<String>();
for (match_start, match_end) in find_matches(line_text.clone(), &regex) {
let start_char = match_start.saturating_sub(LONG_LINE_CONTEXT);
let end_char = (match_end + LONG_LINE_CONTEXT).min(buffer.line_len(point_range.start.row) as usize);
writeln!(output, "\n### Line {}, chars {}-{}\n```", point_range.start.row + 1, start_char, end_char)?;
output.push_str(&line_text[start_char..end_char]);
output.push_str("\n```\n");
}
matches_found += 1;
if matches_found >= RESULTS_PER_PAGE {
has_more_matches = true;
}
Ok((Point::new(point_range.start.row + 1, 0)..Point::new(point_range.start.row + 1, 0), true))
} else {
point_range.start.row = point_range.start.row.saturating_sub(CONTEXT_LINES);
point_range.start.column = 0;
point_range.end.row = cmp::min(
buffer.max_point().row,
point_range.end.row + CONTEXT_LINES,
);
point_range.end.column = buffer.line_len(point_range.end.row);
Ok((point_range, is_long_line))
}
point_range.start.row =
point_range.start.row.saturating_sub(CONTEXT_LINES);
point_range.start.column = 0;
point_range.end.row = cmp::min(
buffer.max_point().row,
point_range.end.row + CONTEXT_LINES,
);
point_range.end.column = buffer.line_len(point_range.end.row);
point_range
})
.peekable();
while let Some(range_result) = ranges.next() {
let (range, is_long) = range_result?;
// Skip long lines as they were already handled
if is_long {
continue;
while let Some(mut range) = ranges.next() {
while let Some(next_range) = ranges.peek() {
if range.end.row >= next_range.start.row {
range.end = next_range.end;
ranges.next();
} else {
break;
}
}
if skips_remaining > 0 {
skips_remaining -= 1;
continue;
}
// We've found a full page of matches
if matches_found >= RESULTS_PER_PAGE {
has_more_matches = true;
break;
}
// Write file header if needed
if !file_header_written {
let _ = writeln!(output, "\n## Matches in {}", path.display());
file_header_written = true;
}
// Show the match with context lines
let context_start = range.start.row;
let context_end = range.end.row;
let context_range = Point::new(context_start, 0)..Point::new(context_end, buffer.line_len(context_end));
let context_text = buffer.text_for_range(context_range).collect::<String>();
if context_text.contains(&regex) {
let _ = writeln!(output, "\n### Lines {}-{}\n```", context_start + 1, context_end + 1);
output.push_str(&context_text);
output.push_str("\n```\n");
matches_found += 1;
}
writeln!(output, "```").unwrap();
output.extend(buffer.text_for_range(range));
writeln!(output, "\n```\n").unwrap();
}
}
})?;
}
if matches_found == 0 {
Ok("No matches found".to_string())
} else if has_more_matches {
Ok(format!(
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
offset + 1,
offset + matches_found,
offset + RESULTS_PER_PAGE,
))
if output.is_empty() {
Ok("No matches found".into())
} else {
Ok(format!("Found {matches_found} matches:\n{output}"))
Ok(output)
}
})
}
}
fn find_matches(text: String, pattern: &str) -> Vec<(usize, usize)> {
let mut matches = Vec::new();
if pattern.is_empty() {
return matches;
}
let mut start = 0;
while start < text.len() {
match text[start..].find(pattern) {
Some(pos) => {
let match_start = start + pos;
let match_end = match_start + pattern.len();
if match_end <= text.len() {
matches.push((match_start, match_end));
}
start = match_start + 1;
}
None => break,
}
}
matches
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_matches() {
// Test basic match finding
let text = "hello world hello".to_string();
let matches = find_matches(text, "hello");
assert_eq!(matches.len(), 2);
assert_eq!(matches[0], (0, 5));
assert_eq!(matches[1], (12, 17));
// Test overlapping matches
let overlaps = find_matches("abababa".to_string(), "aba");
assert_eq!(overlaps.len(), 3);
assert_eq!(overlaps[0], (0, 3));
assert_eq!(overlaps[1], (2, 5));
assert_eq!(overlaps[2], (4, 7));
// Test edge cases
assert_eq!(find_matches("".to_string(), "pattern"), vec![]);
assert_eq!(find_matches("text".to_string(), ""), vec![]);
}
#[test]
fn test_long_line_context_calculation() {
// Create a line that exceeds MAX_LINE_LENGTH
let prefix = "x".repeat(100);
let target = "TARGET";
let suffix = "y".repeat(300);
let long_line = format!("{}{}{}", prefix, target, suffix);
// Find the target in the long line
let matches = find_matches(long_line.clone(), target);
assert_eq!(matches.len(), 1);
let (match_start, match_end) = matches[0];
// Verify context calculation
let start_char = match_start.saturating_sub(LONG_LINE_CONTEXT);
let end_char = (match_end + LONG_LINE_CONTEXT).min(long_line.len());
// Context should start no more than LONG_LINE_CONTEXT chars before match
assert!(match_start - start_char <= LONG_LINE_CONTEXT);
// Context should end no more than LONG_LINE_CONTEXT chars after match
assert!(end_char - match_end <= LONG_LINE_CONTEXT);
// Context should contain the target
let context = &long_line[start_char..end_char];
assert!(context.contains(target));
}
#[test]
fn test_line_length_classification() {
// Test if lines are correctly classified as long or regular
let max_len = MAX_LINE_LENGTH as usize;
// Line shorter than MAX_LINE_LENGTH
let regular_line = "x".repeat(max_len - 1);
assert!(regular_line.len() < max_len);
// Line exactly at MAX_LINE_LENGTH
let boundary_line = "x".repeat(max_len);
assert_eq!(boundary_line.len(), max_len);
// Line longer than MAX_LINE_LENGTH
let long_line = "x".repeat(max_len + 1);
assert!(long_line.len() > max_len);
}
#[test]
fn test_heading_format() {
// For long lines: "### Line X, chars Y-Z"
// For regular lines: "### Lines X-Y"
let long_line_row = 42_usize;
let start_char = 100_usize;
let end_char = 340_usize;
// In the implementation, the heading format for long lines is "### Line"
let long_line_heading = format!(
"### Line {}, chars {}-{}",
long_line_row + 1,
start_char,
end_char
);
assert_eq!(long_line_heading, "### Line 43, chars 100-340");
let context_start = 40_usize;
let context_end = 44_usize;
let regular_heading = format!("### Lines {}-{}", context_start + 1, context_end + 1);
assert_eq!(regular_heading, "### Lines 41-45");
}
#[test]
fn test_pagination() {
// Initialize variables to simulate pagination
let mut skips_remaining = 2;
let mut matches_found = 0;
// Simulate processing 5 matches with offset 2
for _ in 0..5 {
if skips_remaining > 0 {
skips_remaining -= 1;
continue;
}
matches_found += 1;
}
// Should have processed 3 matches after skipping 2
assert_eq!(matches_found, 3);
assert_eq!(skips_remaining, 0);
// Test page limits
let mut matches_found = 0;
let mut has_more_matches = false;
// Simulate processing more matches than fit in a page
for _ in 0..(RESULTS_PER_PAGE + 5) {
// We'd already found a full page of matches, and we just found one more.
if matches_found >= RESULTS_PER_PAGE {
has_more_matches = true;
break;
}
matches_found += 1;
}
// Should have stopped at RESULTS_PER_PAGE
assert_eq!(matches_found, RESULTS_PER_PAGE);
assert!(has_more_matches);
}
#[test]
fn test_handling_of_very_long_lines() {
// Test very long lines with matches at different positions
// Create a very long line (3 * MAX_LINE_LENGTH)
let max_len = MAX_LINE_LENGTH as usize;
let long_prefix = "prefix_".repeat(max_len / 7);
let middle = "middle_".repeat(max_len / 7);
let long_suffix = "suffix_".repeat(max_len / 7);
let very_long_line = format!("{}{}{}", long_prefix, middle, long_suffix);
assert!(very_long_line.len() > max_len);
// Find matches for "middle" in the very long line
let matches = find_matches(very_long_line.clone(), "middle_");
assert!(!matches.is_empty());
// First match should be after the prefix
let (first_match_start, _) = matches[0];
assert!(first_match_start >= long_prefix.len());
// With 120 chars of context, we should not see the start of the string
let context_start = first_match_start.saturating_sub(LONG_LINE_CONTEXT);
assert!(context_start > 0); // Context should not include the very beginning
// But context should include the match
let context_end = first_match_start + 7 + LONG_LINE_CONTEXT; // "middle_" is 7 chars
let context = &very_long_line[context_start..context_end.min(very_long_line.len())];
assert!(context.contains("middle_"));
}
#[test]
fn test_output_format() {
// Test format consistency
// For a long line match (> 240 chars), we show:
// - Only the matched line (no context lines)
// - Format: "### Line X, chars Y-Z"
// - Context: 120 chars before and after the match
// For regular lines (< 240 chars), we show:
// - The matched line plus context (2 lines before, 2 lines after)
// - Format: "### Lines X-Y"
// Test that multiple matches in a long line are shown separately
let row = 42_usize;
let match1_start = 100_usize;
let match1_end = 110_usize;
let match2_start = 300_usize;
let match2_end = 310_usize;
let heading1 = format!(
"### Line {}, chars {}-{}",
row + 1,
match1_start.saturating_sub(LONG_LINE_CONTEXT),
(match1_end + LONG_LINE_CONTEXT)
);
let heading2 = format!(
"### Line {}, chars {}-{}",
row + 1,
match2_start.saturating_sub(LONG_LINE_CONTEXT),
(match2_end + LONG_LINE_CONTEXT)
);
// Each match should have its own heading
assert_ne!(heading1, heading2);
}
}

View File

@@ -1,5 +1,3 @@
Searches the entire project for the given regular expression.
Returns a list of paths that matched the query. For each path, it returns a list of excerpts of the matched text.
Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages.

View File

@@ -1,48 +0,0 @@
use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ThinkingToolInput {
/// Content to think about. This should be a description of what to think about or
/// a problem to solve.
content: String,
}
pub struct ThinkingTool;
impl Tool for ThinkingTool {
fn name(&self) -> String {
"thinking".to_string()
}
fn description(&self) -> String {
include_str!("./thinking_tool/description.md").to_string()
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(ThinkingToolInput);
serde_json::to_value(&schema).unwrap()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_cx: &mut App,
) -> Task<Result<String>> {
// This tool just "thinks out loud" and doesn't perform any actions.
Task::ready(match serde_json::from_value::<ThinkingToolInput>(input) {
Ok(_input) => Ok("Finished thinking.".to_string()),
Err(err) => Err(anyhow!(err)),
})
}
}

View File

@@ -1 +0,0 @@
A tool for thinking through problems, brainstorming ideas, or planning without executing any actions. Use this tool when you need to work through complex problems, develop strategies, or outline approaches before taking action.

View File

@@ -27,7 +27,6 @@ feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
http_client_tls.workspace = true
log.workspace = true
paths.workspace = true
parking_lot.workspace = true

View File

@@ -1154,7 +1154,7 @@ impl Client {
async_tungstenite::async_tls::client_async_tls_with_connector(
request,
stream,
Some(http_client_tls::tls_config().into()),
Some(http_client::tls_config().into()),
)
.await?;
Ok(Connection::new(

View File

@@ -29,12 +29,6 @@ impl std::fmt::Display for ChannelId {
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
pub struct ProjectId(pub u64);
impl ProjectId {
pub fn to_proto(&self) -> u64 {
self.0
}
}
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
)]

View File

@@ -660,10 +660,6 @@ fn for_snowflake(
e.event_type.clone(),
serde_json::to_value(&e.event_properties).unwrap(),
),
Event::AssistantThreadFeedback(e) => (
"Assistant Feedback".to_string(),
serde_json::to_value(&e).unwrap(),
),
};
if let serde_json::Value::Object(ref mut map) = event_properties {

View File

@@ -43,20 +43,6 @@ pub enum Relation {
Contributor,
}
impl Model {
/// Returns the timestamp of when the user's account was created.
///
/// This will be the earlier of the `created_at` and `github_user_created_at` timestamps.
pub fn account_created_at(&self) -> NaiveDateTime {
let mut account_created_at = self.created_at;
if let Some(github_created_at) = self.github_user_created_at {
account_created_at = account_created_at.min(github_created_at);
}
account_created_at
}
}
impl Related<super::access_token::Entity> for Entity {
fn to() -> RelationDef {
Relation::AccessToken.def()

View File

@@ -5,7 +5,6 @@ mod token;
use crate::api::events::SnowflakeRow;
use crate::api::CloudflareIpCountryHeader;
use crate::build_kinesis_client;
use crate::rpc::MIN_ACCOUNT_AGE_FOR_LLM_USE;
use crate::{db::UserId, executor::Executor, Cents, Config, Error, Result};
use anyhow::{anyhow, Context as _};
use authorization::authorize_access_to_language_model;
@@ -218,13 +217,6 @@ async fn perform_completion(
params.model,
);
let bypass_account_age_check = claims.has_llm_subscription || claims.bypass_account_age_check;
if !bypass_account_age_check {
if Utc::now().naive_utc() - claims.account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
Err(anyhow!("account too young"))?
}
}
authorize_access_to_language_model(
&state.config,
&claims,

View File

@@ -3,7 +3,7 @@ use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::Cents;
use crate::{db::billing_preference, Config};
use anyhow::{anyhow, Result};
use chrono::{NaiveDateTime, Utc};
use chrono::Utc;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
@@ -20,10 +20,9 @@ pub struct LlmTokenClaims {
pub system_id: Option<String>,
pub metrics_id: Uuid,
pub github_user_login: String,
pub account_created_at: NaiveDateTime,
pub is_staff: bool,
pub has_llm_closed_beta_feature_flag: bool,
pub bypass_account_age_check: bool,
#[serde(default)]
pub has_predict_edits_feature_flag: bool,
pub has_llm_subscription: bool,
pub max_monthly_spend_in_cents: u32,
@@ -38,7 +37,8 @@ impl LlmTokenClaims {
user: &user::Model,
is_staff: bool,
billing_preferences: Option<billing_preference::Model>,
feature_flags: &Vec<String>,
has_llm_closed_beta_feature_flag: bool,
has_predict_edits_feature_flag: bool,
has_llm_subscription: bool,
plan: rpc::proto::Plan,
system_id: Option<String>,
@@ -58,17 +58,9 @@ impl LlmTokenClaims {
system_id,
metrics_id: user.metrics_id,
github_user_login: user.github_login.clone(),
account_created_at: user.account_created_at(),
is_staff,
has_llm_closed_beta_feature_flag: feature_flags
.iter()
.any(|flag| flag == "llm-closed-beta"),
bypass_account_age_check: feature_flags
.iter()
.any(|flag| flag == "bypass-account-age-check"),
has_predict_edits_feature_flag: feature_flags
.iter()
.any(|flag| flag == "predict-edits"),
has_llm_closed_beta_feature_flag,
has_predict_edits_feature_flag,
has_llm_subscription,
max_monthly_spend_in_cents: billing_preferences
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {

View File

@@ -307,7 +307,6 @@ impl Server {
.add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
.add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
.add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
.add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
.add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
.add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
.add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
@@ -348,7 +347,6 @@ impl Server {
.add_message_handler(create_buffer_for_peer)
.add_request_handler(update_buffer)
.add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
.add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
.add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
.add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
.add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
@@ -4036,7 +4034,7 @@ async fn accept_terms_of_service(
}
/// The minimum account age an account must have in order to use the LLM service.
pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
async fn get_llm_api_token(
_request: proto::GetLlmToken,
@@ -4047,6 +4045,8 @@ async fn get_llm_api_token(
let flags = db.get_user_flags(session.user_id()).await?;
let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
if !session.is_staff() && !has_language_models_feature_flag {
Err(anyhow!("permission denied"))?
@@ -4063,13 +4063,27 @@ async fn get_llm_api_token(
}
let has_llm_subscription = session.has_llm_subscription(&db).await?;
let bypass_account_age_check =
has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
if !bypass_account_age_check {
let mut account_created_at = user.created_at;
if let Some(github_created_at) = user.github_user_created_at {
account_created_at = account_created_at.min(github_created_at);
}
if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
Err(anyhow!("account too young"))?
}
}
let billing_preferences = db.get_billing_preferences(user.id).await?;
let token = LlmTokenClaims::create(
&user,
session.is_staff(),
billing_preferences,
&flags,
has_llm_closed_beta_feature_flag,
has_predict_edits_feature_flag,
has_llm_subscription,
session.current_plan(&db).await?,
session.system_id.clone(),

View File

@@ -1337,7 +1337,7 @@ impl RandomizedTest for ProjectCollaborationTest {
let host_diff_base = host_project.read_with(host_cx, |project, cx| {
project
.git_store()
.buffer_store()
.read(cx)
.get_unstaged_diff(host_buffer.read(cx).remote_id(), cx)
.unwrap()
@@ -1346,7 +1346,7 @@ impl RandomizedTest for ProjectCollaborationTest {
});
let guest_diff_base = guest_project.read_with(client_cx, |project, cx| {
project
.git_store()
.buffer_store()
.read(cx)
.get_unstaged_diff(guest_buffer.read(cx).remote_id(), cx)
.unwrap()

View File

@@ -271,7 +271,7 @@ impl TestServer {
let git_hosting_provider_registry = cx.update(GitHostingProviderRegistry::default_global);
git_hosting_provider_registry
.register_hosting_provider(Arc::new(git_hosting_providers::Github::public_instance()));
.register_hosting_provider(Arc::new(git_hosting_providers::Github::new()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx));

View File

@@ -61,9 +61,9 @@ impl CompletionProvider for MessageEditorCompletionProvider {
_: editor::CompletionContext,
_window: &mut Window,
cx: &mut Context<Editor>,
) -> Task<Result<Option<Vec<Completion>>>> {
) -> Task<anyhow::Result<Vec<Completion>>> {
let Some(handle) = self.0.upgrade() else {
return Task::ready(Ok(None));
return Task::ready(Ok(Vec::new()));
};
handle.update(cx, |message_editor, cx| {
message_editor.completions(buffer, buffer_position, cx)
@@ -246,22 +246,20 @@ impl MessageEditor {
buffer: &Entity<Buffer>,
end_anchor: Anchor,
cx: &mut Context<Self>,
) -> Task<Result<Option<Vec<Completion>>>> {
) -> Task<Result<Vec<Completion>>> {
if let Some((start_anchor, query, candidates)) =
self.collect_mention_candidates(buffer, end_anchor, cx)
{
if !candidates.is_empty() {
return cx.spawn(|_, cx| async move {
Ok(Some(
Self::resolve_completions_for_candidates(
&cx,
query.as_str(),
&candidates,
start_anchor..end_anchor,
Self::completion_for_mention,
)
.await,
))
Ok(Self::resolve_completions_for_candidates(
&cx,
query.as_str(),
&candidates,
start_anchor..end_anchor,
Self::completion_for_mention,
)
.await)
});
}
}
@@ -271,21 +269,19 @@ impl MessageEditor {
{
if !candidates.is_empty() {
return cx.spawn(|_, cx| async move {
Ok(Some(
Self::resolve_completions_for_candidates(
&cx,
query.as_str(),
candidates,
start_anchor..end_anchor,
Self::completion_for_emoji,
)
.await,
))
Ok(Self::resolve_completions_for_candidates(
&cx,
query.as_str(),
candidates,
start_anchor..end_anchor,
Self::completion_for_emoji,
)
.await)
});
}
}
Task::ready(Ok(Some(Vec::new())))
Task::ready(Ok(vec![]))
}
async fn resolve_completions_for_candidates(

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::{anyhow, bail, Result};
use assistant_tool::{ActionLog, Tool, ToolSource};
use assistant_tool::{Tool, ToolSource};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
@@ -61,7 +61,6 @@ impl Tool for ContextServerTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {

View File

@@ -170,9 +170,7 @@ enum SignInStatus {
prompt: Option<request::PromptUserDeviceFlow>,
task: Shared<Task<Result<(), Arc<anyhow::Error>>>>,
},
SignedOut {
awaiting_signing_in: bool,
},
SignedOut,
}
#[derive(Debug, Clone)]
@@ -182,9 +180,7 @@ pub enum Status {
},
Error(Arc<str>),
Disabled,
SignedOut {
awaiting_signing_in: bool,
},
SignedOut,
SigningIn {
prompt: Option<request::PromptUserDeviceFlow>,
},
@@ -349,8 +345,8 @@ impl Copilot {
buffers: Default::default(),
_subscription: cx.on_app_quit(Self::shutdown_language_server),
};
this.start_copilot(true, false, cx);
cx.observe_global::<SettingsStore>(move |this, cx| this.start_copilot(true, false, cx))
this.enable_or_disable_copilot(cx);
cx.observe_global::<SettingsStore>(move |this, cx| this.enable_or_disable_copilot(cx))
.detach();
this
}
@@ -368,40 +364,26 @@ impl Copilot {
}
}
fn start_copilot(
&mut self,
check_edit_prediction_provider: bool,
awaiting_sign_in_after_start: bool,
cx: &mut Context<Self>,
) {
if !matches!(self.server, CopilotServer::Disabled) {
return;
}
let language_settings = all_language_settings(None, cx);
if check_edit_prediction_provider
&& language_settings.edit_predictions.provider != EditPredictionProvider::Copilot
{
return;
}
fn enable_or_disable_copilot(&mut self, cx: &mut Context<Self>) {
let server_id = self.server_id;
let http = self.http.clone();
let node_runtime = self.node_runtime.clone();
let env = self.build_env(&language_settings.edit_predictions.copilot);
let start_task = cx
.spawn(move |this, cx| {
Self::start_language_server(
server_id,
http,
node_runtime,
env,
this,
awaiting_sign_in_after_start,
cx,
)
})
.shared();
self.server = CopilotServer::Starting { task: start_task };
cx.notify();
let language_settings = all_language_settings(None, cx);
if language_settings.edit_predictions.provider == EditPredictionProvider::Copilot {
if matches!(self.server, CopilotServer::Disabled) {
let env = self.build_env(&language_settings.edit_predictions.copilot);
let start_task = cx
.spawn(move |this, cx| {
Self::start_language_server(server_id, http, node_runtime, env, this, cx)
})
.shared();
self.server = CopilotServer::Starting { task: start_task };
cx.notify();
}
} else {
self.server = CopilotServer::Disabled;
cx.notify();
}
}
fn build_env(&self, copilot_settings: &CopilotSettings) -> Option<HashMap<String, String>> {
@@ -467,7 +449,6 @@ impl Copilot {
node_runtime: NodeRuntime,
env: Option<HashMap<String, String>>,
this: WeakEntity<Self>,
awaiting_sign_in_after_start: bool,
mut cx: AsyncApp,
) {
let start_language_server = async {
@@ -541,9 +522,7 @@ impl Copilot {
Ok((server, status)) => {
this.server = CopilotServer::Running(RunningCopilotServer {
lsp: server,
sign_in_status: SignInStatus::SignedOut {
awaiting_signing_in: awaiting_sign_in_after_start,
},
sign_in_status: SignInStatus::SignedOut,
registered_buffers: Default::default(),
});
cx.emit(Event::CopilotLanguageServerStarted);
@@ -566,7 +545,7 @@ impl Copilot {
cx.notify();
task.clone()
}
SignInStatus::SignedOut { .. } | SignInStatus::Unauthorized { .. } => {
SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => {
let lsp = server.lsp.clone();
let task = cx
.spawn(|this, mut cx| async move {
@@ -654,7 +633,7 @@ impl Copilot {
anyhow::Ok(())
})
}
CopilotServer::Disabled => cx.background_spawn(async {
CopilotServer::Disabled => cx.background_spawn(async move {
clear_copilot_config_dir().await;
anyhow::Ok(())
}),
@@ -672,8 +651,7 @@ impl Copilot {
let server_id = self.server_id;
move |this, cx| async move {
clear_copilot_dir().await;
Self::start_language_server(server_id, http, node_runtime, env, this, false, cx)
.await
Self::start_language_server(server_id, http, node_runtime, env, this, cx).await
}
})
.shared();
@@ -983,11 +961,7 @@ impl Copilot {
SignInStatus::SigningIn { prompt, .. } => Status::SigningIn {
prompt: prompt.clone(),
},
SignInStatus::SignedOut {
awaiting_signing_in,
} => Status::SignedOut {
awaiting_signing_in: *awaiting_signing_in,
},
SignInStatus::SignedOut => Status::SignedOut,
}
}
}
@@ -1016,11 +990,7 @@ impl Copilot {
}
}
request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => {
if !matches!(server.sign_in_status, SignInStatus::SignedOut { .. }) {
server.sign_in_status = SignInStatus::SignedOut {
awaiting_signing_in: false,
};
}
server.sign_in_status = SignInStatus::SignedOut;
cx.emit(Event::CopilotAuthSignedOut);
for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
self.unregister_buffer(&buffer);

View File

@@ -1,10 +1,9 @@
use crate::{request::PromptUserDeviceFlow, Copilot, Status};
use gpui::{
div, percentage, svg, Animation, AnimationExt, App, ClipboardItem, Context, DismissEvent,
Element, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement,
MouseDownEvent, ParentElement, Render, Styled, Subscription, Transformation, Window,
div, App, ClipboardItem, Context, DismissEvent, Element, Entity, EventEmitter, FocusHandle,
Focusable, InteractiveElement, IntoElement, MouseDownEvent, ParentElement, Render, Styled,
Subscription, Window,
};
use std::time::Duration;
use ui::{prelude::*, Button, Label, Vector, VectorName};
use util::ResultExt as _;
use workspace::notifications::NotificationId;
@@ -18,13 +17,11 @@ pub fn initiate_sign_in(window: &mut Window, cx: &mut App) {
let Some(copilot) = Copilot::global(cx) else {
return;
};
let status = copilot.read(cx).status();
let Some(workspace) = window.root::<Workspace>().flatten() else {
return;
};
if matches!(copilot.read(cx).status(), Status::Disabled) {
copilot.update(cx, |this, cx| this.start_copilot(false, true, cx));
}
match copilot.read(cx).status() {
match status {
Status::Starting { task } => {
workspace.update(cx, |workspace, cx| {
workspace.show_toast(
@@ -57,15 +54,6 @@ pub fn initiate_sign_in(window: &mut Window, cx: &mut App) {
copilot
.update(cx, |copilot, cx| copilot.sign_in(cx))
.detach_and_log_err(cx);
if let Some(window_handle) = cx.active_window() {
window_handle
.update(cx, |_, window, cx| {
workspace.toggle_modal(window, cx, |_, cx| {
CopilotCodeVerification::new(&copilot, cx)
});
})
.log_err();
}
}
})
.log_err();
@@ -88,7 +76,6 @@ pub struct CopilotCodeVerification {
status: Status,
connect_clicked: bool,
focus_handle: FocusHandle,
copilot: Entity<Copilot>,
_subscription: Subscription,
}
@@ -99,20 +86,7 @@ impl Focusable for CopilotCodeVerification {
}
impl EventEmitter<DismissEvent> for CopilotCodeVerification {}
impl ModalView for CopilotCodeVerification {
fn on_before_dismiss(
&mut self,
_: &mut Window,
cx: &mut Context<Self>,
) -> workspace::DismissDecision {
self.copilot.update(cx, |copilot, cx| {
if matches!(copilot.status(), Status::SigningIn { .. }) {
copilot.sign_out(cx).detach_and_log_err(cx);
}
});
workspace::DismissDecision::Dismiss(true)
}
}
impl ModalView for CopilotCodeVerification {}
impl CopilotCodeVerification {
pub fn new(copilot: &Entity<Copilot>, cx: &mut Context<Self>) -> Self {
@@ -121,7 +95,6 @@ impl CopilotCodeVerification {
status,
connect_clicked: false,
focus_handle: cx.focus_handle(),
copilot: copilot.clone(),
_subscription: cx.observe(copilot, |this, copilot, cx| {
let status = copilot.read(cx).status();
match status {
@@ -207,12 +180,9 @@ impl CopilotCodeVerification {
.child(
Button::new("copilot-enable-cancel-button", "Cancel")
.full_width()
.on_click(cx.listener(|_, _, _, cx| {
cx.emit(DismissEvent);
})),
.on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
)
}
fn render_enabled_modal(cx: &mut Context<Self>) -> impl Element {
v_flex()
.gap_2()
@@ -246,27 +216,16 @@ impl CopilotCodeVerification {
)
}
fn render_loading(window: &mut Window, _: &mut Context<Self>) -> impl Element {
let loading_icon = svg()
.size_8()
.path(IconName::ArrowCircle.path())
.text_color(window.text_style().color)
.with_animation(
"icon_circle_arrow",
Animation::new(Duration::from_secs(2)).repeat(),
|svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))),
);
h_flex().justify_center().child(loading_icon)
fn render_disabled_modal() -> impl Element {
v_flex()
.child(Headline::new("Copilot is disabled").size(HeadlineSize::Large))
.child(Label::new("You can enable Copilot in your settings."))
}
}
impl Render for CopilotCodeVerification {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let prompt = match &self.status {
Status::SigningIn { prompt: None } => {
Self::render_loading(window, cx).into_any_element()
}
Status::SigningIn {
prompt: Some(prompt),
} => Self::render_prompting_modal(self.connect_clicked, prompt, cx).into_any_element(),
@@ -278,6 +237,10 @@ impl Render for CopilotCodeVerification {
self.connect_clicked = false;
Self::render_enabled_modal(cx).into_any_element()
}
Status::Disabled => {
self.connect_clicked = false;
Self::render_disabled_modal().into_any_element()
}
_ => div().into_any_element(),
};

View File

@@ -1,6 +1,6 @@
//! This module contains all actions supported by [`Editor`].
use super::*;
use gpui::{action_as, action_with_deprecated_aliases, actions};
use gpui::{action_as, action_with_deprecated_aliases};
use schemars::JsonSchema;
use util::serde::default_true;
#[derive(PartialEq, Clone, Deserialize, Default, JsonSchema)]
@@ -248,7 +248,7 @@ impl_actions!(
]
);
actions!(
gpui::actions!(
editor,
[
AcceptEditPrediction,
@@ -404,7 +404,6 @@ actions!(
ShowCharacterPalette,
ShowEditPrediction,
ShowSignatureHelp,
ShowWordCompletions,
ShuffleLines,
SortLinesCaseInsensitive,
SortLinesCaseSensitive,

View File

@@ -180,7 +180,6 @@ pub struct CompletionsMenu {
scroll_handle: UniformListScrollHandle,
resolve_completions: bool,
show_completion_documentation: bool,
pub(super) ignore_completion_provider: bool,
last_rendered_range: Rc<RefCell<Option<Range<usize>>>>,
markdown_element: Option<Entity<Markdown>>,
}
@@ -190,7 +189,6 @@ impl CompletionsMenu {
id: CompletionId,
sort_completions: bool,
show_completion_documentation: bool,
ignore_completion_provider: bool,
initial_position: Anchor,
buffer: Entity<Buffer>,
completions: Box<[Completion]>,
@@ -207,7 +205,6 @@ impl CompletionsMenu {
initial_position,
buffer,
show_completion_documentation,
ignore_completion_provider,
completions: RefCell::new(completions).into(),
match_candidates,
entries: RefCell::new(Vec::new()).into(),
@@ -269,7 +266,6 @@ impl CompletionsMenu {
scroll_handle: UniformListScrollHandle::new(),
resolve_completions: false,
show_completion_documentation: false,
ignore_completion_provider: false,
last_rendered_range: RefCell::new(None).into(),
markdown_element: None,
}

View File

@@ -240,7 +240,7 @@ pub struct BlockContext<'a, 'b> {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum BlockId {
ExcerptBoundary(ExcerptId),
ExcerptBoundary(Option<ExcerptId>),
FoldedBuffer(ExcerptId),
Custom(CustomBlockId),
}
@@ -249,9 +249,10 @@ impl From<BlockId> for ElementId {
fn from(value: BlockId) -> Self {
match value {
BlockId::Custom(CustomBlockId(id)) => ("Block", id).into(),
BlockId::ExcerptBoundary(excerpt_id) => {
("ExcerptBoundary", EntityId::from(excerpt_id)).into()
}
BlockId::ExcerptBoundary(next_excerpt) => match next_excerpt {
Some(id) => ("ExcerptBoundary", EntityId::from(id)).into(),
None => "LastExcerptBoundary".into(),
},
BlockId::FoldedBuffer(id) => ("FoldedBuffer", EntityId::from(id)).into(),
}
}
@@ -279,10 +280,12 @@ pub enum Block {
Custom(Arc<CustomBlock>),
FoldedBuffer {
first_excerpt: ExcerptInfo,
prev_excerpt: Option<ExcerptInfo>,
height: u32,
},
ExcerptBoundary {
excerpt: ExcerptInfo,
prev_excerpt: Option<ExcerptInfo>,
next_excerpt: Option<ExcerptInfo>,
height: u32,
starts_new_buffer: bool,
},
@@ -292,10 +295,9 @@ impl Block {
pub fn id(&self) -> BlockId {
match self {
Block::Custom(block) => BlockId::Custom(block.id),
Block::ExcerptBoundary {
excerpt: next_excerpt,
..
} => BlockId::ExcerptBoundary(next_excerpt.id),
Block::ExcerptBoundary { next_excerpt, .. } => {
BlockId::ExcerptBoundary(next_excerpt.as_ref().map(|info| info.id))
}
Block::FoldedBuffer { first_excerpt, .. } => BlockId::FoldedBuffer(first_excerpt.id),
}
}
@@ -318,7 +320,7 @@ impl Block {
match self {
Block::Custom(block) => matches!(block.placement, BlockPlacement::Above(_)),
Block::FoldedBuffer { .. } => false,
Block::ExcerptBoundary { .. } => true,
Block::ExcerptBoundary { next_excerpt, .. } => next_excerpt.is_some(),
}
}
@@ -326,7 +328,7 @@ impl Block {
match self {
Block::Custom(block) => matches!(block.placement, BlockPlacement::Below(_)),
Block::FoldedBuffer { .. } => false,
Block::ExcerptBoundary { .. } => false,
Block::ExcerptBoundary { next_excerpt, .. } => next_excerpt.is_none(),
}
}
@@ -345,16 +347,6 @@ impl Block {
Block::ExcerptBoundary { .. } => true,
}
}
pub fn is_buffer_header(&self) -> bool {
match self {
Block::Custom(_) => false,
Block::FoldedBuffer { .. } => true,
Block::ExcerptBoundary {
starts_new_buffer, ..
} => *starts_new_buffer,
}
}
}
impl Debug for Block {
@@ -363,21 +355,24 @@ impl Debug for Block {
Self::Custom(block) => f.debug_struct("Custom").field("block", block).finish(),
Self::FoldedBuffer {
first_excerpt,
prev_excerpt,
height,
} => f
.debug_struct("FoldedBuffer")
.field("first_excerpt", &first_excerpt)
.field("prev_excerpt", prev_excerpt)
.field("height", height)
.finish(),
Self::ExcerptBoundary {
starts_new_buffer,
excerpt,
height,
next_excerpt,
prev_excerpt,
..
} => f
.debug_struct("ExcerptBoundary")
.field("excerpt", excerpt)
.field("prev_excerpt", prev_excerpt)
.field("next_excerpt", next_excerpt)
.field("starts_new_buffer", starts_new_buffer)
.field("height", height)
.finish(),
}
}
@@ -729,13 +724,23 @@ impl BlockMap {
std::iter::from_fn(move || {
let excerpt_boundary = boundaries.next()?;
let wrap_row = wrap_snapshot
.make_wrap_point(Point::new(excerpt_boundary.row.0, 0), Bias::Left)
.row();
let wrap_row = if excerpt_boundary.next.is_some() {
wrap_snapshot.make_wrap_point(Point::new(excerpt_boundary.row.0, 0), Bias::Left)
} else {
wrap_snapshot.make_wrap_point(
Point::new(
excerpt_boundary.row.0,
buffer.line_len(excerpt_boundary.row),
),
Bias::Left,
)
}
.row();
let new_buffer_id = match (&excerpt_boundary.prev, &excerpt_boundary.next) {
(None, next) => Some(next.buffer_id),
(Some(prev), next) => {
(_, None) => None,
(None, Some(next)) => Some(next.buffer_id),
(Some(prev), Some(next)) => {
if prev.buffer_id != next.buffer_id {
Some(next.buffer_id)
} else {
@@ -744,18 +749,24 @@ impl BlockMap {
}
};
let prev_excerpt = excerpt_boundary
.prev
.filter(|prev| !folded_buffers.contains(&prev.buffer_id));
let mut height = 0;
if let Some(new_buffer_id) = new_buffer_id {
let first_excerpt = excerpt_boundary.next.clone();
let first_excerpt = excerpt_boundary.next.clone().unwrap();
if folded_buffers.contains(&new_buffer_id) {
let mut last_excerpt_end_row = first_excerpt.end_row;
while let Some(next_boundary) = boundaries.peek() {
if next_boundary.next.buffer_id == new_buffer_id {
last_excerpt_end_row = next_boundary.next.end_row;
} else {
break;
if let Some(next_excerpt_boundary) = &next_boundary.next {
if next_excerpt_boundary.buffer_id == new_buffer_id {
last_excerpt_end_row = next_excerpt_boundary.end_row;
} else {
break;
}
}
boundaries.next();
@@ -774,6 +785,7 @@ impl BlockMap {
return Some((
BlockPlacement::Replace(WrapRow(wrap_row)..=WrapRow(wrap_end_row)),
Block::FoldedBuffer {
prev_excerpt,
height: height + buffer_header_height,
first_excerpt,
},
@@ -781,16 +793,27 @@ impl BlockMap {
}
}
if new_buffer_id.is_some() {
height += buffer_header_height;
} else {
height += excerpt_header_height;
if excerpt_boundary.next.is_some() {
if new_buffer_id.is_some() {
height += buffer_header_height;
} else {
height += excerpt_header_height;
}
}
if height == 0 {
return None;
}
Some((
BlockPlacement::Above(WrapRow(wrap_row)),
if excerpt_boundary.next.is_some() {
BlockPlacement::Above(WrapRow(wrap_row))
} else {
BlockPlacement::Below(WrapRow(wrap_row))
},
Block::ExcerptBoundary {
excerpt: excerpt_boundary.next,
prev_excerpt,
next_excerpt: excerpt_boundary.next,
height,
starts_new_buffer: new_buffer_id.is_some(),
},
@@ -838,14 +861,31 @@ impl BlockMap {
placement_comparison.then_with(|| match (block_a, block_b) {
(
Block::ExcerptBoundary {
excerpt: excerpt_a, ..
next_excerpt: next_excerpt_a,
..
},
Block::ExcerptBoundary {
excerpt: excerpt_b, ..
next_excerpt: next_excerpt_b,
..
},
) => Some(excerpt_a.id).cmp(&Some(excerpt_b.id)),
(Block::ExcerptBoundary { .. }, Block::Custom(_)) => Ordering::Less,
(Block::Custom(_), Block::ExcerptBoundary { .. }) => Ordering::Greater,
) => next_excerpt_a
.as_ref()
.map(|excerpt| excerpt.id)
.cmp(&next_excerpt_b.as_ref().map(|excerpt| excerpt.id)),
(Block::ExcerptBoundary { next_excerpt, .. }, Block::Custom(_)) => {
if next_excerpt.is_some() {
Ordering::Less
} else {
Ordering::Greater
}
}
(Block::Custom(_), Block::ExcerptBoundary { next_excerpt, .. }) => {
if next_excerpt.is_some() {
Ordering::Greater
} else {
Ordering::Less
}
}
(Block::Custom(block_a), Block::Custom(block_b)) => block_a
.priority
.cmp(&block_b.priority)
@@ -1365,19 +1405,51 @@ impl BlockSnapshot {
pub fn sticky_header_excerpt(&self, position: f32) -> Option<StickyHeaderExcerpt<'_>> {
let top_row = position as u32;
let mut cursor = self.transforms.cursor::<BlockRow>(&());
cursor.seek(&BlockRow(top_row), Bias::Right, &());
cursor.seek(&BlockRow(top_row), Bias::Left, &());
while let Some(transform) = cursor.item() {
let start = cursor.start().0;
let end = cursor.end(&()).0;
match &transform.block {
Some(Block::ExcerptBoundary { excerpt, .. }) => {
return Some(StickyHeaderExcerpt { excerpt })
Some(Block::ExcerptBoundary {
prev_excerpt,
next_excerpt,
starts_new_buffer,
..
}) => {
let matches_start = (start as f32) < position;
if matches_start && top_row <= end {
return next_excerpt.as_ref().map(|excerpt| StickyHeaderExcerpt {
next_buffer_row: None,
excerpt,
});
}
let next_buffer_row = if *starts_new_buffer { Some(end) } else { None };
return prev_excerpt.as_ref().map(|excerpt| StickyHeaderExcerpt {
excerpt,
next_buffer_row,
});
}
Some(block) if block.is_buffer_header() => return None,
_ => {
cursor.prev(&());
continue;
Some(Block::FoldedBuffer {
prev_excerpt: Some(excerpt),
..
}) if top_row <= start => {
return Some(StickyHeaderExcerpt {
next_buffer_row: Some(end),
excerpt,
});
}
Some(Block::FoldedBuffer { .. }) | Some(Block::Custom(_)) | None => {}
}
// This is needed to iterate past None / FoldedBuffer / Custom blocks. For FoldedBuffer,
// if scrolled slightly past the header of a folded block, the next block is needed for
// the sticky header.
cursor.next(&());
}
None
@@ -1391,9 +1463,14 @@ impl BlockSnapshot {
return Some(Block::Custom(custom_block.clone()));
}
BlockId::ExcerptBoundary(next_excerpt_id) => {
let excerpt_range = buffer.range_for_excerpt(next_excerpt_id)?;
self.wrap_snapshot
.make_wrap_point(excerpt_range.start, Bias::Left)
if let Some(next_excerpt_id) = next_excerpt_id {
let excerpt_range = buffer.range_for_excerpt(next_excerpt_id)?;
self.wrap_snapshot
.make_wrap_point(excerpt_range.start, Bias::Left)
} else {
self.wrap_snapshot
.make_wrap_point(buffer.max_point(), Bias::Left)
}
}
BlockId::FoldedBuffer(excerpt_id) => self
.wrap_snapshot
@@ -1671,6 +1748,7 @@ impl BlockChunks<'_> {
pub struct StickyHeaderExcerpt<'a> {
pub excerpt: &'a ExcerptInfo,
pub next_buffer_row: Option<u32>,
}
impl<'a> Iterator for BlockChunks<'a> {
@@ -2176,9 +2254,9 @@ mod tests {
assert_eq!(
blocks,
vec![
(0..1, BlockId::ExcerptBoundary(excerpt_ids[0])), // path, header
(3..4, BlockId::ExcerptBoundary(excerpt_ids[1])), // path, header
(6..7, BlockId::ExcerptBoundary(excerpt_ids[2])), // path, header
(0..1, BlockId::ExcerptBoundary(Some(excerpt_ids[0]))), // path, header
(3..4, BlockId::ExcerptBoundary(Some(excerpt_ids[1]))), // path, header
(6..7, BlockId::ExcerptBoundary(Some(excerpt_ids[2]))), // path, header
]
);
}
@@ -2875,7 +2953,10 @@ mod tests {
.iter()
.filter(|(_, block)| {
match block {
Block::FoldedBuffer { .. } => true,
Block::FoldedBuffer { prev_excerpt, .. } => {
assert!(prev_excerpt.is_none());
true
}
_ => false,
}
})

View File

@@ -69,7 +69,7 @@ pub use element::{
CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition,
};
use futures::{
future::{self, join, Shared},
future::{self, Shared},
FutureExt,
};
use fuzzy::StringMatchCandidate;
@@ -82,10 +82,10 @@ use code_context_menus::{
use git::blame::GitBlame;
use gpui::{
div, impl_actions, point, prelude::*, pulsating_between, px, relative, size, Action, Animation,
AnimationExt, AnyElement, App, AppContext, AsyncWindowContext, AvailableSpace, Background,
Bounds, ClipboardEntry, ClipboardItem, Context, DispatchPhase, Edges, Entity,
EntityInputHandler, EventEmitter, FocusHandle, FocusOutEvent, Focusable, FontId, FontWeight,
Global, HighlightStyle, Hsla, KeyContext, Modifiers, MouseButton, MouseDownEvent, PaintQuad,
AnimationExt, AnyElement, App, AsyncWindowContext, AvailableSpace, Background, Bounds,
ClipboardEntry, ClipboardItem, Context, DispatchPhase, Edges, Entity, EntityInputHandler,
EventEmitter, FocusHandle, FocusOutEvent, Focusable, FontId, FontWeight, Global,
HighlightStyle, Hsla, KeyContext, Modifiers, MouseButton, MouseDownEvent, PaintQuad,
ParentElement, Pixels, Render, SharedString, Size, Stateful, Styled, StyledText, Subscription,
Task, TextStyle, TextStyleRefinement, UTF16Selection, UnderlineStyle, UniformListScrollHandle,
WeakEntity, WeakFocusHandle, Window,
@@ -106,7 +106,7 @@ use language::{
point_from_lsp, text_diff_with_options, AutoindentMode, BracketMatch, BracketPair, Buffer,
Capability, CharKind, CodeLabel, CursorShape, Diagnostic, DiffOptions, EditPredictionsMode,
EditPreview, HighlightedText, IndentKind, IndentSize, Language, OffsetRangeExt, Point,
Selection, SelectionGoal, TextObject, TransactionId, TreeSitterOptions, WordsQuery,
Selection, SelectionGoal, TextObject, TransactionId, TreeSitterOptions,
};
use language::{point_to_lsp, BufferRow, CharClassifier, Runnable, RunnableRange};
use linked_editing_ranges::refresh_linked_ranges;
@@ -134,7 +134,7 @@ pub use multi_buffer::{
};
use multi_buffer::{
ExcerptInfo, ExpandExcerptDirection, MultiBufferDiffHunk, MultiBufferPoint, MultiBufferRow,
MultiOrSingleBufferOffsetRange, ToOffsetUtf16,
MultiOrSingleBufferOffsetRange, PathKey, ToOffsetUtf16,
};
use project::{
lsp_store::{CompletionDocumentation, FormatTrigger, LspFormatTarget, OpenLspBufferHandle},
@@ -211,7 +211,6 @@ pub(crate) const SCROLL_CENTER_TOP_BOTTOM_DEBOUNCE_TIMEOUT: Duration = Duration:
pub(crate) const EDIT_PREDICTION_KEY_CONTEXT: &str = "edit_prediction";
pub(crate) const EDIT_PREDICTION_CONFLICT_KEY_CONTEXT: &str = "edit_prediction_conflict";
pub(crate) const MIN_LINE_NUMBER_DIGITS: u32 = 4;
const COLUMNAR_SELECTION_MODIFIERS: Modifiers = Modifiers {
alt: true,
@@ -1234,15 +1233,11 @@ impl Editor {
project_subscriptions.push(cx.subscribe_in(
project,
window,
|editor, _, event, window, cx| match event {
project::Event::RefreshCodeLens => {
// we always query lens with actions, without storing them, always refreshing them
}
project::Event::RefreshInlayHints => {
|editor, _, event, window, cx| {
if let project::Event::RefreshInlayHints = event {
editor
.refresh_inlay_hints(InlayHintRefreshReason::RefreshRequested, cx);
}
project::Event::SnippetEdit(id, snippet_edits) => {
} else if let project::Event::SnippetEdit(id, snippet_edits) = event {
if let Some(buffer) = editor.buffer.read(cx).buffer(*id) {
let focus_handle = editor.focus_handle(cx);
if focus_handle.is_focused(window) {
@@ -1262,7 +1257,6 @@ impl Editor {
}
}
}
_ => {}
},
));
if let Some(task_inventory) = project
@@ -2924,17 +2918,6 @@ impl Editor {
.next()
.map_or(true, |c| scope.should_autoclose_before(c));
let preceding_text_allows_autoclose = selection.start.column == 0
|| snapshot.reversed_chars_at(selection.start).next().map_or(
true,
|c| {
bracket_pair.start != bracket_pair.end
|| !snapshot
.char_classifier_at(selection.start)
.is_word(c)
},
);
let is_closing_quote = if bracket_pair.end == bracket_pair.start
&& bracket_pair.start.len() == 1
{
@@ -2952,7 +2935,6 @@ impl Editor {
if autoclose
&& bracket_pair.close
&& following_text_allows_autoclose
&& preceding_text_allows_autoclose
&& !is_closing_quote
{
let anchor = snapshot.anchor_before(selection.end);
@@ -3546,21 +3528,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
let ignore_completion_provider = self
.context_menu
.borrow()
.as_ref()
.map(|menu| match menu {
CodeContextMenu::Completions(completions_menu) => {
completions_menu.ignore_completion_provider
}
CodeContextMenu::CodeActions(_) => false,
})
.unwrap_or(false);
if ignore_completion_provider {
self.show_word_completions(&ShowWordCompletions, window, cx);
} else if self.is_completion_trigger(text, trigger_in_words, cx) {
if self.is_completion_trigger(text, trigger_in_words, cx) {
self.show_completions(
&ShowCompletions {
trigger: Some(text.to_owned()).filter(|x| !x.is_empty()),
@@ -4009,34 +3977,20 @@ impl Editor {
}))
}
pub fn show_word_completions(
&mut self,
_: &ShowWordCompletions,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.open_completions_menu(true, None, window, cx);
}
pub fn show_completions(
&mut self,
options: &ShowCompletions,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.open_completions_menu(false, options.trigger.as_deref(), window, cx);
}
fn open_completions_menu(
&mut self,
ignore_completion_provider: bool,
trigger: Option<&str>,
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.pending_rename.is_some() {
return;
}
let Some(provider) = self.completion_provider.as_ref() else {
return;
};
if !self.snippet_stack.is_empty() && self.context_menu.borrow().as_ref().is_some() {
return;
}
@@ -4058,14 +4012,14 @@ impl Editor {
let query = Self::completion_query(&self.buffer.read(cx).read(cx), position);
let trigger_kind = match trigger {
let trigger_kind = match &options.trigger {
Some(trigger) if buffer.read(cx).completion_triggers().contains(trigger) => {
CompletionTriggerKind::TRIGGER_CHARACTER
}
_ => CompletionTriggerKind::INVOKED,
};
let completion_context = CompletionContext {
trigger_character: trigger.and_then(|trigger| {
trigger_character: options.trigger.as_ref().and_then(|trigger| {
if trigger_kind == CompletionTriggerKind::TRIGGER_CHARACTER {
Some(String::from(trigger))
} else {
@@ -4074,7 +4028,8 @@ impl Editor {
}),
trigger_kind,
};
let completions =
provider.completions(&buffer, buffer_position, completion_context, window, cx);
let (old_range, word_kind) = buffer_snapshot.surrounding_word(buffer_position);
let (old_range, word_to_exclude) = if word_kind == Some(CharKind::Word) {
let word_to_exclude = buffer_snapshot
@@ -4112,49 +4067,15 @@ impl Editor {
);
let word_search_range = buffer_snapshot.point_to_offset(min_word_search)
..buffer_snapshot.point_to_offset(max_word_search);
let provider = self
.completion_provider
.as_ref()
.filter(|_| !ignore_completion_provider);
let skip_digits = query
.as_ref()
.map_or(true, |query| !query.chars().any(|c| c.is_digit(10)));
let (mut words, provided_completions) = match provider {
Some(provider) => {
let completions =
provider.completions(&buffer, buffer_position, completion_context, window, cx);
let words = match completion_settings.words {
WordsCompletionMode::Disabled => Task::ready(HashMap::default()),
WordsCompletionMode::Enabled | WordsCompletionMode::Fallback => cx
.background_spawn(async move {
buffer_snapshot.words_in_range(WordsQuery {
fuzzy_contents: None,
range: word_search_range,
skip_digits,
})
}),
};
(words, completions)
}
None => (
let words = match completion_settings.words {
WordsCompletionMode::Disabled => Task::ready(HashMap::default()),
WordsCompletionMode::Enabled | WordsCompletionMode::Fallback => {
cx.background_spawn(async move {
buffer_snapshot.words_in_range(WordsQuery {
fuzzy_contents: None,
range: word_search_range,
skip_digits,
})
}),
Task::ready(Ok(None)),
),
buffer_snapshot.words_in_range(None, word_search_range)
})
}
};
let sort_completions = provider
.as_ref()
.map_or(true, |provider| provider.sort_completions());
let sort_completions = provider.sort_completions();
let id = post_inc(&mut self.next_completion_id);
let task = cx.spawn_in(window, |editor, mut cx| {
@@ -4162,34 +4083,55 @@ impl Editor {
editor.update(&mut cx, |this, _| {
this.completion_tasks.retain(|(task_id, _)| *task_id >= id);
})?;
let mut completions = completions.await.log_err().unwrap_or_default();
let mut completions = Vec::new();
if let Some(provided_completions) = provided_completions.await.log_err().flatten() {
completions.extend(provided_completions);
if completion_settings.words == WordsCompletionMode::Fallback {
words = Task::ready(HashMap::default());
match completion_settings.words {
WordsCompletionMode::Enabled => {
let mut words = words.await;
if let Some(word_to_exclude) = &word_to_exclude {
words.remove(word_to_exclude);
}
for lsp_completion in &completions {
words.remove(&lsp_completion.new_text);
}
completions.extend(words.into_iter().map(|(word, word_range)| {
Completion {
old_range: old_range.clone(),
new_text: word.clone(),
label: CodeLabel::plain(word, None),
documentation: None,
source: CompletionSource::BufferWord {
word_range,
resolved: false,
},
confirm: None,
}
}));
}
WordsCompletionMode::Fallback => {
if completions.is_empty() {
completions.extend(
words
.await
.into_iter()
.filter(|(word, _)| word_to_exclude.as_ref() != Some(word))
.map(|(word, word_range)| Completion {
old_range: old_range.clone(),
new_text: word.clone(),
label: CodeLabel::plain(word, None),
documentation: None,
source: CompletionSource::BufferWord {
word_range,
resolved: false,
},
confirm: None,
}),
);
}
}
WordsCompletionMode::Disabled => {}
}
let mut words = words.await;
if let Some(word_to_exclude) = &word_to_exclude {
words.remove(word_to_exclude);
}
for lsp_completion in &completions {
words.remove(&lsp_completion.new_text);
}
completions.extend(words.into_iter().map(|(word, word_range)| Completion {
old_range: old_range.clone(),
new_text: word.clone(),
label: CodeLabel::plain(word, None),
documentation: None,
source: CompletionSource::BufferWord {
word_range,
resolved: false,
},
confirm: None,
}));
let menu = if completions.is_empty() {
None
} else {
@@ -4197,7 +4139,6 @@ impl Editor {
id,
sort_completions,
show_completion_documentation,
ignore_completion_provider,
position,
buffer.clone(),
completions.into(),
@@ -4247,7 +4188,7 @@ impl Editor {
}
})?;
anyhow::Ok(())
Ok::<_, anyhow::Error>(())
}
.log_err()
});
@@ -4718,18 +4659,18 @@ impl Editor {
let mut multibuffer = MultiBuffer::new(Capability::ReadWrite).with_title(title);
for (buffer_handle, transaction) in &entries {
let buffer = buffer_handle.read(cx);
ranges_to_highlight.extend(
multibuffer.push_excerpts_with_context_lines(
buffer_handle.clone(),
buffer
.edited_ranges_for_transaction::<usize>(transaction)
.collect(),
DEFAULT_MULTIBUFFER_CONTEXT,
cx,
),
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer_handle.clone(),
buffer
.edited_ranges_for_transaction::<Point>(transaction)
.collect(),
DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
}
multibuffer.push_transaction(entries.iter().map(|(b, t)| (b, t)), cx);
ranges_to_highlight.extend(multibuffer.primary_excerpt_ranges(cx));
multibuffer
})?;
@@ -4970,9 +4911,6 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Editor>,
) {
if matches!(self.mode, EditorMode::SingleLine { .. }) {
return;
}
self.selection_highlight_task.take();
if !EditorSettings::get_global(cx).selection_highlight {
self.clear_background_highlights::<SelectedTextHighlight>(cx);
@@ -5221,9 +5159,6 @@ impl Editor {
cx: &App,
) -> bool {
maybe!({
if self.read_only(cx) {
return Some(false);
}
let provider = self.edit_prediction_provider()?;
if !provider.is_enabled(&buffer, buffer_position, cx) {
return Some(false);
@@ -12393,20 +12328,20 @@ impl Editor {
// If there are multiple definitions, open them in a multibuffer
locations.sort_by_key(|location| location.buffer.read(cx).remote_id());
let mut locations = locations.into_iter().peekable();
let mut ranges = Vec::new();
let capability = workspace.project().read(cx).capability();
let mut ranges = Vec::new();
let excerpt_buffer = cx.new(|cx| {
let mut multibuffer = MultiBuffer::new(capability);
while let Some(location) = locations.next() {
let buffer = location.buffer.read(cx);
let mut ranges_for_buffer = Vec::new();
let range = location.range.to_offset(buffer);
let range = location.range.to_point(buffer);
ranges_for_buffer.push(range.clone());
while let Some(next_location) = locations.peek() {
if next_location.buffer == location.buffer {
ranges_for_buffer.push(next_location.range.to_offset(buffer));
ranges_for_buffer.push(next_location.range.to_point(buffer));
locations.next();
} else {
break;
@@ -12414,14 +12349,16 @@ impl Editor {
}
ranges_for_buffer.sort_by_key(|range| (range.start, Reverse(range.end)));
ranges.extend(multibuffer.push_excerpts_with_context_lines(
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
location.buffer.clone(),
ranges_for_buffer,
DEFAULT_MULTIBUFFER_CONTEXT,
cx,
))
);
}
ranges.extend(multibuffer.primary_excerpt_ranges(cx));
multibuffer.with_title(title)
});
@@ -12464,22 +12401,14 @@ impl Editor {
if split {
workspace.split_item(SplitDirection::Right, item.clone(), window, cx);
} else {
if PreviewTabsSettings::get_global(cx).enable_preview_from_code_navigation {
let (preview_item_id, preview_item_idx) =
workspace.active_pane().update(cx, |pane, _| {
(pane.preview_item_id(), pane.preview_item_idx())
});
workspace.add_item_to_active_pane(item.clone(), preview_item_idx, true, window, cx);
if let Some(preview_item_id) = preview_item_id {
workspace.active_pane().update(cx, |pane, cx| {
pane.remove_item(preview_item_id, false, false, window, cx);
});
let destination_index = workspace.active_pane().update(cx, |pane, cx| {
if PreviewTabsSettings::get_global(cx).enable_preview_from_code_navigation {
pane.close_current_preview_item(window, cx)
} else {
None
}
} else {
workspace.add_item_to_active_pane(item.clone(), None, true, window, cx);
}
});
workspace.add_item_to_active_pane(item.clone(), destination_index, true, window, cx);
}
workspace.active_pane().update(cx, |pane, cx| {
pane.set_preview_item_id(Some(item_id), cx);
@@ -12867,11 +12796,11 @@ impl Editor {
cx.spawn_in(window, |_, mut cx| async move {
let transaction = futures::select_biased! {
transaction = format.log_err().fuse() => transaction,
() = timeout => {
log::warn!("timed out waiting for formatting");
None
}
transaction = format.log_err().fuse() => transaction,
};
buffer
@@ -14044,6 +13973,8 @@ impl Editor {
self.change_selections(Some(autoscroll), window, cx, |s| {
s.select_ranges([destination..destination]);
});
} else if all_diff_hunks_expanded {
window.dispatch_action(::git::ExpandCommitEditor.boxed_clone(), cx);
}
}
@@ -16984,7 +16915,7 @@ pub trait CompletionProvider {
trigger: CompletionContext,
window: &mut Window,
cx: &mut Context<Editor>,
) -> Task<Result<Option<Vec<Completion>>>>;
) -> Task<Result<Vec<Completion>>>;
fn resolve_completions(
&self,
@@ -17054,16 +16985,7 @@ impl CodeActionProvider for Entity<Project> {
cx: &mut App,
) -> Task<Result<Vec<CodeAction>>> {
self.update(cx, |project, cx| {
let code_lens = project.code_lens(buffer, range.clone(), cx);
let code_actions = project.code_actions(buffer, range, None, cx);
cx.background_spawn(async move {
let (code_lens, code_actions) = join(code_lens, code_actions).await;
Ok(code_lens
.context("code lens fetch")?
.into_iter()
.chain(code_actions.context("code action fetch")?)
.collect())
})
project.code_actions(buffer, range, None, cx)
})
}
@@ -17233,25 +17155,15 @@ impl CompletionProvider for Entity<Project> {
options: CompletionContext,
_window: &mut Window,
cx: &mut Context<Editor>,
) -> Task<Result<Option<Vec<Completion>>>> {
) -> Task<Result<Vec<Completion>>> {
self.update(cx, |project, cx| {
let snippets = snippet_completions(project, buffer, buffer_position, cx);
let project_completions = project.completions(buffer, buffer_position, options, cx);
cx.background_spawn(async move {
let mut completions = project_completions.await?;
let snippets_completions = snippets.await?;
match project_completions.await? {
Some(mut completions) => {
completions.extend(snippets_completions);
Ok(Some(completions))
}
None => {
if snippets_completions.is_empty() {
Ok(None)
} else {
Ok(Some(snippets_completions))
}
}
}
completions.extend(snippets_completions);
Ok(completions)
})
})
}
@@ -17634,7 +17546,7 @@ impl EditorSnapshot {
.unwrap_or(gutter_settings.line_numbers);
let line_gutter_width = if show_line_numbers {
// Avoid flicker-like gutter resizes when the line number gains another digit and only resize the gutter on files with N*10^5 lines.
let min_width_for_number_on_gutter = em_advance * MIN_LINE_NUMBER_DIGITS as f32;
let min_width_for_number_on_gutter = em_advance * 4.0;
max_line_number_width.max(min_width_for_number_on_gutter)
} else {
0.0.into()
@@ -17663,12 +17575,8 @@ impl EditorSnapshot {
em_advance * max_char_count
});
let is_singleton = self.buffer_snapshot.is_singleton();
let mut left_padding = git_blame_entries_width.unwrap_or(Pixels::ZERO);
left_padding += if !is_singleton {
em_width * 4.0
} else if show_code_actions || show_runnables {
left_padding += if show_code_actions || show_runnables {
em_width * 3.0
} else if show_git_gutter && show_line_numbers {
em_width * 2.0
@@ -17678,11 +17586,9 @@ impl EditorSnapshot {
px(0.)
};
let shows_folds = is_singleton && gutter_settings.folds;
let right_padding = if shows_folds && show_line_numbers {
let right_padding = if gutter_settings.folds && show_line_numbers {
em_width * 4.0
} else if shows_folds || (!is_singleton && show_line_numbers) {
} else if gutter_settings.folds {
em_width * 3.0
} else if show_line_numbers {
em_width

View File

@@ -6357,29 +6357,12 @@ async fn test_autoclose_and_auto_surround_pairs(cx: &mut TestAppContext) {
cx.update_editor(|editor, window, cx| editor.handle_input("{", window, cx));
cx.assert_editor_state("{«aˇ»} b");
// Autoclose when not immediately after a word character
cx.set_state("a ˇ");
cx.update_editor(|editor, window, cx| editor.handle_input("\"", window, cx));
cx.assert_editor_state("a \"ˇ\"");
// Autoclose pair where the start and end characters are the same
cx.update_editor(|editor, window, cx| editor.handle_input("\"", window, cx));
cx.assert_editor_state("a \"\"ˇ");
// Don't autoclose when immediately after a word character
// Autclose pair where the start and end characters are the same
cx.set_state("");
cx.update_editor(|editor, window, cx| editor.handle_input("\"", window, cx));
cx.assert_editor_state("a\"ˇ");
// Do autoclose when after a non-word character
cx.set_state("");
cx.assert_editor_state("a\"ˇ\"");
cx.update_editor(|editor, window, cx| editor.handle_input("\"", window, cx));
cx.assert_editor_state("{\"ˇ\"");
// Non identical pairs autoclose regardless of preceding character
cx.set_state("");
cx.update_editor(|editor, window, cx| editor.handle_input("{", window, cx));
cx.assert_editor_state("a{ˇ}");
cx.assert_editor_state("a\"\"ˇ");
// Don't autoclose pair if autoclose is disabled
cx.set_state("ˇ");
@@ -9212,7 +9195,7 @@ async fn test_completion(cx: &mut TestAppContext) {
}
#[gpui::test]
async fn test_word_completion(cx: &mut TestAppContext) {
async fn test_words_completion(cx: &mut TestAppContext) {
let lsp_fetch_timeout_ms = 10;
init_test(cx, |language_settings| {
language_settings.defaults.completions = Some(CompletionSettings {
@@ -9354,7 +9337,7 @@ async fn test_word_completions_do_not_duplicate_lsp_ones(cx: &mut TestAppContext
cx.executor().run_until_parked();
cx.condition(|editor, _| editor.context_menu_visible())
.await;
cx.update_editor(|editor, _, _| {
cx.update_editor(|editor, window, cx| {
if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref()
{
assert_eq!(
@@ -9365,140 +9348,8 @@ async fn test_word_completions_do_not_duplicate_lsp_ones(cx: &mut TestAppContext
} else {
panic!("expected completion menu to be open");
}
});
}
#[gpui::test]
async fn test_word_completions_continue_on_typing(cx: &mut TestAppContext) {
init_test(cx, |language_settings| {
language_settings.defaults.completions = Some(CompletionSettings {
words: WordsCompletionMode::Disabled,
lsp: true,
lsp_fetch_timeout_ms: 0,
});
});
let mut cx = EditorLspTestContext::new_rust(
lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions {
trigger_characters: Some(vec![".".to_string(), ":".to_string()]),
..lsp::CompletionOptions::default()
}),
signature_help_provider: Some(lsp::SignatureHelpOptions::default()),
..lsp::ServerCapabilities::default()
},
cx,
)
.await;
let _completion_requests_handler =
cx.lsp
.server
.on_request::<lsp::request::Completion, _, _>(move |_, _| async move {
panic!("LSP completions should not be queried when dealing with word completions")
});
cx.set_state(indoc! {"ˇ
first
last
second
"});
cx.update_editor(|editor, window, cx| {
editor.show_word_completions(&ShowWordCompletions, window, cx);
});
cx.executor().run_until_parked();
cx.condition(|editor, _| editor.context_menu_visible())
.await;
cx.update_editor(|editor, _, _| {
if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref()
{
assert_eq!(
completion_menu_entries(&menu),
&["first", "last", "second"],
"`ShowWordCompletions` action should show word completions"
);
} else {
panic!("expected completion menu to be open");
}
});
cx.simulate_keystroke("s");
cx.executor().run_until_parked();
cx.condition(|editor, _| editor.context_menu_visible())
.await;
cx.update_editor(|editor, _, _| {
if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref()
{
assert_eq!(
completion_menu_entries(&menu),
&["second"],
"After showing word completions, further editing should filter them and not query the LSP"
);
} else {
panic!("expected completion menu to be open");
}
});
}
#[gpui::test]
async fn test_word_completions_usually_skip_digits(cx: &mut TestAppContext) {
init_test(cx, |language_settings| {
language_settings.defaults.completions = Some(CompletionSettings {
words: WordsCompletionMode::Fallback,
lsp: false,
lsp_fetch_timeout_ms: 0,
});
});
let mut cx = EditorLspTestContext::new_rust(lsp::ServerCapabilities::default(), cx).await;
cx.set_state(indoc! {"ˇ
0_usize
let
33
4.5f32
"});
cx.update_editor(|editor, window, cx| {
editor.show_completions(&ShowCompletions::default(), window, cx);
});
cx.executor().run_until_parked();
cx.condition(|editor, _| editor.context_menu_visible())
.await;
cx.update_editor(|editor, window, cx| {
if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref()
{
assert_eq!(
completion_menu_entries(&menu),
&["let"],
"With no digits in the completion query, no digits should be in the word completions"
);
} else {
panic!("expected completion menu to be open");
}
editor.cancel(&Cancel, window, cx);
});
cx.set_state(indoc! {"
0_usize
let
3
33.35f32
"});
cx.update_editor(|editor, window, cx| {
editor.show_completions(&ShowCompletions::default(), window, cx);
});
cx.executor().run_until_parked();
cx.condition(|editor, _| editor.context_menu_visible())
.await;
cx.update_editor(|editor, _, _| {
if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref()
{
assert_eq!(completion_menu_entries(&menu), &["33", "35f32"], "The digit is in the completion query, \
return matching words with digits (`33`, `35f32`) but exclude query duplicates (`3`)");
} else {
panic!("expected completion menu to be open");
}
});
}
#[gpui::test]
@@ -15797,7 +15648,7 @@ async fn test_display_diff_hunks(cx: &mut TestAppContext) {
for buffer in &buffers {
let snapshot = buffer.read(cx).snapshot();
multibuffer.set_excerpts_for_path(
PathKey::namespaced("", buffer.read(cx).file().unwrap().path().clone()),
PathKey::namespaced(0, buffer.read(cx).file().unwrap().path().clone()),
buffer.clone(),
vec![text::Anchor::MIN.to_point(&snapshot)..text::Anchor::MAX.to_point(&snapshot)],
DEFAULT_MULTIBUFFER_CONTEXT,
@@ -17304,187 +17155,6 @@ async fn test_tree_sitter_brackets_newline_insertion(cx: &mut TestAppContext) {
"});
}
#[gpui::test(iterations = 10)]
async fn test_apply_code_lens_actions_with_commands(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/dir"),
json!({
"a.ts": "a",
}),
)
.await;
let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let cx = &mut VisualTestContext::from_window(*workspace.deref(), cx);
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(Arc::new(Language::new(
LanguageConfig {
name: "TypeScript".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["ts".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
)));
let mut fake_language_servers = language_registry.register_fake_lsp(
"TypeScript",
FakeLspAdapter {
capabilities: lsp::ServerCapabilities {
code_lens_provider: Some(lsp::CodeLensOptions {
resolve_provider: Some(true),
}),
execute_command_provider: Some(lsp::ExecuteCommandOptions {
commands: vec!["_the/command".to_string()],
..lsp::ExecuteCommandOptions::default()
}),
..lsp::ServerCapabilities::default()
},
..FakeLspAdapter::default()
},
);
let (buffer, _handle) = project
.update(cx, |p, cx| {
p.open_local_buffer_with_lsp(path!("/dir/a.ts"), cx)
})
.await
.unwrap();
cx.executor().run_until_parked();
let fake_server = fake_language_servers.next().await.unwrap();
let buffer_snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
let anchor = buffer_snapshot.anchor_at(0, text::Bias::Left);
drop(buffer_snapshot);
let actions = cx
.update_window(*workspace, |_, window, cx| {
project.code_actions(&buffer, anchor..anchor, window, cx)
})
.unwrap();
fake_server
.handle_request::<lsp::request::CodeLensRequest, _, _>(|_, _| async move {
Ok(Some(vec![
lsp::CodeLens {
range: lsp::Range::default(),
command: Some(lsp::Command {
title: "Code lens command".to_owned(),
command: "_the/command".to_owned(),
arguments: None,
}),
data: None,
},
lsp::CodeLens {
range: lsp::Range::default(),
command: Some(lsp::Command {
title: "Command not in capabilities".to_owned(),
command: "not in capabilities".to_owned(),
arguments: None,
}),
data: None,
},
lsp::CodeLens {
range: lsp::Range {
start: lsp::Position {
line: 1,
character: 1,
},
end: lsp::Position {
line: 1,
character: 1,
},
},
command: Some(lsp::Command {
title: "Command not in range".to_owned(),
command: "_the/command".to_owned(),
arguments: None,
}),
data: None,
},
]))
})
.next()
.await;
let actions = actions.await.unwrap();
assert_eq!(
actions.len(),
1,
"Should have only one valid action for the 0..0 range"
);
let action = actions[0].clone();
let apply = project.update(cx, |project, cx| {
project.apply_code_action(buffer.clone(), action, true, cx)
});
// Resolving the code action does not populate its edits. In absence of
// edits, we must execute the given command.
fake_server.handle_request::<lsp::request::CodeLensResolve, _, _>(|mut lens, _| async move {
let lens_command = lens.command.as_mut().expect("should have a command");
assert_eq!(lens_command.title, "Code lens command");
lens_command.arguments = Some(vec![json!("the-argument")]);
Ok(lens)
});
// While executing the command, the language server sends the editor
// a `workspaceEdit` request.
fake_server
.handle_request::<lsp::request::ExecuteCommand, _, _>({
let fake = fake_server.clone();
move |params, _| {
assert_eq!(params.command, "_the/command");
let fake = fake.clone();
async move {
fake.server
.request::<lsp::request::ApplyWorkspaceEdit>(
lsp::ApplyWorkspaceEditParams {
label: None,
edit: lsp::WorkspaceEdit {
changes: Some(
[(
lsp::Url::from_file_path(path!("/dir/a.ts")).unwrap(),
vec![lsp::TextEdit {
range: lsp::Range::new(
lsp::Position::new(0, 0),
lsp::Position::new(0, 0),
),
new_text: "X".into(),
}],
)]
.into_iter()
.collect(),
),
..Default::default()
},
},
)
.await
.unwrap();
Ok(Some(json!(null)))
}
}
})
.next()
.await;
// Applying the code lens command returns a project transaction containing the edits
// sent by the language server in its `workspaceEdit` request.
let transaction = apply.await.unwrap();
assert!(transaction.0.contains_key(&buffer));
buffer.update(cx, |buffer, cx| {
assert_eq!(buffer.text(), "Xa");
buffer.undo(cx);
assert_eq!(buffer.text(), "a");
});
}
mod autoclose_tags {
use super::*;
use language::language_settings::JsxTagAutoCloseSettings;

View File

@@ -23,7 +23,7 @@ use crate::{
InlineCompletion, JumpData, LineDown, LineHighlight, LineUp, OpenExcerpts, PageDown, PageUp,
Point, RowExt, RowRangeExt, SelectPhase, SelectedTextHighlight, Selection, SoftWrap,
StickyHeaderExcerpt, ToPoint, ToggleFold, COLUMNAR_SELECTION_MODIFIERS, CURSORS_VISIBLE_FOR,
FILE_HEADER_HEIGHT, GIT_BLAME_MAX_AUTHOR_CHARS_DISPLAYED, MAX_LINE_LEN, MIN_LINE_NUMBER_DIGITS,
FILE_HEADER_HEIGHT, GIT_BLAME_MAX_AUTHOR_CHARS_DISPLAYED, MAX_LINE_LEN,
MULTI_BUFFER_EXCERPT_HEADER_HEIGHT,
};
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
@@ -55,7 +55,7 @@ use multi_buffer::{
Anchor, ExcerptId, ExcerptInfo, ExpandExcerptDirection, ExpandInfo, MultiBufferPoint,
MultiBufferRow, RowInfo,
};
use project::project_settings::{self, GitGutterSetting, GitHunkStyleSetting, ProjectSettings};
use project::project_settings::{self, GitGutterSetting, ProjectSettings};
use settings::Settings;
use smallvec::{smallvec, SmallVec};
use std::{
@@ -390,7 +390,6 @@ impl EditorElement {
register_action(editor, window, Editor::set_mark);
register_action(editor, window, Editor::swap_selection_ends);
register_action(editor, window, Editor::show_completions);
register_action(editor, window, Editor::show_word_completions);
register_action(editor, window, Editor::toggle_code_actions);
register_action(editor, window, Editor::open_excerpts);
register_action(editor, window, Editor::open_excerpts_in_split);
@@ -2119,11 +2118,9 @@ impl EditorElement {
})
}
fn layout_expand_toggles(
fn layout_excerpt_gutter(
&self,
gutter_hitbox: &Hitbox,
gutter_dimensions: GutterDimensions,
em_width: Pixels,
line_height: Pixels,
scroll_position: gpui::Point<f32>,
buffer_rows: &[RowInfo],
@@ -2132,17 +2129,10 @@ impl EditorElement {
) -> Vec<Option<(AnyElement, gpui::Point<Pixels>)>> {
let editor_font_size = self.style.text.font_size.to_pixels(window.rem_size()) * 1.2;
let scroll_top = scroll_position.y * line_height;
let icon_size = editor_font_size.round();
let button_h_padding = ((icon_size - px(1.0)) / 2.0).round() - px(2.0);
let max_line_number_length = self
.editor
.read(cx)
.buffer()
.read(cx)
.snapshot(cx)
.widest_line_number()
.ilog10()
+ 1;
let scroll_top = scroll_position.y * line_height;
let elements = buffer_rows
.into_iter()
@@ -2159,27 +2149,27 @@ impl EditorElement {
ExpandExcerptDirection::UpAndDown => IconName::ExpandVertical,
};
let git_gutter_width = Self::gutter_strip_width(line_height);
let available_width = gutter_dimensions.left_padding - git_gutter_width;
let editor = self.editor.clone();
let is_wide = max_line_number_length >= MIN_LINE_NUMBER_DIGITS
let max_row = self
.editor
.read(cx)
.buffer()
.read(cx)
.snapshot(cx)
.widest_line_number();
let is_wide = max_row > 999
&& row_info
.buffer_row
.is_some_and(|row| (row + 1).ilog10() + 1 == max_line_number_length)
|| gutter_dimensions.right_padding == px(0.);
let width = if is_wide {
available_width - px(2.)
} else {
available_width + em_width - px(2.)
};
.is_some_and(|row| row.ilog10() == max_row.ilog10());
let toggle = IconButton::new(("expand", ix), icon_name)
.icon_color(Color::Custom(cx.theme().colors().editor_line_number))
.selected_icon_color(Color::Custom(cx.theme().colors().editor_foreground))
.icon_size(IconSize::Custom(rems(editor_font_size / window.rem_size())))
.width(width.into())
.width((icon_size + button_h_padding * 2).into())
.when(is_wide, |el| {
el.width((icon_size + button_h_padding).into())
})
.on_click(move |_, _, cx| {
editor.update(cx, |editor, cx| {
editor.expand_excerpt(excerpt_id, direction, cx);
@@ -2192,7 +2182,7 @@ impl EditorElement {
.into_any_element();
let position = point(
git_gutter_width + px(1.),
px(1.),
ix as f32 * line_height - (scroll_top % line_height) + px(1.),
);
let origin = gutter_hitbox.origin + position;
@@ -2647,7 +2637,7 @@ impl EditorElement {
}
Block::ExcerptBoundary {
excerpt,
next_excerpt,
height,
starts_new_buffer,
..
@@ -2655,31 +2645,40 @@ impl EditorElement {
let color = cx.theme().colors().clone();
let mut result = v_flex().id(block_id).w_full();
let jump_data = header_jump_data(snapshot, block_row_start, *height, excerpt);
if let Some(next_excerpt) = next_excerpt {
let jump_data =
header_jump_data(snapshot, block_row_start, *height, next_excerpt);
if *starts_new_buffer {
if sticky_header_excerpt_id != Some(excerpt.id) {
let selected = selected_buffer_ids.contains(&excerpt.buffer_id);
if *starts_new_buffer {
if sticky_header_excerpt_id != Some(next_excerpt.id) {
let selected = selected_buffer_ids.contains(&next_excerpt.buffer_id);
result = result.child(self.render_buffer_header(
excerpt, false, selected, false, jump_data, window, cx,
));
result = result.child(self.render_buffer_header(
next_excerpt,
false,
selected,
false,
jump_data,
window,
cx,
));
} else {
result = result
.child(div().h(FILE_HEADER_HEIGHT as f32 * window.line_height()));
}
} else {
result =
result.child(div().h(FILE_HEADER_HEIGHT as f32 * window.line_height()));
}
} else {
result = result.child(
h_flex().relative().child(
div()
.top(line_height / 2.)
.absolute()
.w_full()
.h_px()
.bg(color.border_variant),
),
);
};
result = result.child(
h_flex().relative().child(
div()
.top(line_height / 2.)
.absolute()
.w_full()
.h_px()
.bg(color.border_variant),
),
);
};
}
result.into_any()
}
@@ -2716,7 +2715,7 @@ impl EditorElement {
for_excerpt: &ExcerptInfo,
is_folded: bool,
is_selected: bool,
is_sticky: bool,
_is_sticky: bool,
jump_data: JumpData,
window: &mut Window,
cx: &mut App,
@@ -2751,7 +2750,8 @@ impl EditorElement {
let colors = cx.theme().colors();
div()
.p_1()
.px_2()
.pt_2()
.w_full()
.h(FILE_HEADER_HEIGHT as f32 * window.line_height())
.child(
@@ -2762,7 +2762,7 @@ impl EditorElement {
.pl_0p5()
.pr_5()
.rounded_sm()
.when(is_sticky, |el| el.shadow_md())
// .when(is_sticky, |el| el.shadow_md())
.border_1()
.map(|div| {
let border_color = if is_selected
@@ -2790,7 +2790,7 @@ impl EditorElement {
ButtonLike::new("toggle-buffer-fold")
.style(ui::ButtonStyle::Transparent)
.size(ButtonSize::Large)
.width(px(24.).into())
.width(px(30.).into())
.children(toggle_chevron_icon)
.tooltip({
let focus_handle = focus_handle.clone();
@@ -2969,7 +2969,6 @@ impl EditorElement {
element,
available_space: size(AvailableSpace::MinContent, element_size.height.into()),
style: BlockStyle::Fixed,
is_buffer_header: block.is_buffer_header(),
});
}
@@ -3020,7 +3019,6 @@ impl EditorElement {
element,
available_space: size(width.into(), element_size.height.into()),
style,
is_buffer_header: block.is_buffer_header(),
});
}
@@ -3071,7 +3069,6 @@ impl EditorElement {
element,
available_space: size(width, element_size.height.into()),
style,
is_buffer_header: block.is_buffer_header(),
});
}
}
@@ -3133,13 +3130,15 @@ impl EditorElement {
fn layout_sticky_buffer_header(
&self,
StickyHeaderExcerpt { excerpt }: StickyHeaderExcerpt<'_>,
StickyHeaderExcerpt {
excerpt,
next_buffer_row,
}: StickyHeaderExcerpt<'_>,
scroll_position: f32,
line_height: Pixels,
snapshot: &EditorSnapshot,
hitbox: &Hitbox,
selected_buffer_ids: &Vec<BufferId>,
blocks: &[BlockLayout],
window: &mut Window,
cx: &mut App,
) -> AnyElement {
@@ -3175,23 +3174,17 @@ impl EditorElement {
.into_any_element();
let mut origin = hitbox.origin;
// Move floating header up to avoid colliding with the next buffer header.
for block in blocks.iter() {
if !block.is_buffer_header {
continue;
}
let Some(display_row) = block.row.filter(|row| row.0 > scroll_position as u32) else {
continue;
};
if let Some(next_buffer_row) = next_buffer_row {
// Push up the sticky header when the excerpt is getting close to the top of the viewport
let max_row = next_buffer_row - FILE_HEADER_HEIGHT * 2;
let max_row = display_row.0.saturating_sub(FILE_HEADER_HEIGHT);
let offset = scroll_position - max_row as f32;
if offset > 0.0 {
origin.y -= Pixels(offset) * line_height;
}
break;
}
let size = size(
@@ -4382,7 +4375,7 @@ impl EditorElement {
}),
};
if let Some((hunk_bounds, background_color, corner_radii, status)) = hunk_to_paint {
if let Some((hunk_bounds, background_color, corner_radii, _)) = hunk_to_paint {
// Flatten the background color with the editor color to prevent
// elements below transparent hunks from showing through
let flattened_background_color = cx
@@ -4391,38 +4384,18 @@ impl EditorElement {
.editor_background
.blend(background_color);
if !Self::diff_hunk_hollow(status, cx) {
window.paint_quad(quad(
hunk_bounds,
corner_radii,
flattened_background_color,
Edges::default(),
transparent_black(),
));
} else {
let flattened_unstaged_background_color = cx
.theme()
.colors()
.editor_background
.blend(background_color.opacity(0.3));
window.paint_quad(quad(
hunk_bounds,
corner_radii,
flattened_unstaged_background_color,
Edges::all(Pixels(1.0)),
flattened_background_color,
));
}
window.paint_quad(quad(
hunk_bounds,
corner_radii,
flattened_background_color,
Edges::default(),
transparent_black(),
));
}
}
});
}
fn gutter_strip_width(line_height: Pixels) -> Pixels {
(0.275 * line_height).floor()
}
fn diff_hunk_bounds(
snapshot: &EditorSnapshot,
line_height: Pixels,
@@ -4431,7 +4404,7 @@ impl EditorElement {
) -> Bounds<Pixels> {
let scroll_position = snapshot.scroll_position();
let scroll_top = scroll_position.y * line_height;
let gutter_strip_width = Self::gutter_strip_width(line_height);
let gutter_strip_width = (0.275 * line_height).floor();
match hunk {
DisplayDiffHunk::Folded { display_row, .. } => {
@@ -5604,8 +5577,8 @@ impl EditorElement {
window: &mut Window,
cx: &mut App,
) -> Pixels {
let digit_count = snapshot.widest_line_number().ilog10() + 1;
self.column_pixels(digit_count as usize, window, cx)
let digit_count = (snapshot.widest_line_number() as f32).log10().floor() as usize + 1;
self.column_pixels(digit_count, window, cx)
}
fn shape_line_number(
@@ -5628,18 +5601,6 @@ impl EditorElement {
&[run],
)
}
fn diff_hunk_hollow(status: DiffHunkStatus, cx: &mut App) -> bool {
let unstaged = status.has_secondary_hunk();
let unstaged_hollow = ProjectSettings::get_global(cx)
.git
.hunk_style
.map_or(false, |style| {
matches!(style, GitHunkStyleSetting::UnstagedHollow)
});
unstaged == unstaged_hollow
}
}
fn header_jump_data(
@@ -6791,9 +6752,10 @@ impl Element for EditorElement {
}
};
let unstaged = diff_status.has_secondary_hunk();
let hunk_opacity = if is_light { 0.16 } else { 0.12 };
let hollow_highlight = LineHighlight {
let staged_highlight = LineHighlight {
background: (background_color.opacity(if is_light {
0.08
} else {
@@ -6807,13 +6769,13 @@ impl Element for EditorElement {
}),
};
let filled_highlight =
let unstaged_highlight =
solid_background(background_color.opacity(hunk_opacity)).into();
let background = if Self::diff_hunk_hollow(diff_status, cx) {
hollow_highlight
let background = if unstaged {
unstaged_highlight
} else {
filled_highlight
staged_highlight
};
highlighted_rows
@@ -6897,10 +6859,8 @@ impl Element for EditorElement {
let mut expand_toggles =
window.with_element_namespace("expand_toggles", |window| {
self.layout_expand_toggles(
self.layout_excerpt_gutter(
&gutter_hitbox,
gutter_dimensions,
em_width,
line_height,
scroll_position,
&row_infos,
@@ -7053,7 +7013,6 @@ impl Element for EditorElement {
&snapshot,
&hitbox,
&selected_buffer_ids,
&blocks,
window,
cx,
)
@@ -7937,7 +7896,6 @@ struct BlockLayout {
element: AnyElement,
available_space: Size<AvailableSpace>,
style: BlockStyle,
is_buffer_header: bool,
}
pub fn layout_line(

View File

@@ -1109,14 +1109,14 @@ mod tests {
px(14.0),
None,
0,
1,
2,
FoldPlaceholder::test(),
cx,
)
});
let snapshot = display_map.update(cx, |map, cx| map.snapshot(cx));
assert_eq!(snapshot.text(), "abc\ndefg\n\nhijkl\nmn");
assert_eq!(snapshot.text(), "abc\ndefg\nhijkl\nmn");
let col_2_x = snapshot
.x_for_display_point(DisplayPoint::new(DisplayRow(0), 2), &text_layout_details);
@@ -1181,13 +1181,13 @@ mod tests {
);
let col_5_x = snapshot
.x_for_display_point(DisplayPoint::new(DisplayRow(3), 5), &text_layout_details);
.x_for_display_point(DisplayPoint::new(DisplayRow(2), 5), &text_layout_details);
// Move up and down across second excerpt's header
assert_eq!(
up(
&snapshot,
DisplayPoint::new(DisplayRow(3), 5),
DisplayPoint::new(DisplayRow(2), 5),
SelectionGoal::HorizontalPosition(col_5_x.0),
false,
&text_layout_details
@@ -1206,38 +1206,38 @@ mod tests {
&text_layout_details
),
(
DisplayPoint::new(DisplayRow(3), 5),
DisplayPoint::new(DisplayRow(2), 5),
SelectionGoal::HorizontalPosition(col_5_x.0)
),
);
let max_point_x = snapshot
.x_for_display_point(DisplayPoint::new(DisplayRow(4), 2), &text_layout_details);
.x_for_display_point(DisplayPoint::new(DisplayRow(3), 2), &text_layout_details);
// Can't move down off the end, and attempting to do so leaves the selection goal unchanged
assert_eq!(
down(
&snapshot,
DisplayPoint::new(DisplayRow(4), 0),
DisplayPoint::new(DisplayRow(3), 0),
SelectionGoal::HorizontalPosition(0.0),
false,
&text_layout_details
),
(
DisplayPoint::new(DisplayRow(4), 2),
DisplayPoint::new(DisplayRow(3), 2),
SelectionGoal::HorizontalPosition(0.0)
),
);
assert_eq!(
down(
&snapshot,
DisplayPoint::new(DisplayRow(4), 2),
DisplayPoint::new(DisplayRow(3), 2),
SelectionGoal::HorizontalPosition(max_point_x.0),
false,
&text_layout_details
),
(
DisplayPoint::new(DisplayRow(4), 2),
DisplayPoint::new(DisplayRow(3), 2),
SelectionGoal::HorizontalPosition(max_point_x.0)
),
);

View File

@@ -714,7 +714,6 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) {
})
.await
.unwrap()
.unwrap()
.into_iter()
.map(|c| c.label.text)
.collect::<Vec<_>>();

View File

@@ -36,7 +36,7 @@ serde_json.workspace = true
smol.workspace = true
sysinfo.workspace = true
ui.workspace = true
urlencoding.workspace = true
urlencoding = "2.1.2"
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true

View File

@@ -52,7 +52,7 @@ use util::ResultExt;
#[cfg(any(test, feature = "test-support"))]
use collections::{btree_map, BTreeMap};
#[cfg(any(test, feature = "test-support"))]
use git::FakeGitRepositoryState;
use git::repository::FakeGitRepositoryState;
#[cfg(any(test, feature = "test-support"))]
use parking_lot::Mutex;
#[cfg(any(test, feature = "test-support"))]
@@ -885,7 +885,7 @@ enum FakeFsEntry {
mtime: MTime,
len: u64,
entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
git_repo_state: Option<Arc<Mutex<git::FakeGitRepositoryState>>>,
git_repo_state: Option<Arc<Mutex<git::repository::FakeGitRepositoryState>>>,
},
Symlink {
target: PathBuf,
@@ -2095,7 +2095,7 @@ impl Fs for FakeFs {
)))
})
.clone();
Some(git::FakeGitRepository::open(state))
Some(git::repository::FakeGitRepository::open(state))
} else {
None
}

View File

@@ -42,4 +42,3 @@ pretty_assertions.workspace = true
serde_json.workspace = true
text = { workspace = true, features = ["test-support"] }
unindent.workspace = true
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -1,294 +0,0 @@
use crate::{
blame::Blame,
repository::{
Branch, CommitDetails, DiffType, GitRepository, PushOptions, Remote, RemoteCommandOutput,
RepoPath, ResetMode,
},
status::{FileStatus, GitStatus},
};
use anyhow::{Context, Result};
use askpass::AskPassSession;
use collections::{HashMap, HashSet};
use futures::{future::BoxFuture, FutureExt as _};
use gpui::{AsyncApp, SharedString};
use parking_lot::Mutex;
use rope::Rope;
use std::{path::PathBuf, sync::Arc};
#[derive(Debug, Clone)]
pub struct FakeGitRepository {
state: Arc<Mutex<FakeGitRepositoryState>>,
}
#[derive(Debug, Clone)]
pub struct FakeGitRepositoryState {
pub path: PathBuf,
pub event_emitter: smol::channel::Sender<PathBuf>,
pub head_contents: HashMap<RepoPath, String>,
pub index_contents: HashMap<RepoPath, String>,
pub blames: HashMap<RepoPath, Blame>,
pub statuses: HashMap<RepoPath, FileStatus>,
pub current_branch_name: Option<String>,
pub branches: HashSet<String>,
pub simulated_index_write_error_message: Option<String>,
}
impl FakeGitRepository {
pub fn open(state: Arc<Mutex<FakeGitRepositoryState>>) -> Arc<dyn GitRepository> {
Arc::new(FakeGitRepository { state })
}
}
impl FakeGitRepositoryState {
pub fn new(path: PathBuf, event_emitter: smol::channel::Sender<PathBuf>) -> Self {
FakeGitRepositoryState {
path,
event_emitter,
head_contents: Default::default(),
index_contents: Default::default(),
blames: Default::default(),
statuses: Default::default(),
current_branch_name: Default::default(),
branches: Default::default(),
simulated_index_write_error_message: None,
}
}
}
impl GitRepository for FakeGitRepository {
fn reload_index(&self) {}
fn load_index_text(&self, path: RepoPath, _: AsyncApp) -> BoxFuture<Option<String>> {
let state = self.state.lock();
let content = state.index_contents.get(path.as_ref()).cloned();
async { content }.boxed()
}
fn load_committed_text(&self, path: RepoPath, _: AsyncApp) -> BoxFuture<Option<String>> {
let state = self.state.lock();
let content = state.head_contents.get(path.as_ref()).cloned();
async { content }.boxed()
}
fn set_index_text(
&self,
path: RepoPath,
content: Option<String>,
_env: HashMap<String, String>,
cx: AsyncApp,
) -> BoxFuture<anyhow::Result<()>> {
let state = self.state.clone();
let executor = cx.background_executor().clone();
async move {
executor.simulate_random_delay().await;
let mut state = state.lock();
if let Some(message) = state.simulated_index_write_error_message.clone() {
return Err(anyhow::anyhow!(message));
}
if let Some(content) = content {
state.index_contents.insert(path.clone(), content);
} else {
state.index_contents.remove(&path);
}
state
.event_emitter
.try_send(state.path.clone())
.expect("Dropped repo change event");
Ok(())
}
.boxed()
}
fn remote_url(&self, _name: &str) -> Option<String> {
None
}
fn head_sha(&self) -> Option<String> {
None
}
fn merge_head_shas(&self) -> Vec<String> {
vec![]
}
fn show(&self, _: String, _: AsyncApp) -> BoxFuture<Result<CommitDetails>> {
unimplemented!()
}
fn reset(&self, _: String, _: ResetMode, _: HashMap<String, String>) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn checkout_files(
&self,
_: String,
_: Vec<RepoPath>,
_: HashMap<String, String>,
) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn path(&self) -> PathBuf {
let state = self.state.lock();
state.path.clone()
}
fn main_repository_path(&self) -> PathBuf {
self.path()
}
fn status(&self, path_prefixes: &[RepoPath]) -> Result<GitStatus> {
let state = self.state.lock();
let mut entries = state
.statuses
.iter()
.filter_map(|(repo_path, status)| {
if path_prefixes
.iter()
.any(|path_prefix| repo_path.0.starts_with(path_prefix))
{
Some((repo_path.to_owned(), *status))
} else {
None
}
})
.collect::<Vec<_>>();
entries.sort_unstable_by(|(a, _), (b, _)| a.cmp(&b));
Ok(GitStatus {
entries: entries.into(),
})
}
fn branches(&self) -> BoxFuture<Result<Vec<Branch>>> {
let state = self.state.lock();
let current_branch = &state.current_branch_name;
let result = Ok(state
.branches
.iter()
.map(|branch_name| Branch {
is_head: Some(branch_name) == current_branch.as_ref(),
name: branch_name.into(),
most_recent_commit: None,
upstream: None,
})
.collect());
async { result }.boxed()
}
fn change_branch(&self, name: String, _: AsyncApp) -> BoxFuture<Result<()>> {
let mut state = self.state.lock();
state.current_branch_name = Some(name.to_owned());
state
.event_emitter
.try_send(state.path.clone())
.expect("Dropped repo change event");
async { Ok(()) }.boxed()
}
fn create_branch(&self, name: String, _: AsyncApp) -> BoxFuture<Result<()>> {
let mut state = self.state.lock();
state.branches.insert(name.to_owned());
state
.event_emitter
.try_send(state.path.clone())
.expect("Dropped repo change event");
async { Ok(()) }.boxed()
}
fn blame(
&self,
path: RepoPath,
_content: Rope,
_cx: AsyncApp,
) -> BoxFuture<Result<crate::blame::Blame>> {
let state = self.state.lock();
let result = state
.blames
.get(&path)
.with_context(|| format!("failed to get blame for {:?}", path.0))
.cloned();
async { result }.boxed()
}
fn stage_paths(
&self,
_paths: Vec<RepoPath>,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn unstage_paths(
&self,
_paths: Vec<RepoPath>,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn commit(
&self,
_message: SharedString,
_name_and_email: Option<(SharedString, SharedString)>,
_env: HashMap<String, String>,
_: AsyncApp,
) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn push(
&self,
_branch: String,
_remote: String,
_options: Option<PushOptions>,
_ask_pass: AskPassSession,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<Result<RemoteCommandOutput>> {
unimplemented!()
}
fn pull(
&self,
_branch: String,
_remote: String,
_ask_pass: AskPassSession,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<Result<RemoteCommandOutput>> {
unimplemented!()
}
fn fetch(
&self,
_ask_pass: AskPassSession,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<Result<RemoteCommandOutput>> {
unimplemented!()
}
fn get_remotes(
&self,
_branch: Option<String>,
_cx: AsyncApp,
) -> BoxFuture<Result<Vec<Remote>>> {
unimplemented!()
}
fn check_for_pushed_commit(&self, _cx: AsyncApp) -> BoxFuture<Result<Vec<SharedString>>> {
unimplemented!()
}
fn diff(&self, _diff: DiffType, _cx: AsyncApp) -> BoxFuture<Result<String>> {
unimplemented!()
}
}

View File

@@ -5,25 +5,20 @@ mod remote;
pub mod repository;
pub mod status;
#[cfg(any(test, feature = "test-support"))]
mod fake_repository;
#[cfg(any(test, feature = "test-support"))]
pub use fake_repository::*;
pub use crate::hosting_provider::*;
pub use crate::remote::*;
use anyhow::{anyhow, Context as _, Result};
pub use git2 as libgit;
use gpui::action_with_deprecated_aliases;
use gpui::actions;
pub use repository::WORK_DIRECTORY_REPO_PATH;
use serde::{Deserialize, Serialize};
use std::ffi::OsStr;
use std::fmt;
use std::str::FromStr;
use std::sync::LazyLock;
pub use crate::hosting_provider::*;
pub use crate::remote::*;
pub use git2 as libgit;
pub use repository::WORK_DIRECTORY_REPO_PATH;
pub static DOT_GIT: LazyLock<&'static OsStr> = LazyLock::new(|| OsStr::new(".git"));
pub static GITIGNORE: LazyLock<&'static OsStr> = LazyLock::new(|| OsStr::new(".gitignore"));
pub static FSMONITOR_DAEMON: LazyLock<&'static OsStr> =

View File

@@ -1,8 +1,9 @@
use crate::status::GitStatus;
use crate::status::FileStatus;
use crate::SHORT_SHA_LENGTH;
use anyhow::{anyhow, Context as _, Result};
use crate::{blame::Blame, status::GitStatus};
use anyhow::{anyhow, Context, Result};
use askpass::{AskPassResult, AskPassSession};
use collections::HashMap;
use collections::{HashMap, HashSet};
use futures::future::BoxFuture;
use futures::{select_biased, AsyncWriteExt, FutureExt as _};
use git2::BranchType;
@@ -12,12 +13,11 @@ use rope::Rope;
use schemars::JsonSchema;
use serde::Deserialize;
use std::borrow::Borrow;
use std::path::Component;
use std::process::Stdio;
use std::sync::LazyLock;
use std::{
cmp::Ordering,
path::{Path, PathBuf},
path::{Component, Path, PathBuf},
sync::Arc,
};
use sum_tree::MapSeekTarget;
@@ -1056,6 +1056,304 @@ async fn run_remote_command(
}
}
#[derive(Debug, Clone)]
pub struct FakeGitRepository {
state: Arc<Mutex<FakeGitRepositoryState>>,
}
#[derive(Debug, Clone)]
pub struct FakeGitRepositoryState {
pub path: PathBuf,
pub event_emitter: smol::channel::Sender<PathBuf>,
pub head_contents: HashMap<RepoPath, String>,
pub index_contents: HashMap<RepoPath, String>,
pub blames: HashMap<RepoPath, Blame>,
pub statuses: HashMap<RepoPath, FileStatus>,
pub current_branch_name: Option<String>,
pub branches: HashSet<String>,
pub simulated_index_write_error_message: Option<String>,
}
impl FakeGitRepository {
pub fn open(state: Arc<Mutex<FakeGitRepositoryState>>) -> Arc<dyn GitRepository> {
Arc::new(FakeGitRepository { state })
}
}
impl FakeGitRepositoryState {
pub fn new(path: PathBuf, event_emitter: smol::channel::Sender<PathBuf>) -> Self {
FakeGitRepositoryState {
path,
event_emitter,
head_contents: Default::default(),
index_contents: Default::default(),
blames: Default::default(),
statuses: Default::default(),
current_branch_name: Default::default(),
branches: Default::default(),
simulated_index_write_error_message: None,
}
}
}
impl GitRepository for FakeGitRepository {
fn reload_index(&self) {}
fn load_index_text(&self, path: RepoPath, _: AsyncApp) -> BoxFuture<Option<String>> {
let state = self.state.lock();
let content = state.index_contents.get(path.as_ref()).cloned();
async { content }.boxed()
}
fn load_committed_text(&self, path: RepoPath, _: AsyncApp) -> BoxFuture<Option<String>> {
let state = self.state.lock();
let content = state.head_contents.get(path.as_ref()).cloned();
async { content }.boxed()
}
fn set_index_text(
&self,
path: RepoPath,
content: Option<String>,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<anyhow::Result<()>> {
let mut state = self.state.lock();
if let Some(message) = state.simulated_index_write_error_message.clone() {
return async { Err(anyhow::anyhow!(message)) }.boxed();
}
if let Some(content) = content {
state.index_contents.insert(path.clone(), content);
} else {
state.index_contents.remove(&path);
}
state
.event_emitter
.try_send(state.path.clone())
.expect("Dropped repo change event");
async { Ok(()) }.boxed()
}
fn remote_url(&self, _name: &str) -> Option<String> {
None
}
fn head_sha(&self) -> Option<String> {
None
}
fn merge_head_shas(&self) -> Vec<String> {
vec![]
}
fn show(&self, _: String, _: AsyncApp) -> BoxFuture<Result<CommitDetails>> {
unimplemented!()
}
fn reset(&self, _: String, _: ResetMode, _: HashMap<String, String>) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn checkout_files(
&self,
_: String,
_: Vec<RepoPath>,
_: HashMap<String, String>,
) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn path(&self) -> PathBuf {
let state = self.state.lock();
state.path.clone()
}
fn main_repository_path(&self) -> PathBuf {
self.path()
}
fn status(&self, path_prefixes: &[RepoPath]) -> Result<GitStatus> {
let state = self.state.lock();
let mut entries = state
.statuses
.iter()
.filter_map(|(repo_path, status)| {
if path_prefixes
.iter()
.any(|path_prefix| repo_path.0.starts_with(path_prefix))
{
Some((repo_path.to_owned(), *status))
} else {
None
}
})
.collect::<Vec<_>>();
entries.sort_unstable_by(|(a, _), (b, _)| a.cmp(&b));
Ok(GitStatus {
entries: entries.into(),
})
}
fn branches(&self) -> BoxFuture<Result<Vec<Branch>>> {
let state = self.state.lock();
let current_branch = &state.current_branch_name;
let result = Ok(state
.branches
.iter()
.map(|branch_name| Branch {
is_head: Some(branch_name) == current_branch.as_ref(),
name: branch_name.into(),
most_recent_commit: None,
upstream: None,
})
.collect());
async { result }.boxed()
}
fn change_branch(&self, name: String, _: AsyncApp) -> BoxFuture<Result<()>> {
let mut state = self.state.lock();
state.current_branch_name = Some(name.to_owned());
state
.event_emitter
.try_send(state.path.clone())
.expect("Dropped repo change event");
async { Ok(()) }.boxed()
}
fn create_branch(&self, name: String, _: AsyncApp) -> BoxFuture<Result<()>> {
let mut state = self.state.lock();
state.branches.insert(name.to_owned());
state
.event_emitter
.try_send(state.path.clone())
.expect("Dropped repo change event");
async { Ok(()) }.boxed()
}
fn blame(
&self,
path: RepoPath,
_content: Rope,
_cx: AsyncApp,
) -> BoxFuture<Result<crate::blame::Blame>> {
let state = self.state.lock();
let result = state
.blames
.get(&path)
.with_context(|| format!("failed to get blame for {:?}", path.0))
.cloned();
async { result }.boxed()
}
fn stage_paths(
&self,
_paths: Vec<RepoPath>,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn unstage_paths(
&self,
_paths: Vec<RepoPath>,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn commit(
&self,
_message: SharedString,
_name_and_email: Option<(SharedString, SharedString)>,
_env: HashMap<String, String>,
_: AsyncApp,
) -> BoxFuture<Result<()>> {
unimplemented!()
}
fn push(
&self,
_branch: String,
_remote: String,
_options: Option<PushOptions>,
_ask_pass: AskPassSession,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<Result<RemoteCommandOutput>> {
unimplemented!()
}
fn pull(
&self,
_branch: String,
_remote: String,
_ask_pass: AskPassSession,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<Result<RemoteCommandOutput>> {
unimplemented!()
}
fn fetch(
&self,
_ask_pass: AskPassSession,
_env: HashMap<String, String>,
_cx: AsyncApp,
) -> BoxFuture<Result<RemoteCommandOutput>> {
unimplemented!()
}
fn get_remotes(
&self,
_branch: Option<String>,
_cx: AsyncApp,
) -> BoxFuture<Result<Vec<Remote>>> {
unimplemented!()
}
fn check_for_pushed_commit(&self, _cx: AsyncApp) -> BoxFuture<Result<Vec<SharedString>>> {
unimplemented!()
}
fn diff(&self, _diff: DiffType, _cx: AsyncApp) -> BoxFuture<Result<String>> {
unimplemented!()
}
}
fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
match relative_file_path.components().next() {
None => anyhow::bail!("repo path should not be empty"),
Some(Component::Prefix(_)) => anyhow::bail!(
"repo path `{}` should be relative, not a windows prefix",
relative_file_path.to_string_lossy()
),
Some(Component::RootDir) => {
anyhow::bail!(
"repo path `{}` should be relative",
relative_file_path.to_string_lossy()
)
}
Some(Component::CurDir) => {
anyhow::bail!(
"repo path `{}` should not start with `.`",
relative_file_path.to_string_lossy()
)
}
Some(Component::ParentDir) => {
anyhow::bail!(
"repo path `{}` should not start with `..`",
relative_file_path.to_string_lossy()
)
}
_ => Ok(()),
}
}
pub static WORK_DIRECTORY_REPO_PATH: LazyLock<RepoPath> =
LazyLock::new(|| RepoPath(Path::new("").into()));
@@ -1228,35 +1526,6 @@ fn parse_upstream_track(upstream_track: &str) -> Result<UpstreamTracking> {
}))
}
fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
match relative_file_path.components().next() {
None => anyhow::bail!("repo path should not be empty"),
Some(Component::Prefix(_)) => anyhow::bail!(
"repo path `{}` should be relative, not a windows prefix",
relative_file_path.to_string_lossy()
),
Some(Component::RootDir) => {
anyhow::bail!(
"repo path `{}` should be relative",
relative_file_path.to_string_lossy()
)
}
Some(Component::CurDir) => {
anyhow::bail!(
"repo path `{}` should not start with `.`",
relative_file_path.to_string_lossy()
)
}
Some(Component::ParentDir) => {
anyhow::bail!(
"repo path `{}` should not start with `..`",
relative_file_path.to_string_lossy()
)
}
_ => Ok(()),
}
}
#[test]
fn test_branches_parsing() {
// suppress "help: octal escapes are not supported, `\0` is always null"

View File

@@ -19,10 +19,8 @@ git.workspace = true
gpui.workspace = true
http_client.workspace = true
regex.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
url.workspace = true
util.workspace = true

View File

@@ -1,5 +1,4 @@
mod providers;
mod settings;
use std::sync::Arc;
@@ -11,19 +10,16 @@ use url::Url;
use util::maybe;
pub use crate::providers::*;
pub use crate::settings::*;
/// Initializes the Git hosting providers.
pub fn init(cx: &mut App) {
crate::settings::init(cx);
pub fn init(cx: &App) {
let provider_registry = GitHostingProviderRegistry::global(cx);
provider_registry.register_hosting_provider(Arc::new(Bitbucket::public_instance()));
provider_registry.register_hosting_provider(Arc::new(Bitbucket));
provider_registry.register_hosting_provider(Arc::new(Chromium));
provider_registry.register_hosting_provider(Arc::new(Codeberg));
provider_registry.register_hosting_provider(Arc::new(Gitee));
provider_registry.register_hosting_provider(Arc::new(Github::public_instance()));
provider_registry.register_hosting_provider(Arc::new(Gitlab::public_instance()));
provider_registry.register_hosting_provider(Arc::new(Github::new()));
provider_registry.register_hosting_provider(Arc::new(Gitlab::new()));
provider_registry.register_hosting_provider(Arc::new(Sourcehut));
}

View File

@@ -7,31 +7,15 @@ use git::{
RemoteUrl,
};
pub struct Bitbucket {
name: String,
base_url: Url,
}
impl Bitbucket {
pub fn new(name: impl Into<String>, base_url: Url) -> Self {
Self {
name: name.into(),
base_url,
}
}
pub fn public_instance() -> Self {
Self::new("Bitbucket", Url::parse("https://bitbucket.org").unwrap())
}
}
pub struct Bitbucket;
impl GitHostingProvider for Bitbucket {
fn name(&self) -> String {
self.name.clone()
"Bitbucket".to_string()
}
fn base_url(&self) -> Url {
self.base_url.clone()
Url::parse("https://bitbucket.org").unwrap()
}
fn supports_avatars(&self) -> bool {
@@ -106,7 +90,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_ssh_url() {
let parsed_remote = Bitbucket::public_instance()
let parsed_remote = Bitbucket
.parse_remote_url("git@bitbucket.org:zed-industries/zed.git")
.unwrap();
@@ -121,7 +105,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_https_url() {
let parsed_remote = Bitbucket::public_instance()
let parsed_remote = Bitbucket
.parse_remote_url("https://bitbucket.org/zed-industries/zed.git")
.unwrap();
@@ -136,7 +120,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_https_url_with_username() {
let parsed_remote = Bitbucket::public_instance()
let parsed_remote = Bitbucket
.parse_remote_url("https://thorstenballzed@bitbucket.org/zed-industries/zed.git")
.unwrap();
@@ -151,7 +135,7 @@ mod tests {
#[test]
fn test_build_bitbucket_permalink() {
let permalink = Bitbucket::public_instance().build_permalink(
let permalink = Bitbucket.build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
@@ -169,7 +153,7 @@ mod tests {
#[test]
fn test_build_bitbucket_permalink_with_single_line_selection() {
let permalink = Bitbucket::public_instance().build_permalink(
let permalink = Bitbucket.build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
@@ -187,7 +171,7 @@ mod tests {
#[test]
fn test_build_bitbucket_permalink_with_multi_line_selection() {
let permalink = Bitbucket::public_instance().build_permalink(
let permalink = Bitbucket.build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),

View File

@@ -45,24 +45,19 @@ struct User {
pub avatar_url: String,
}
#[derive(Debug)]
pub struct Github {
name: String,
base_url: Url,
}
impl Github {
pub fn new(name: impl Into<String>, base_url: Url) -> Self {
pub fn new() -> Self {
Self {
name: name.into(),
base_url,
name: "GitHub".to_string(),
base_url: Url::parse("https://github.com").unwrap(),
}
}
pub fn public_instance() -> Self {
Self::new("GitHub", Url::parse("https://github.com").unwrap())
}
pub fn from_remote_url(remote_url: &str) -> Result<Self> {
let host = get_host_from_git_remote_url(remote_url)?;
if host == "github.com" {
@@ -76,10 +71,10 @@ impl Github {
bail!("not a GitHub URL");
}
Ok(Self::new(
"GitHub Self-Hosted",
Url::parse(&format!("https://{}", host))?,
))
Ok(Self {
name: "GitHub Self-Hosted".to_string(),
base_url: Url::parse(&format!("https://{}", host))?,
})
}
async fn fetch_github_commit_author(
@@ -313,7 +308,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_ssh_url() {
let parsed_remote = Github::public_instance()
let parsed_remote = Github::new()
.parse_remote_url("git@github.com:zed-industries/zed.git")
.unwrap();
@@ -328,7 +323,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_https_url() {
let parsed_remote = Github::public_instance()
let parsed_remote = Github::new()
.parse_remote_url("https://github.com/zed-industries/zed.git")
.unwrap();
@@ -343,7 +338,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_https_url_with_username() {
let parsed_remote = Github::public_instance()
let parsed_remote = Github::new()
.parse_remote_url("https://jlannister@github.com/some-org/some-repo.git")
.unwrap();
@@ -362,7 +357,7 @@ mod tests {
owner: "zed-industries".into(),
repo: "zed".into(),
};
let permalink = Github::public_instance().build_permalink(
let permalink = Github::new().build_permalink(
remote,
BuildPermalinkParams {
sha: "e6ebe7974deb6bb6cc0e2595c8ec31f0c71084b7",
@@ -377,7 +372,7 @@ mod tests {
#[test]
fn test_build_github_permalink() {
let permalink = Github::public_instance().build_permalink(
let permalink = Github::new().build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
@@ -395,7 +390,7 @@ mod tests {
#[test]
fn test_build_github_permalink_with_single_line_selection() {
let permalink = Github::public_instance().build_permalink(
let permalink = Github::new().build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
@@ -413,7 +408,7 @@ mod tests {
#[test]
fn test_build_github_permalink_with_multi_line_selection() {
let permalink = Github::public_instance().build_permalink(
let permalink = Github::new().build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
@@ -436,7 +431,7 @@ mod tests {
repo: "zed".into(),
};
let github = Github::public_instance();
let github = Github::new();
let message = "This does not contain a pull request";
assert!(github.extract_pull_request(&remote, message).is_none());

View File

@@ -17,17 +17,13 @@ pub struct Gitlab {
}
impl Gitlab {
pub fn new(name: impl Into<String>, base_url: Url) -> Self {
pub fn new() -> Self {
Self {
name: name.into(),
base_url,
name: "GitLab".to_string(),
base_url: Url::parse("https://gitlab.com").unwrap(),
}
}
pub fn public_instance() -> Self {
Self::new("GitLab", Url::parse("https://gitlab.com").unwrap())
}
pub fn from_remote_url(remote_url: &str) -> Result<Self> {
let host = get_host_from_git_remote_url(remote_url)?;
if host == "gitlab.com" {
@@ -41,10 +37,10 @@ impl Gitlab {
bail!("not a GitLab URL");
}
Ok(Self::new(
"GitLab Self-Hosted",
Url::parse(&format!("https://{}", host))?,
))
Ok(Self {
name: "GitLab Self-Hosted".to_string(),
base_url: Url::parse(&format!("https://{}", host))?,
})
}
}
@@ -139,7 +135,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_ssh_url() {
let parsed_remote = Gitlab::public_instance()
let parsed_remote = Gitlab::new()
.parse_remote_url("git@gitlab.com:zed-industries/zed.git")
.unwrap();
@@ -154,7 +150,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_https_url() {
let parsed_remote = Gitlab::public_instance()
let parsed_remote = Gitlab::new()
.parse_remote_url("https://gitlab.com/zed-industries/zed.git")
.unwrap();
@@ -204,7 +200,7 @@ mod tests {
#[test]
fn test_build_gitlab_permalink() {
let permalink = Gitlab::public_instance().build_permalink(
let permalink = Gitlab::new().build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
@@ -222,7 +218,7 @@ mod tests {
#[test]
fn test_build_gitlab_permalink_with_single_line_selection() {
let permalink = Gitlab::public_instance().build_permalink(
let permalink = Gitlab::new().build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
@@ -240,7 +236,7 @@ mod tests {
#[test]
fn test_build_gitlab_permalink_with_multi_line_selection() {
let permalink = Gitlab::public_instance().build_permalink(
let permalink = Gitlab::new().build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),

View File

@@ -11,7 +11,7 @@ pub struct Sourcehut;
impl GitHostingProvider for Sourcehut {
fn name(&self) -> String {
"SourceHut".to_string()
"Gitee".to_string()
}
fn base_url(&self) -> Url {

View File

@@ -1,84 +0,0 @@
use std::sync::Arc;
use anyhow::Result;
use git::GitHostingProviderRegistry;
use gpui::App;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use url::Url;
use util::ResultExt as _;
use crate::{Bitbucket, Github, Gitlab};
pub(crate) fn init(cx: &mut App) {
GitHostingProviderSettings::register(cx);
init_git_hosting_provider_settings(cx);
}
fn init_git_hosting_provider_settings(cx: &mut App) {
update_git_hosting_providers_from_settings(cx);
cx.observe_global::<SettingsStore>(update_git_hosting_providers_from_settings)
.detach();
}
fn update_git_hosting_providers_from_settings(cx: &mut App) {
let settings = GitHostingProviderSettings::get_global(cx);
let provider_registry = GitHostingProviderRegistry::global(cx);
for provider in settings.git_hosting_providers.iter() {
let Some(url) = Url::parse(&provider.base_url).log_err() else {
continue;
};
let provider = match provider.provider {
GitHostingProviderKind::Bitbucket => Arc::new(Bitbucket::new(&provider.name, url)) as _,
GitHostingProviderKind::Github => Arc::new(Github::new(&provider.name, url)) as _,
GitHostingProviderKind::Gitlab => Arc::new(Gitlab::new(&provider.name, url)) as _,
};
provider_registry.register_hosting_provider(provider);
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum GitHostingProviderKind {
Github,
Gitlab,
Bitbucket,
}
/// A custom Git hosting provider.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct GitHostingProviderConfig {
/// The type of the provider.
///
/// Must be one of `github`, `gitlab`, or `bitbucket`.
pub provider: GitHostingProviderKind,
/// The base URL for the provider (e.g., "https://code.corp.big.com").
pub base_url: String,
/// The display name for the provider (e.g., "BigCorp GitHub").
pub name: String,
}
#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
pub struct GitHostingProviderSettings {
/// The list of custom Git hosting providers.
#[serde(default)]
pub git_hosting_providers: Vec<GitHostingProviderConfig>,
}
impl Settings for GitHostingProviderSettings {
const KEY: Option<&'static str> = None;
type FileContent = Self;
fn load(sources: settings::SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
sources.json_merge()
}
}

View File

@@ -7,7 +7,7 @@ use gpui::{
InteractiveElement, IntoElement, Modifiers, ModifiersChangedEvent, ParentElement, Render,
SharedString, Styled, Subscription, Task, Window,
};
use picker::{Picker, PickerDelegate, PickerEditorPosition};
use picker::{Picker, PickerDelegate};
use project::git::Repository;
use std::sync::Arc;
use time::OffsetDateTime;
@@ -17,10 +17,13 @@ use util::ResultExt;
use workspace::notifications::DetachAndPromptErr;
use workspace::{ModalView, Workspace};
pub fn register(workspace: &mut Workspace) {
workspace.register_action(open);
workspace.register_action(switch);
workspace.register_action(checkout_branch);
pub fn init(cx: &mut App) {
cx.observe_new(|workspace: &mut Workspace, _, _| {
workspace.register_action(open);
workspace.register_action(switch);
workspace.register_action(checkout_branch);
})
.detach();
}
pub fn checkout_branch(
@@ -222,13 +225,6 @@ impl PickerDelegate for BranchListDelegate {
"Select branch...".into()
}
fn editor_position(&self) -> PickerEditorPosition {
match self.style {
BranchListStyle::Modal => PickerEditorPosition::Start,
BranchListStyle::Popover => PickerEditorPosition::End,
}
}
fn match_count(&self) -> usize {
self.matches.len()
}

View File

@@ -54,6 +54,16 @@ impl ModalContainerProperties {
}
}
pub fn init(cx: &mut App) {
cx.observe_new(|workspace: &mut Workspace, window, cx| {
let Some(window) = window else {
return;
};
CommitModal::register(workspace, window, cx)
})
.detach();
}
pub struct CommitModal {
git_panel: Entity<GitPanel>,
commit_editor: Entity<Editor>,
@@ -98,7 +108,7 @@ struct RestoreDock {
}
impl CommitModal {
pub fn register(workspace: &mut Workspace) {
pub fn register(workspace: &mut Workspace, _: &mut Window, _cx: &mut Context<Workspace>) {
workspace.register_action(|workspace, _: &Commit, window, cx| {
CommitModal::toggle(workspace, window, cx);
});

View File

@@ -127,13 +127,18 @@ const GIT_PANEL_KEY: &str = "GitPanel";
const UPDATE_DEBOUNCE: Duration = Duration::from_millis(50);
pub fn register(workspace: &mut Workspace) {
workspace.register_action(|workspace, _: &ToggleFocus, window, cx| {
workspace.toggle_panel_focus::<GitPanel>(window, cx);
});
workspace.register_action(|workspace, _: &ExpandCommitEditor, window, cx| {
CommitModal::toggle(workspace, window, cx)
});
pub fn init(cx: &mut App) {
cx.observe_new(
|workspace: &mut Workspace, _window, _: &mut Context<Workspace>| {
workspace.register_action(|workspace, _: &ToggleFocus, window, cx| {
workspace.toggle_panel_focus::<GitPanel>(window, cx);
});
workspace.register_action(|workspace, _: &ExpandCommitEditor, window, cx| {
CommitModal::toggle(workspace, window, cx)
});
},
)
.detach();
}
#[derive(Debug, Clone)]
@@ -2001,7 +2006,7 @@ impl GitPanel {
}
fn can_push_and_pull(&self, cx: &App) -> bool {
!self.project.read(cx).is_via_collab()
crate::can_push_and_pull(&self.project, cx)
}
fn get_current_remote(
@@ -2237,14 +2242,7 @@ impl GitPanel {
fn update_visible_entries(&mut self, cx: &mut Context<Self>) {
self.entries.clear();
self.single_staged_entry.take();
self.conflicted_count = 0;
self.conflicted_staged_count = 0;
self.new_count = 0;
self.tracked_count = 0;
self.new_staged_count = 0;
self.tracked_staged_count = 0;
self.entry_count = 0;
self.single_staged_entry.take();
let mut changed_entries = Vec::new();
let mut new_entries = Vec::new();
let mut conflict_entries = Vec::new();
@@ -2824,36 +2822,6 @@ impl GitPanel {
)
}
pub(crate) fn render_remote_button(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let branch = self
.active_repository
.as_ref()?
.read(cx)
.current_branch()
.cloned();
if !self.can_push_and_pull(cx) {
return None;
}
let spinner = self.render_spinner();
Some(
h_flex()
.gap_1()
.flex_shrink_0()
.children(spinner)
.when_some(branch, |this, branch| {
let focus_handle = Some(self.focus_handle(cx));
this.children(render_remote_button(
"remote-button",
&branch,
focus_handle,
true,
))
})
.into_any_element(),
)
}
pub fn render_footer(
&self,
window: &mut Window,
@@ -2893,7 +2861,12 @@ impl GitPanel {
});
let footer = v_flex()
.child(PanelRepoFooter::new(display_name, branch, Some(git_panel)))
.child(PanelRepoFooter::new(
"footer-button",
display_name,
branch,
Some(git_panel),
))
.child(
panel_editor_container(window, cx)
.id("commit-editor-container")
@@ -2944,10 +2917,6 @@ impl GitPanel {
.disabled(!can_commit || self.modal_open)
.on_click({
cx.listener(move |this, _: &ClickEvent, window, cx| {
telemetry::event!(
"Git Committed",
source = "Git Panel"
);
this.commit_changes(window, cx)
})
}),
@@ -3835,7 +3804,7 @@ impl Render for GitPanel {
deferred(
anchored()
.position(*position)
.anchor(Corner::TopLeft)
.anchor(gpui::Corner::TopLeft)
.child(menu.clone()),
)
.with_priority(1)
@@ -3990,6 +3959,7 @@ impl Render for GitPanelMessageTooltip {
#[derive(IntoElement, IntoComponent)]
#[component(scope = "Version Control")]
pub struct PanelRepoFooter {
id: SharedString,
active_repository: SharedString,
branch: Option<Branch>,
// Getting a GitPanel in previews will be difficult.
@@ -4000,19 +3970,26 @@ pub struct PanelRepoFooter {
impl PanelRepoFooter {
pub fn new(
id: impl Into<SharedString>,
active_repository: SharedString,
branch: Option<Branch>,
git_panel: Option<Entity<GitPanel>>,
) -> Self {
Self {
id: id.into(),
active_repository,
branch,
git_panel,
}
}
pub fn new_preview(active_repository: SharedString, branch: Option<Branch>) -> Self {
pub fn new_preview(
id: impl Into<SharedString>,
active_repository: SharedString,
branch: Option<Branch>,
) -> Self {
Self {
id: id.into(),
active_repository,
branch,
git_panel: None,
@@ -4093,14 +4070,14 @@ impl RenderOnce for PanelRepoFooter {
let project = project.clone();
move |window, cx| {
let project = project.clone()?;
Some(cx.new(|cx| RepositorySelector::new(project, rems(16.), window, cx)))
Some(cx.new(|cx| RepositorySelector::new(project, window, cx)))
}
})
.trigger_with_tooltip(
repo_selector_trigger.disabled(single_repo).truncate(true),
Tooltip::text("Switch active repository"),
)
.anchor(Corner::BottomLeft)
.attach(gpui::Corner::BottomLeft)
.into_any_element();
let branch_selector_button = Button::new("branch-selector", truncated_branch_name)
@@ -4122,12 +4099,17 @@ impl RenderOnce for PanelRepoFooter {
branch_selector_button,
Tooltip::for_action_title("Switch Branch", &zed_actions::git::Branch),
)
.anchor(Corner::BottomLeft)
.anchor(Corner::TopLeft)
.offset(gpui::Point {
x: px(0.0),
y: px(-2.0),
});
let spinner = self
.git_panel
.as_ref()
.and_then(|git_panel| git_panel.read(cx).render_spinner());
h_flex()
.w_full()
.px_2()
@@ -4162,11 +4144,28 @@ impl RenderOnce for PanelRepoFooter {
})
.child(branch_selector),
)
.children(if let Some(git_panel) = self.git_panel {
git_panel.update(cx, |git_panel, cx| git_panel.render_remote_button(cx))
} else {
None
})
.child(
h_flex()
.gap_1()
.flex_shrink_0()
.children(spinner)
.when_some(branch, |this, branch| {
let mut focus_handle = None;
if let Some(git_panel) = self.git_panel.as_ref() {
if !git_panel.read(cx).can_push_and_pull(cx) {
return this;
}
focus_handle = Some(git_panel.focus_handle(cx));
}
this.children(render_remote_button(
self.id.clone(),
&branch,
focus_handle,
true,
))
}),
)
}
}
@@ -4257,6 +4256,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"no-branch",
active_repository(1).clone(),
None,
))
@@ -4269,6 +4269,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"unknown-upstream",
active_repository(2).clone(),
Some(branch(unknown_upstream)),
))
@@ -4281,6 +4282,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"no-remote-upstream",
active_repository(3).clone(),
Some(branch(no_remote_upstream)),
))
@@ -4293,6 +4295,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"not-ahead-or-behind",
active_repository(4).clone(),
Some(branch(not_ahead_or_behind_upstream)),
))
@@ -4305,6 +4308,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"behind-remote",
active_repository(5).clone(),
Some(branch(behind_upstream)),
))
@@ -4317,6 +4321,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"ahead-of-remote",
active_repository(6).clone(),
Some(branch(ahead_of_upstream)),
))
@@ -4329,6 +4334,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"ahead-and-behind",
active_repository(7).clone(),
Some(branch(ahead_and_behind_upstream)),
))
@@ -4348,6 +4354,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"short-branch",
SharedString::from("zed"),
Some(custom("main", behind_upstream)),
))
@@ -4360,6 +4367,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"long-branch",
SharedString::from("zed"),
Some(custom(
"redesign-and-update-git-ui-list-entry-style",
@@ -4375,6 +4383,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"long-repo",
SharedString::from("zed-industries-community-examples"),
Some(custom("gpui", ahead_of_upstream)),
))
@@ -4387,6 +4396,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"long-repo-and-branch",
SharedString::from("zed-industries-community-examples"),
Some(custom(
"redesign-and-update-git-ui-list-entry-style",
@@ -4402,6 +4412,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"uppercase-repo",
SharedString::from("LICENSES"),
Some(custom("main", ahead_of_upstream)),
))
@@ -4414,6 +4425,7 @@ impl ComponentPreview for PanelRepoFooter {
.w(example_width)
.overflow_hidden()
.child(PanelRepoFooter::new_preview(
"uppercase-branch",
SharedString::from("zed"),
Some(custom("update-README", behind_upstream)),
))

View File

@@ -2,14 +2,14 @@ use std::any::Any;
use ::settings::Settings;
use command_palette_hooks::CommandPaletteFilter;
use commit_modal::CommitModal;
use git::{
repository::{Branch, Upstream, UpstreamTracking, UpstreamTrackingStatus},
status::{FileStatus, StatusCode, UnmergedStatus, UnmergedStatusCode},
};
use git_panel_settings::GitPanelSettings;
use gpui::{actions, App, FocusHandle};
use gpui::{actions, App, Entity, FocusHandle};
use onboarding::{clear_dismissed, GitOnboardingModal};
use project::Project;
use project_diff::ProjectDiff;
use ui::prelude::*;
use workspace::Workspace;
@@ -29,14 +29,12 @@ actions!(git, [ResetOnboarding]);
pub fn init(cx: &mut App) {
GitPanelSettings::register(cx);
branch_picker::init(cx);
cx.observe_new(ProjectDiff::register).detach();
commit_modal::init(cx);
git_panel::init(cx);
cx.observe_new(|workspace: &mut Workspace, _, cx| {
ProjectDiff::register(workspace, cx);
CommitModal::register(workspace);
git_panel::register(workspace);
repository_selector::register(workspace);
branch_picker::register(workspace);
let project = workspace.project().read(cx);
if project.is_read_only(cx) {
return;
@@ -122,6 +120,10 @@ pub fn git_status_icon(status: FileStatus) -> impl IntoElement {
GitStatusIcon::new(status)
}
fn can_push_and_pull(project: &Entity<Project>, cx: &App) -> bool {
!project.read(cx).is_via_collab()
}
fn render_remote_button(
id: impl Into<SharedString>,
branch: &Branch,

View File

@@ -61,16 +61,21 @@ struct DiffBuffer {
file_status: FileStatus,
}
const CONFLICT_NAMESPACE: &'static str = "0";
const TRACKED_NAMESPACE: &'static str = "1";
const NEW_NAMESPACE: &'static str = "2";
const CONFLICT_NAMESPACE: usize = 0;
const TRACKED_NAMESPACE: usize = 1;
const NEW_NAMESPACE: usize = 2;
impl ProjectDiff {
pub(crate) fn register(workspace: &mut Workspace, cx: &mut Context<Workspace>) {
pub(crate) fn register(
workspace: &mut Workspace,
_window: Option<&mut Window>,
cx: &mut Context<Workspace>,
) {
workspace.register_action(Self::deploy);
workspace.register_action(|workspace, _: &Add, window, cx| {
Self::deploy(workspace, &Diff, window, cx);
});
workspace::register_serializable_item::<ProjectDiff>(cx);
}
@@ -668,6 +673,8 @@ impl Render for ProjectDiff {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_empty = self.multibuffer.read(cx).is_empty();
let can_push_and_pull = crate::can_push_and_pull(&self.project, cx);
div()
.track_focus(&self.focus_handle)
.key_context(if is_empty { "EmptyPane" } else { "GitDiff" })
@@ -677,16 +684,6 @@ impl Render for ProjectDiff {
.justify_center()
.size_full()
.when(is_empty, |el| {
let remote_button = if let Some(panel) = self
.workspace
.upgrade()
.and_then(|workspace| workspace.read(cx).panel::<GitPanel>(cx))
{
panel.update(cx, |panel, cx| panel.render_remote_button(cx))
} else {
None
};
let keybinding_focus_handle = self.focus_handle(cx).clone();
el.child(
v_flex()
.gap_1()
@@ -695,33 +692,52 @@ impl Render for ProjectDiff {
.justify_around()
.child(Label::new("No uncommitted changes")),
)
.map(|el| match remote_button {
Some(button) => el.child(h_flex().justify_around().child(button)),
None => el.child(
h_flex()
.justify_around()
.child(Label::new("Remote up to date")),
),
.when(can_push_and_pull, |this_div| {
let keybinding_focus_handle = self.focus_handle(cx);
this_div.when_some(self.current_branch.as_ref(), |this_div, branch| {
let remote_button = crate::render_remote_button(
"project-diff-remote-button",
branch,
Some(keybinding_focus_handle.clone()),
false,
);
match remote_button {
Some(button) => {
this_div.child(h_flex().justify_around().child(button))
}
None => this_div.child(
h_flex()
.justify_around()
.child(Label::new("Remote up to date")),
),
}
})
})
.child(
h_flex().justify_around().mt_1().child(
Button::new("project-diff-close-button", "Close")
// .style(ButtonStyle::Transparent)
.key_binding(KeyBinding::for_action_in(
&CloseActiveItem::default(),
&keybinding_focus_handle,
window,
cx,
))
.on_click(move |_, window, cx| {
window.focus(&keybinding_focus_handle);
window.dispatch_action(
Box::new(CloseActiveItem::default()),
.map(|this| {
let keybinding_focus_handle = self.focus_handle(cx).clone();
this.child(
h_flex().justify_around().mt_1().child(
Button::new("project-diff-close-button", "Close")
// .style(ButtonStyle::Transparent)
.key_binding(KeyBinding::for_action_in(
&CloseActiveItem::default(),
&keybinding_focus_handle,
window,
cx,
);
}),
),
),
))
.on_click(move |_, window, cx| {
window.focus(&keybinding_focus_handle);
window.dispatch_action(
Box::new(CloseActiveItem::default()),
cx,
);
}),
),
)
}),
)
})
.when(!is_empty, |el| el.child(self.editor.clone()))
@@ -790,7 +806,6 @@ impl ProjectDiffToolbar {
fn project_diff(&self, _: &App) -> Option<Entity<ProjectDiff>> {
self.project_diff.as_ref()?.upgrade()
}
fn dispatch_action(&self, action: &dyn Action, window: &mut Window, cx: &mut Context<Self>) {
if let Some(project_diff) = self.project_diff(cx) {
project_diff.focus_handle(cx).focus(window);
@@ -904,6 +919,12 @@ impl Render for ProjectDiffToolbar {
&StageAndNext,
&focus_handle,
))
// don't actually disable the button so it's mashable
.color(if button_states.stage {
Color::Default
} else {
Color::Disabled
})
.on_click(cx.listener(|this, _, window, cx| {
this.dispatch_action(&StageAndNext, window, cx)
})),
@@ -915,6 +936,11 @@ impl Render for ProjectDiffToolbar {
&UnstageAndNext,
&focus_handle,
))
.color(if button_states.unstage {
Color::Default
} else {
Color::Disabled
})
.on_click(cx.listener(|this, _, window, cx| {
this.dispatch_action(&UnstageAndNext, window, cx)
})),

View File

@@ -9,33 +9,14 @@ use project::{
};
use std::sync::Arc;
use ui::{prelude::*, ListItem, ListItemSpacing};
use workspace::{ModalView, Workspace};
pub fn register(workspace: &mut Workspace) {
workspace.register_action(open);
}
pub fn open(
workspace: &mut Workspace,
_: &zed_actions::git::SelectRepo,
window: &mut Window,
cx: &mut Context<Workspace>,
) {
let project = workspace.project().clone();
workspace.toggle_modal(window, cx, |window, cx| {
RepositorySelector::new(project, rems(34.), window, cx)
})
}
pub struct RepositorySelector {
width: Rems,
picker: Entity<Picker<RepositorySelectorDelegate>>,
}
impl RepositorySelector {
pub fn new(
project_handle: Entity<Project>,
width: Rems,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -67,7 +48,7 @@ impl RepositorySelector {
.max_height(Some(rems(20.).into()))
});
RepositorySelector { picker, width }
RepositorySelector { picker }
}
}
@@ -110,12 +91,10 @@ impl Focusable for RepositorySelector {
impl Render for RepositorySelector {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
div().w(self.width).child(self.picker.clone())
self.picker.clone()
}
}
impl ModalView for RepositorySelector {}
pub struct RepositorySelectorDelegate {
project: WeakEntity<Project>,
repository_selector: WeakEntity<RepositorySelector>,

Some files were not shown because too many files have changed in this diff Show More