Compare commits

..

112 Commits

Author SHA1 Message Date
Junkui Zhang
d103a3ecb2 refactor 2025-06-04 20:24:17 +08:00
Junkui Zhang
c921212c59 fix merge conflicts 2025-06-04 20:19:11 +08:00
Junkui Zhang
5bf3f5be50 clippy 2025-06-04 17:27:41 +08:00
Junkui Zhang
10ce1111f2 fix merge conflicts 2025-06-04 17:03:14 +08:00
Junkui Zhang
3c84069df2 clippy 2025-06-04 16:53:32 +08:00
Junkui Zhang
ea35b55ff5 remove debug print 2025-06-04 16:53:32 +08:00
Junkui Zhang
33bb07125f use atomic write 2025-06-04 16:53:32 +08:00
Junkui Zhang
109471a768 fix tests 2025-06-04 16:53:32 +08:00
Junkui Zhang
78e5134b7d fix ci 2025-06-04 16:53:32 +08:00
Junkui Zhang
6211ad1e61 fix all 2025-06-04 16:53:32 +08:00
Junkui Zhang
9d203ba0d8 rename static 2025-06-04 16:53:31 +08:00
Junkui Zhang
2a1ec794eb fix ci 2025-06-04 16:53:31 +08:00
Junkui Zhang
44e9f7f0a0 use xvfb 2025-06-04 16:53:31 +08:00
Junkui Zhang
1002a04583 test 2025-06-04 16:53:31 +08:00
Junkui Zhang
1bd09837c8 try 2025-06-04 16:53:31 +08:00
Junkui Zhang
616f932145 fix 2025-06-04 16:53:31 +08:00
Junkui Zhang
c8d1de87d6 fix 2025-06-04 16:53:31 +08:00
Junkui Zhang
17f35e88ac test 2025-06-04 16:53:31 +08:00
Junkui Zhang
b96b6cf5ef test 2025-06-04 16:53:31 +08:00
Junkui Zhang
2714fc6d5b test 2025-06-04 16:53:30 +08:00
Junkui Zhang
6064799099 test 2025-06-04 16:53:18 +08:00
Junkui Zhang
fa54137c05 test 2025-06-04 16:53:18 +08:00
Junkui Zhang
2d531fab6a test 2025-06-04 16:53:18 +08:00
Junkui Zhang
f4a24167fd fix 2025-06-04 16:53:18 +08:00
Junkui Zhang
4e46aa5057 trigger ci 2025-06-04 16:53:18 +08:00
Junkui Zhang
b23efbf62f test 2025-06-04 16:53:17 +08:00
Junkui Zhang
e54780b2da test 2025-06-04 16:53:17 +08:00
Junkui Zhang
e3bd1a0e41 test 2025-06-04 16:53:17 +08:00
Junkui Zhang
2b0b9b337d test 2025-06-04 16:53:17 +08:00
Junkui Zhang
c65c6f20e1 try fix linux 2025-06-04 16:53:17 +08:00
Junkui Zhang
5e4959728e fix Russian layout on macOS 2025-06-04 16:53:17 +08:00
Junkui Zhang
cb893ccd84 use is_immutable_key again 2025-06-04 16:53:17 +08:00
Junkui Zhang
d37b7f5d21 fix linux 2025-06-04 16:53:17 +08:00
Junkui Zhang
7c9c2363ab fix linux 2025-06-04 16:53:16 +08:00
Junkui Zhang
c00ce9f327 Fix all tests for linux 2025-06-04 16:53:16 +08:00
Junkui Zhang
a1fb04b49f fix linux 2025-06-04 16:53:16 +08:00
Junkui Zhang
bd4557503f use the new serializer 2025-06-04 16:53:16 +08:00
Junkui Zhang
8af8e69870 introduce KeymapSerializer 2025-06-04 16:53:16 +08:00
Junkui Zhang
91d56b815a remove is_immutable_keys 2025-06-04 16:53:16 +08:00
Junkui Zhang
7f06102e7e fix 2025-06-04 16:53:15 +08:00
Junkui Zhang
ba6dc2040e disblae tests for linux 2025-06-04 16:53:15 +08:00
Junkui Zhang
0f8327f662 disable for linux 2025-06-04 16:53:15 +08:00
Junkui Zhang
7a84135aac Fix macOS and Linux 2025-06-04 16:53:15 +08:00
Junkui Zhang
1393fefb8f use Result<Option<key>> instead 2025-06-04 16:53:15 +08:00
Junkui Zhang
8f74bdd091 fix 2025-06-04 16:53:15 +08:00
Junkui Zhang
010f345963 basic impl for linux 2025-06-04 16:53:15 +08:00
Junkui Zhang
4bbc56d7bd try fix all 2025-06-04 16:53:14 +08:00
Junkui Zhang
e5c6ef8361 fmt 2025-06-04 16:53:14 +08:00
Junkui Zhang
7f11dea9f2 try fix 2025-06-04 16:53:14 +08:00
Junkui Zhang
36fc38cf87 fix 2025-06-04 16:53:14 +08:00
Junkui Zhang
7d1f5d4718 fix tests 2025-06-04 16:53:14 +08:00
Junkui Zhang
3027225767 remove unused function 2025-06-04 16:53:14 +08:00
Junkui Zhang
0c6f50032b clippy 2025-06-04 16:53:14 +08:00
Junkui Zhang
b26332f514 trigger ci 2025-06-04 16:53:13 +08:00
Junkui Zhang
1c5fd70a04 fix all 2025-06-04 16:53:13 +08:00
Junkui Zhang
451e7bed60 fix tests 2025-06-04 16:53:01 +08:00
Junkui Zhang
854f55b4d8 manual serialize 2025-06-04 16:53:01 +08:00
Junkui Zhang
c8b6020804 refactor 2025-06-04 16:53:01 +08:00
Junkui Zhang
622f6ffe2d CHECKPOINT 2025-06-04 16:53:01 +08:00
Junkui Zhang
49a85e59ac rename some variables 2025-06-04 16:53:01 +08:00
Junkui Zhang
43f0e6dedd checkpoint 2025-06-04 16:53:01 +08:00
Junkui Zhang
091d2cfcb3 checkpoint 2025-06-04 16:53:01 +08:00
Junkui Zhang
de217bef18 wip 2025-06-04 16:53:01 +08:00
Junkui Zhang
991922ce54 wip 2025-06-04 16:53:00 +08:00
Junkui Zhang
68ca64a310 checkpoint 2025-06-04 16:53:00 +08:00
Junkui Zhang
8ef6c77934 add debugger 2025-06-04 16:53:00 +08:00
Junkui Zhang
13811202cb wip 2025-06-04 16:53:00 +08:00
Junkui Zhang
595be0135c wip 2025-06-04 16:53:00 +08:00
Junkui Zhang
2f1f231f0b add to_gpui_style 2025-06-04 16:53:00 +08:00
Junkui Zhang
4de395f933 rename some parameters 2025-06-04 16:52:59 +08:00
Junkui Zhang
f34e6f127d fix tests 2025-06-04 16:52:59 +08:00
Junkui Zhang
90bbc49b0c fix test 2025-06-04 16:52:59 +08:00
Junkui Zhang
76ffcb3c77 refactor 2025-06-04 16:52:59 +08:00
Junkui Zhang
6a929b7dc5 add more test 2025-06-04 16:52:59 +08:00
Junkui Zhang
94c78851c3 fix windows oem check 2025-06-04 16:52:59 +08:00
Junkui Zhang
23f68c9ffc simplify 2025-06-04 16:52:58 +08:00
Junkui Zhang
1ddb4126c8 fix macOS 2025-06-04 16:52:58 +08:00
Junkui Zhang
92440054ce fix scan code 2025-06-04 16:52:58 +08:00
Junkui Zhang
493c3a6084 fix macOS 2025-06-04 16:52:58 +08:00
Junkui Zhang
16c92944df test passes 2025-06-04 16:52:58 +08:00
Junkui Zhang
8b569d7dc2 fix tests 2025-06-04 16:52:58 +08:00
Junkui Zhang
8a7ab5c3d4 add test 2025-06-04 16:52:58 +08:00
Junkui Zhang
fc1947b97c fix all 2025-06-04 16:52:57 +08:00
Junkui Zhang
6fc96e6b7f parse ScanCode 2025-06-04 16:52:57 +08:00
Junkui Zhang
fb89852586 add ScanCode 2025-06-04 16:52:57 +08:00
Junkui Zhang
2c04d7d118 add keycodes 2025-06-04 16:52:57 +08:00
Junkui Zhang
424ce07c35 fix deserializing command 2025-06-04 16:52:57 +08:00
Junkui Zhang
2489e4cea4 fix 2025-06-04 16:52:57 +08:00
Junkui Zhang
eaa1821673 fix macOS 2025-06-04 16:52:57 +08:00
Junkui Zhang
b11688a5e5 fix macOS 2025-06-04 16:52:01 +08:00
Junkui Zhang
8510cea452 checkpoint 2025-06-04 16:52:00 +08:00
Junkui Zhang
77ce5a7a7c tests 2025-06-04 16:51:15 +08:00
Junkui Zhang
9a6c7d5c41 windows impl 2025-06-04 16:51:15 +08:00
Junkui Zhang
126ba040e8 rename method 2025-06-04 16:51:15 +08:00
Junkui Zhang
d49f75ab3f add tests 2025-06-04 16:51:14 +08:00
Junkui Zhang
2fbb5e5db4 fix 2025-06-04 16:51:14 +08:00
Junkui Zhang
f4c0d52530 macOS checkpoint 2025-06-04 16:51:14 +08:00
Junkui Zhang
1d7b61bdd0 refactor 2025-06-04 16:51:14 +08:00
Junkui Zhang
b4b0c58822 add scan codes 2025-06-04 16:51:14 +08:00
Junkui Zhang
932d776efd macOS impl 2025-06-04 16:51:14 +08:00
Junkui Zhang
b0ae7e16f6 fix 2025-06-04 16:51:05 +08:00
Junkui Zhang
b44e7a82c1 Allow setting other separators 2025-06-04 16:51:04 +08:00
Junkui Zhang
9730a36dd2 parse_shortcuts 2025-06-04 16:50:38 +08:00
Junkui Zhang
3a7d186726 update App 2025-06-04 16:50:37 +08:00
Junkui Zhang
567af9455d Fix windows and test 2025-06-04 16:50:18 +08:00
Junkui Zhang
8537b72597 init test 2025-06-04 16:49:39 +08:00
Junkui Zhang
b60f185bcf register action 2025-06-04 16:49:39 +08:00
Junkui Zhang
46b17ef148 basci impl 2025-06-04 16:47:29 +08:00
Junkui Zhang
8f7d5ecf81 add new trait 2025-06-04 16:47:05 +08:00
Junkui Zhang
fec7620eee Add docs 2025-06-04 16:47:05 +08:00
Junkui Zhang
97319c4fe3 add vscode_shortcuts_file 2025-06-04 16:47:05 +08:00
Junkui Zhang
da422e7dcd Fix vscode_settings_file for Windows 2025-06-04 16:39:10 +08:00
138 changed files with 4403 additions and 5123 deletions

View File

@@ -1,26 +0,0 @@
name: "Build docs"
description: "Build the docs"
runs:
using: "composite"
steps:
- name: Setup mdBook
uses: peaceiris/actions-mdbook@ee69d230fe19748b7abf22df32acaa93833fad08 # v2
with:
mdbook-version: "0.4.37"
- name: Cache dependencies
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
- name: Install Linux dependencies
shell: bash -euxo pipefail {0}
run: ./script/linux
- name: Build book
shell: bash -euxo pipefail {0}
run: |
mkdir -p target/deploy
mdbook build ./docs --dest-dir=../target/deploy/docs/

View File

@@ -1,6 +1,12 @@
name: "Run tests"
description: "Runs the tests"
inputs:
use-xvfb:
description: "Whether to run tests with xvfb"
required: false
default: "false"
runs:
using: "composite"
steps:
@@ -20,4 +26,9 @@ runs:
- name: Run tests
shell: bash -euxo pipefail {0}
run: cargo nextest run --workspace --no-fail-fast
run: |
if [ "${{ inputs.use-xvfb }}" == "true" ]; then
xvfb-run --auto-servernum --server-args="-screen 0 1024x768x24 -nolisten tcp" cargo nextest run --workspace --no-fail-fast
else
cargo nextest run --workspace --no-fail-fast
fi

View File

@@ -191,27 +191,6 @@ jobs:
with:
config: ./typos.toml
check_docs:
timeout-minutes: 60
name: Check docs
needs: [job_spec]
if: github.repository_owner == 'zed-industries'
runs-on:
- buildjet-8vcpu-ubuntu-2204
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
clean: false
- name: Configure CI
run: |
mkdir -p ./../.cargo
cp ./.cargo/ci-config.toml ./../.cargo/config.toml
- name: Build docs
uses: ./.github/actions/build_docs
macos_tests:
timeout-minutes: 60
name: (macOS) Run Clippy and tests
@@ -316,6 +295,8 @@ jobs:
- name: Run tests
uses: ./.github/actions/run_tests
with:
use-xvfb: true
- name: Build other binaries and features
run: |

View File

@@ -9,7 +9,7 @@ jobs:
deploy-docs:
name: Deploy Docs
if: github.repository_owner == 'zed-industries'
runs-on: buildjet-16vcpu-ubuntu-2204
runs-on: ubuntu-latest
steps:
- name: Checkout repo
@@ -17,11 +17,24 @@ jobs:
with:
clean: false
- name: Setup mdBook
uses: peaceiris/actions-mdbook@ee69d230fe19748b7abf22df32acaa93833fad08 # v2
with:
mdbook-version: "0.4.37"
- name: Set up default .cargo/config.toml
run: cp ./.cargo/collab-config.toml ./.cargo/config.toml
- name: Build docs
uses: ./.github/actions/build_docs
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install libxkbcommon-dev libxkbcommon-x11-dev
- name: Build book
run: |
set -euo pipefail
mkdir -p target/deploy
mdbook build ./docs --dest-dir=../target/deploy/docs/
- name: Deploy Docs
uses: cloudflare/wrangler-action@da0e0dfe58b7a431659754fdf3f186c529afbe65 # v3

View File

@@ -1,85 +0,0 @@
name: Run Unit Evals
on:
schedule:
# GitHub might drop jobs at busy times, so we choose a random time in the middle of the night.
- cron: "47 1 * * *"
workflow_dispatch:
concurrency:
# Allow only one workflow per any non-`main` branch.
group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
cancel-in-progress: true
env:
CARGO_TERM_COLOR: always
CARGO_INCREMENTAL: 0
RUST_BACKTRACE: 1
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
jobs:
unit_evals:
timeout-minutes: 60
name: Run unit evals
runs-on:
- buildjet-16vcpu-ubuntu-2204
steps:
- name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
- name: Install Linux dependencies
run: ./script/linux
- name: Configure CI
run: |
mkdir -p ./../.cargo
cp ./.cargo/ci-config.toml ./../.cargo/config.toml
- name: Install Rust
shell: bash -euxo pipefail {0}
run: |
cargo install cargo-nextest --locked
- name: Install Node
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
with:
node-version: "18"
- name: Limit target directory size
shell: bash -euxo pipefail {0}
run: script/clear-target-dir-if-larger-than 100
- name: Run unit evals
shell: bash -euxo pipefail {0}
run: cargo nextest run --workspace --no-fail-fast --features eval --no-capture -E 'test(::eval_)' --test-threads 1
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
- name: Send the pull request link into the Slack channel
if: ${{ failure() }}
uses: slackapi/slack-github-action@b0fa283ad8fea605de13dc3f449259339835fc52
with:
method: chat.postMessage
token: ${{ secrets.SLACK_APP_ZED_UNIT_EVALS_BOT_TOKEN }}
payload: |
channel: C04UDRNNJFQ
text: "Unit Evals Failed: https://github.com/zed-industries/zed/actions/runs/${{ github.run_id }}"
# Even the Linux runner is not stateful, in theory there is no need to do this cleanup.
# But, to avoid potential issues in the future if we choose to use a stateful Linux runner and forget to add code
# to clean up the config file, Ive included the cleanup code here as a precaution.
# While its not strictly necessary at this moment, I believe its better to err on the side of caution.
- name: Clean CI config file
if: always()
run: rm -rf ./../.cargo

6
.rules
View File

@@ -5,12 +5,6 @@
* Prefer implementing functionality in existing files unless it is a new logical component. Avoid creating many small files.
* Avoid using functions that panic like `unwrap()`, instead use mechanisms like `?` to propagate errors.
* Be careful with operations like indexing which may panic if the indexes are out of bounds.
* Never silently discard errors with `let _ =` on fallible operations. Always handle errors appropriately:
- Propagate errors with `?` when the calling function should handle them
- Use `.log_err()` or similar when you need to ignore errors but want visibility
- Use explicit error handling with `match` or `if let Err(...)` when you need custom logic
- Example: avoid `let _ = client.request(...).await?;` - use `client.request(...).await?;` instead
* When implementing async operations that may fail, ensure errors propagate to the UI layer so users get meaningful feedback.
* Never create files with `mod.rs` paths - prefer `src/some_module.rs` instead of `src/some_module/mod.rs`.
# GPUI

15
Cargo.lock generated
View File

@@ -631,7 +631,6 @@ name = "assistant_tool"
version = "0.1.0"
dependencies = [
"anyhow",
"async-watch",
"buffer_diff",
"clock",
"collections",
@@ -4543,8 +4542,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"command_palette",
"gpui",
"mdbook",
"regex",
"serde",
@@ -4552,7 +4549,6 @@ dependencies = [
"settings",
"util",
"workspace-hack",
"zed",
]
[[package]]
@@ -12117,7 +12113,6 @@ dependencies = [
"unindent",
"url",
"util",
"uuid",
"which 6.0.3",
"workspace-hack",
"worktree",
@@ -16513,9 +16508,9 @@ dependencies = [
[[package]]
name = "tree-sitter"
version = "0.25.6"
version = "0.25.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7cf18d43cbf0bfca51f657132cc616a5097edc4424d538bae6fa60142eaf9f0"
checksum = "ac5fff5c47490dfdf473b5228039bfacad9d765d9b6939d26bf7cc064c1c7822"
dependencies = [
"cc",
"regex",
@@ -16528,9 +16523,9 @@ dependencies = [
[[package]]
name = "tree-sitter-bash"
version = "0.25.0"
version = "0.23.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "871b0606e667e98a1237ebdc1b0d7056e0aebfdc3141d12b399865d4cb6ed8a6"
checksum = "329a4d48623ac337d42b1df84e81a1c9dbb2946907c102ca72db158c1964a52e"
dependencies = [
"cc",
"tree-sitter-language",
@@ -19712,7 +19707,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.191.0"
version = "0.190.0"
dependencies = [
"activity_indicator",
"agent",

View File

@@ -574,8 +574,8 @@ tokio = { version = "1" }
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
toml = "0.8"
tower-http = "0.4.4"
tree-sitter = { version = "0.25.6", features = ["wasm"] }
tree-sitter-bash = "0.25.0"
tree-sitter = { version = "0.25.5", features = ["wasm"] }
tree-sitter-bash = "0.23"
tree-sitter-c = "0.23"
tree-sitter-cpp = "0.23"
tree-sitter-css = "0.23"

View File

@@ -120,7 +120,7 @@
"ctrl-'": "editor::ToggleSelectedDiffHunks",
"ctrl-\"": "editor::ExpandAllDiffHunks",
"ctrl-i": "editor::ShowSignatureHelp",
"alt-g b": "git::Blame",
"alt-g b": "editor::ToggleGitBlame",
"menu": "editor::OpenContextMenu",
"shift-f10": "editor::OpenContextMenu",
"ctrl-shift-e": "editor::ToggleEditPrediction",
@@ -512,14 +512,14 @@
{
"context": "Workspace",
"bindings": {
"alt-open": ["projects::OpenRecent", { "create_new_window": false }],
// Change the default action on `menu::Confirm` by setting the parameter
// "alt-ctrl-o": ["projects::OpenRecent", { "create_new_window": true }],
"alt-ctrl-o": ["projects::OpenRecent", { "create_new_window": false }],
"alt-shift-open": ["projects::OpenRemote", { "from_existing_connection": false, "create_new_window": false }],
"alt-open": "projects::OpenRecent",
"alt-ctrl-o": "projects::OpenRecent",
"alt-shift-open": "projects::OpenRemote",
"alt-ctrl-shift-o": "projects::OpenRemote",
// Change to open path modal for existing remote connection by setting the parameter
// "alt-ctrl-shift-o": "["projects::OpenRemote", { "from_existing_connection": true }]",
"alt-ctrl-shift-o": ["projects::OpenRemote", { "from_existing_connection": false, "create_new_window": false }],
"alt-ctrl-shift-b": "branches::OpenRecent",
"alt-shift-enter": "toast::RunAction",
"ctrl-~": "workspace::NewTerminal",
@@ -911,9 +911,7 @@
"context": "CollabPanel && not_editing",
"bindings": {
"ctrl-backspace": "collab_panel::Remove",
"space": "menu::Confirm",
"ctrl-up": "collab_panel::MoveChannelUp",
"ctrl-down": "collab_panel::MoveChannelDown"
"space": "menu::Confirm"
}
},
{

View File

@@ -138,7 +138,7 @@
"cmd-;": "editor::ToggleLineNumbers",
"cmd-'": "editor::ToggleSelectedDiffHunks",
"cmd-\"": "editor::ExpandAllDiffHunks",
"cmd-alt-g b": "git::Blame",
"cmd-alt-g b": "editor::ToggleGitBlame",
"cmd-i": "editor::ShowSignatureHelp",
"f9": "editor::ToggleBreakpoint",
"shift-f9": "editor::EditLogBreakpoint",
@@ -584,9 +584,9 @@
"bindings": {
// Change the default action on `menu::Confirm` by setting the parameter
// "alt-cmd-o": ["projects::OpenRecent", {"create_new_window": true }],
"alt-cmd-o": ["projects::OpenRecent", { "create_new_window": false }],
"ctrl-cmd-o": ["projects::OpenRemote", { "from_existing_connection": false, "create_new_window": false }],
"ctrl-cmd-shift-o": ["projects::OpenRemote", { "from_existing_connection": true, "create_new_window": false }],
"alt-cmd-o": "projects::OpenRecent",
"ctrl-cmd-o": "projects::OpenRemote",
"ctrl-cmd-shift-o": ["projects::OpenRemote", { "from_existing_connection": true }],
"alt-cmd-b": "branches::OpenRecent",
"ctrl-~": "workspace::NewTerminal",
"cmd-s": "workspace::Save",
@@ -967,9 +967,7 @@
"use_key_equivalents": true,
"bindings": {
"ctrl-backspace": "collab_panel::Remove",
"space": "menu::Confirm",
"cmd-up": "collab_panel::MoveChannelUp",
"cmd-down": "collab_panel::MoveChannelDown"
"space": "menu::Confirm"
}
},
{

View File

@@ -198,8 +198,6 @@
"9": ["vim::Number", 9],
"ctrl-w d": "editor::GoToDefinitionSplit",
"ctrl-w g d": "editor::GoToDefinitionSplit",
"ctrl-w ]": "editor::GoToDefinitionSplit",
"ctrl-w ctrl-]": "editor::GoToDefinitionSplit",
"ctrl-w shift-d": "editor::GoToTypeDefinitionSplit",
"ctrl-w g shift-d": "editor::GoToTypeDefinitionSplit",
"ctrl-w space": "editor::OpenExcerptsSplit",

View File

@@ -17,13 +17,13 @@ You are a highly skilled software engineer with extensive knowledge in many prog
4. Use only the tools that are currently available.
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
6. NEVER run commands that don't terminate on their own such as web servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers.
7. Avoid HTML entity escaping - use plain characters instead.
## Searching and Reading
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
{{! TODO: If there are files, we should mention it but otherwise omit that fact }}
{{#if has_tools}}
If appropriate, use tool calls to explore the current project, which contains the following root directories:
{{#each worktrees}}
@@ -38,6 +38,7 @@ If appropriate, use tool calls to explore the current project, which contains th
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
{{/if}}
{{/if}}
{{else}}
You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you).

View File

@@ -533,9 +533,6 @@
"function": false
}
},
// Whether to resize all the panels in a dock when resizing the dock.
// Can be a combination of "left", "right" and "bottom".
"resize_all_panels_in_dock": ["left"],
"project_panel": {
// Whether to show the project panel button in the status bar
"button": true,
@@ -1528,7 +1525,7 @@
"allow_rewrap": "anywhere"
},
"Ruby": {
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "!sorbet", "!steep", "..."]
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "..."]
},
"SCSS": {
"prettier": {

View File

@@ -1,3 +1,4 @@
use crate::AgentPanel;
use crate::context::{AgentContextHandle, RULES_ICON};
use crate::context_picker::{ContextPicker, MentionLink};
use crate::context_store::ContextStore;
@@ -12,7 +13,6 @@ use crate::tool_use::{PendingToolUseStatus, ToolUse};
use crate::ui::{
AddedContext, AgentNotification, AgentNotificationEvent, AnimatedLabel, ContextPill,
};
use crate::{AgentPanel, ModelUsageContext};
use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
use anyhow::Context as _;
use assistant_tool::ToolUseStatus;
@@ -1348,7 +1348,6 @@ impl ActiveThread {
Some(self.text_thread_store.downgrade()),
context_picker_menu_handle.clone(),
SuggestContextKind::File,
ModelUsageContext::Thread(self.thread.clone()),
window,
cx,
)
@@ -1518,7 +1517,31 @@ impl ActiveThread {
}
fn paste(&mut self, _: &Paste, _window: &mut Window, cx: &mut Context<Self>) {
attach_pasted_images_as_context(&self.context_store, cx);
let images = cx
.read_from_clipboard()
.map(|item| {
item.into_entries()
.filter_map(|entry| {
if let ClipboardEntry::Image(image) = entry {
Some(image)
} else {
None
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if images.is_empty() {
return;
}
cx.stop_propagation();
self.context_store.update(cx, |store, cx| {
for image in images {
store.add_image_instance(Arc::new(image), cx);
}
});
}
fn cancel_editing_message(
@@ -1803,10 +1826,9 @@ impl ActiveThread {
// Get all the data we need from thread before we start using it in closures
let checkpoint = thread.checkpoint_for_message(message_id);
let configured_model = thread.configured_model().map(|m| m.model);
let added_context = thread
.context_for_message(message_id)
.map(|context| AddedContext::new_attached(context, configured_model.as_ref(), cx))
.map(|context| AddedContext::new_attached(context, cx))
.collect::<Vec<_>>();
let tool_uses = thread.tool_uses_for_message(message_id, cx);
@@ -3629,38 +3651,6 @@ pub(crate) fn open_context(
}
}
pub(crate) fn attach_pasted_images_as_context(
context_store: &Entity<ContextStore>,
cx: &mut App,
) -> bool {
let images = cx
.read_from_clipboard()
.map(|item| {
item.into_entries()
.filter_map(|entry| {
if let ClipboardEntry::Image(image) = entry {
Some(image)
} else {
None
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if images.is_empty() {
return false;
}
cx.stop_propagation();
context_store.update(cx, |store, cx| {
for image in images {
store.add_image_instance(Arc::new(image), cx);
}
});
true
}
fn open_editor_at_position(
project_path: project::ProjectPath,
target_position: Point,

View File

@@ -33,11 +33,9 @@ use assistant_slash_command::SlashCommandRegistry;
use client::Client;
use feature_flags::FeatureFlagAppExt as _;
use fs::Fs;
use gpui::{App, Entity, actions, impl_actions};
use gpui::{App, actions, impl_actions};
use language::LanguageRegistry;
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProviderId, LanguageModelRegistry,
};
use language_model::{LanguageModelId, LanguageModelProviderId, LanguageModelRegistry};
use prompt_store::PromptBuilder;
use schemars::JsonSchema;
use serde::Deserialize;
@@ -117,28 +115,6 @@ impl ManageProfiles {
impl_actions!(agent, [NewThread, ManageProfiles]);
#[derive(Clone)]
pub(crate) enum ModelUsageContext {
Thread(Entity<Thread>),
InlineAssistant,
}
impl ModelUsageContext {
pub fn configured_model(&self, cx: &App) -> Option<ConfiguredModel> {
match self {
Self::Thread(thread) => thread.read(cx).configured_model(),
Self::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}
}
}
pub fn language_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.configured_model(cx)
.map(|configured_model| configured_model.model)
}
}
/// Initializes the `agent` crate.
pub fn init(
fs: Arc<dyn Fs>,

View File

@@ -1086,7 +1086,7 @@ impl Render for AgentDiffToolbar {
.child(vertical_divider())
.when_some(editor.read(cx).workspace(), |this, _workspace| {
this.child(
IconButton::new("review", IconName::ListTodo)
IconButton::new("review", IconName::ListCollapse)
.icon_size(IconSize::Small)
.tooltip(Tooltip::for_action_title_in(
"Review All Files",
@@ -1116,13 +1116,8 @@ impl Render for AgentDiffToolbar {
return Empty.into_any();
};
let has_pending_edit_tool_use = agent_diff
.read(cx)
.thread
.read(cx)
.has_pending_edit_tool_uses();
if has_pending_edit_tool_use {
let is_generating = agent_diff.read(cx).thread.read(cx).is_generating();
if is_generating {
return div().px_2().child(spinner_icon).into_any();
}
@@ -1512,7 +1507,7 @@ impl AgentDiff {
multibuffer.add_diff(diff_handle.clone(), cx);
});
let new_state = if thread.read(cx).has_pending_edit_tool_uses() {
let new_state = if thread.read(cx).is_generating() {
EditorState::Generating
} else {
EditorState::Reviewing

View File

@@ -3,7 +3,7 @@ use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString};
use picker::popover_menu::PickerPopoverMenu;
use crate::ModelUsageContext;
use crate::Thread;
use assistant_context_editor::language_model_selector::{
LanguageModelSelector, ToggleModelSelector, language_model_selector,
};
@@ -12,6 +12,12 @@ use settings::update_settings_file;
use std::sync::Arc;
use ui::{PopoverMenuHandle, Tooltip, prelude::*};
#[derive(Clone)]
pub enum ModelType {
Default(Entity<Thread>),
InlineAssistant,
}
pub struct AgentModelSelector {
selector: Entity<LanguageModelSelector>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
@@ -23,7 +29,7 @@ impl AgentModelSelector {
fs: Arc<dyn Fs>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
model_usage_context: ModelUsageContext,
model_type: ModelType,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -32,14 +38,19 @@ impl AgentModelSelector {
let fs = fs.clone();
language_model_selector(
{
let model_context = model_usage_context.clone();
move |cx| model_context.configured_model(cx)
let model_type = model_type.clone();
move |cx| match &model_type {
ModelType::Default(thread) => thread.read(cx).configured_model(),
ModelType::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}
}
},
move |model, cx| {
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
match &model_usage_context {
ModelUsageContext::Thread(thread) => {
match &model_type {
ModelType::Default(thread) => {
thread.update(cx, |thread, cx| {
let registry = LanguageModelRegistry::read_global(cx);
if let Some(provider) = registry.provider(&model.provider_id())
@@ -61,7 +72,7 @@ impl AgentModelSelector {
},
);
}
ModelUsageContext::InlineAssistant => {
ModelType::InlineAssistant => {
update_settings_file::<AgentSettings>(
fs.clone(),
cx,

View File

@@ -745,7 +745,6 @@ pub struct ImageContext {
pub enum ImageStatus {
Loading,
Error,
Warning,
Ready,
}
@@ -762,17 +761,11 @@ impl ImageContext {
self.image_task.clone().now_or_never().flatten()
}
pub fn status(&self, model: Option<&Arc<dyn language_model::LanguageModel>>) -> ImageStatus {
pub fn status(&self) -> ImageStatus {
match self.image_task.clone().now_or_never() {
None => ImageStatus::Loading,
Some(None) => ImageStatus::Error,
Some(Some(_)) => {
if model.is_some_and(|model| !model.supports_images()) {
ImageStatus::Warning
} else {
ImageStatus::Ready
}
}
Some(Some(_)) => ImageStatus::Ready,
}
}

View File

@@ -23,7 +23,7 @@ use crate::thread_store::{TextThreadStore, ThreadStore};
use crate::ui::{AddedContext, ContextPill};
use crate::{
AcceptSuggestedContext, AgentPanel, FocusDown, FocusLeft, FocusRight, FocusUp,
ModelUsageContext, RemoveAllContext, RemoveFocusedContext, ToggleContextPicker,
RemoveAllContext, RemoveFocusedContext, ToggleContextPicker,
};
pub struct ContextStrip {
@@ -37,7 +37,6 @@ pub struct ContextStrip {
_subscriptions: Vec<Subscription>,
focused_index: Option<usize>,
children_bounds: Option<Vec<Bounds<Pixels>>>,
model_usage_context: ModelUsageContext,
}
impl ContextStrip {
@@ -48,7 +47,6 @@ impl ContextStrip {
text_thread_store: Option<WeakEntity<TextThreadStore>>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
suggest_context_kind: SuggestContextKind,
model_usage_context: ModelUsageContext,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -83,7 +81,6 @@ impl ContextStrip {
_subscriptions: subscriptions,
focused_index: None,
children_bounds: None,
model_usage_context,
}
}
@@ -101,20 +98,11 @@ impl ContextStrip {
.as_ref()
.and_then(|thread_store| thread_store.upgrade())
.and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref());
let current_model = self.model_usage_context.language_model(cx);
self.context_store
.read(cx)
.context()
.flat_map(|context| {
AddedContext::new_pending(
context.clone(),
prompt_store,
project,
current_model.as_ref(),
cx,
)
AddedContext::new_pending(context.clone(), prompt_store, project, cx)
})
.collect::<Vec<_>>()
} else {

View File

@@ -1,4 +1,4 @@
use crate::agent_model_selector::AgentModelSelector;
use crate::agent_model_selector::{AgentModelSelector, ModelType};
use crate::buffer_codegen::BufferCodegen;
use crate::context::ContextCreasesAddon;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
@@ -7,13 +7,12 @@ use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::message_editor::{extract_message_creases, insert_message_creases};
use crate::terminal_codegen::TerminalCodegen;
use crate::thread_store::{TextThreadStore, ThreadStore};
use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext};
use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist};
use crate::{RemoveAllContext, ToggleContextPicker};
use assistant_context_editor::language_model_selector::ToggleModelSelector;
use client::ErrorExt;
use collections::VecDeque;
use db::kvp::Dismissable;
use editor::actions::Paste;
use editor::display_map::EditorMargins;
use editor::{
ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, MultiBuffer,
@@ -100,7 +99,6 @@ impl<T: 'static> Render for PromptEditor<T> {
v_flex()
.key_context("PromptEditor")
.capture_action(cx.listener(Self::paste))
.bg(cx.theme().colors().editor_background)
.block_mouse_except_scroll()
.gap_0p5()
@@ -305,10 +303,6 @@ impl<T: 'static> PromptEditor<T> {
self.editor.read(cx).text(cx)
}
fn paste(&mut self, _: &Paste, _window: &mut Window, cx: &mut Context<Self>) {
crate::active_thread::attach_pasted_images_as_context(&self.context_store, cx);
}
fn toggle_rate_limit_notice(
&mut self,
_: &ClickEvent,
@@ -918,7 +912,6 @@ impl PromptEditor<BufferCodegen> {
text_thread_store.clone(),
context_picker_menu_handle.clone(),
SuggestContextKind::Thread,
ModelUsageContext::InlineAssistant,
window,
cx,
)
@@ -937,7 +930,7 @@ impl PromptEditor<BufferCodegen> {
fs,
model_selector_menu_handle,
prompt_editor.focus_handle(cx),
ModelUsageContext::InlineAssistant,
ModelType::InlineAssistant,
window,
cx,
)
@@ -1090,7 +1083,6 @@ impl PromptEditor<TerminalCodegen> {
text_thread_store.clone(),
context_picker_menu_handle.clone(),
SuggestContextKind::Thread,
ModelUsageContext::InlineAssistant,
window,
cx,
)
@@ -1109,7 +1101,7 @@ impl PromptEditor<TerminalCodegen> {
fs,
model_selector_menu_handle.clone(),
prompt_editor.focus_handle(cx),
ModelUsageContext::InlineAssistant,
ModelType::InlineAssistant,
window,
cx,
)

View File

@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use std::rc::Rc;
use std::sync::Arc;
use crate::agent_model_selector::AgentModelSelector;
use crate::agent_model_selector::{AgentModelSelector, ModelType};
use crate::context::{AgentContextKey, ContextCreasesAddon, ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use crate::ui::{
@@ -24,8 +24,8 @@ use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt as _, future};
use gpui::{
Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle,
WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription,
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
};
use language::{Buffer, Language, Point};
use language_model::{
@@ -52,8 +52,8 @@ use crate::thread::{MessageCrease, Thread, TokenUsageRatio};
use crate::thread_store::{TextThreadStore, ThreadStore};
use crate::{
ActiveThread, AgentDiffPane, Chat, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll,
ModelUsageContext, NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode,
ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode, ToggleContextPicker,
ToggleProfileSelector, register_agent_preview,
};
#[derive(RegisterComponent)]
@@ -169,7 +169,6 @@ impl MessageEditor {
Some(text_thread_store.clone()),
context_picker_menu_handle.clone(),
SuggestContextKind::File,
ModelUsageContext::Thread(thread.clone()),
window,
cx,
)
@@ -198,7 +197,7 @@ impl MessageEditor {
fs.clone(),
model_selector_menu_handle,
editor.focus_handle(cx),
ModelUsageContext::Thread(thread.clone()),
ModelType::Default(thread.clone()),
window,
cx,
)
@@ -432,7 +431,31 @@ impl MessageEditor {
}
fn paste(&mut self, _: &Paste, _: &mut Window, cx: &mut Context<Self>) {
crate::active_thread::attach_pasted_images_as_context(&self.context_store, cx);
let images = cx
.read_from_clipboard()
.map(|item| {
item.into_entries()
.filter_map(|entry| {
if let ClipboardEntry::Image(image) = entry {
Some(image)
} else {
None
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if images.is_empty() {
return;
}
cx.stop_propagation();
self.context_store.update(cx, |store, cx| {
for image in images {
store.add_image_instance(Arc::new(image), cx);
}
});
}
fn handle_review_click(&mut self, window: &mut Window, cx: &mut Context<Self>) {

View File

@@ -93,9 +93,20 @@ impl ContextPill {
Self::Suggested {
icon_path: Some(icon_path),
..
}
| Self::Added {
context:
AddedContext {
icon_path: Some(icon_path),
..
},
..
} => Icon::from_path(icon_path),
Self::Suggested { kind, .. } => Icon::new(kind.icon()),
Self::Added { context, .. } => context.icon(),
Self::Suggested { kind, .. }
| Self::Added {
context: AddedContext { kind, .. },
..
} => Icon::new(kind.icon()),
}
}
}
@@ -122,7 +133,6 @@ impl RenderOnce for ContextPill {
on_click,
} => {
let status_is_error = matches!(context.status, ContextStatus::Error { .. });
let status_is_warning = matches!(context.status, ContextStatus::Warning { .. });
base_pill
.pr(if on_remove.is_some() { px(2.) } else { px(4.) })
@@ -130,9 +140,6 @@ impl RenderOnce for ContextPill {
if status_is_error {
pill.bg(cx.theme().status().error_background)
.border_color(cx.theme().status().error_border)
} else if status_is_warning {
pill.bg(cx.theme().status().warning_background)
.border_color(cx.theme().status().warning_border)
} else if *focused {
pill.bg(color.element_background)
.border_color(color.border_focused)
@@ -188,8 +195,7 @@ impl RenderOnce for ContextPill {
|label, delta| label.opacity(delta),
)
.into_any_element(),
ContextStatus::Warning { message }
| ContextStatus::Error { message } => element
ContextStatus::Error { message } => element
.tooltip(ui::Tooltip::text(message.clone()))
.into_any_element(),
}),
@@ -264,7 +270,6 @@ pub enum ContextStatus {
Ready,
Loading { message: SharedString },
Error { message: SharedString },
Warning { message: SharedString },
}
#[derive(RegisterComponent)]
@@ -280,19 +285,6 @@ pub struct AddedContext {
}
impl AddedContext {
pub fn icon(&self) -> Icon {
match &self.status {
ContextStatus::Warning { .. } => Icon::new(IconName::Warning).color(Color::Warning),
ContextStatus::Error { .. } => Icon::new(IconName::XCircle).color(Color::Error),
_ => {
if let Some(icon_path) = &self.icon_path {
Icon::from_path(icon_path)
} else {
Icon::new(self.kind.icon())
}
}
}
}
/// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a
/// `None` if `DirectoryContext` or `RulesContext` no longer exist.
///
@@ -301,7 +293,6 @@ impl AddedContext {
handle: AgentContextHandle,
prompt_store: Option<&Entity<PromptStore>>,
project: &Project,
model: Option<&Arc<dyn language_model::LanguageModel>>,
cx: &App,
) -> Option<AddedContext> {
match handle {
@@ -313,15 +304,11 @@ impl AddedContext {
AgentContextHandle::Thread(handle) => Some(Self::pending_thread(handle, cx)),
AgentContextHandle::TextThread(handle) => Some(Self::pending_text_thread(handle, cx)),
AgentContextHandle::Rules(handle) => Self::pending_rules(handle, prompt_store, cx),
AgentContextHandle::Image(handle) => Some(Self::image(handle, model, cx)),
AgentContextHandle::Image(handle) => Some(Self::image(handle, cx)),
}
}
pub fn new_attached(
context: &AgentContext,
model: Option<&Arc<dyn language_model::LanguageModel>>,
cx: &App,
) -> AddedContext {
pub fn new_attached(context: &AgentContext, cx: &App) -> AddedContext {
match context {
AgentContext::File(context) => Self::attached_file(context, cx),
AgentContext::Directory(context) => Self::attached_directory(context),
@@ -331,7 +318,7 @@ impl AddedContext {
AgentContext::Thread(context) => Self::attached_thread(context),
AgentContext::TextThread(context) => Self::attached_text_thread(context),
AgentContext::Rules(context) => Self::attached_rules(context),
AgentContext::Image(context) => Self::image(context.clone(), model, cx),
AgentContext::Image(context) => Self::image(context.clone(), cx),
}
}
@@ -606,11 +593,7 @@ impl AddedContext {
}
}
fn image(
context: ImageContext,
model: Option<&Arc<dyn language_model::LanguageModel>>,
cx: &App,
) -> AddedContext {
fn image(context: ImageContext, cx: &App) -> AddedContext {
let (name, parent, icon_path) = if let Some(full_path) = context.full_path.as_ref() {
let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into();
let (name, parent) =
@@ -621,30 +604,21 @@ impl AddedContext {
("Image".into(), None, None)
};
let status = match context.status(model) {
ImageStatus::Loading => ContextStatus::Loading {
message: "Loading…".into(),
},
ImageStatus::Error => ContextStatus::Error {
message: "Failed to load Image".into(),
},
ImageStatus::Warning => ContextStatus::Warning {
message: format!(
"{} doesn't support attaching Images as Context",
model.map(|m| m.name().0).unwrap_or_else(|| "Model".into())
)
.into(),
},
ImageStatus::Ready => ContextStatus::Ready,
};
AddedContext {
kind: ContextKind::Image,
name,
parent,
tooltip: None,
icon_path,
status,
status: match context.status() {
ImageStatus::Loading => ContextStatus::Loading {
message: "Loading…".into(),
},
ImageStatus::Error => ContextStatus::Error {
message: "Failed to load image".into(),
},
ImageStatus::Ready => ContextStatus::Ready,
},
render_hover: Some(Rc::new({
let image = context.original_image.clone();
move |_, cx| {
@@ -813,7 +787,6 @@ impl Component for AddedContext {
original_image: Arc::new(Image::empty()),
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
},
None,
cx,
),
);
@@ -833,7 +806,6 @@ impl Component for AddedContext {
})
.shared(),
},
None,
cx,
),
);
@@ -848,7 +820,6 @@ impl Component for AddedContext {
original_image: Arc::new(Image::empty()),
image_task: Task::ready(None).shared(),
},
None,
cx,
),
);
@@ -870,60 +841,3 @@ impl Component for AddedContext {
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::App;
use language_model::{LanguageModel, fake_provider::FakeLanguageModel};
use std::sync::Arc;
#[gpui::test]
fn test_image_context_warning_for_unsupported_model(cx: &mut App) {
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel::default());
assert!(!model.supports_images());
let image_context = ImageContext {
context_id: ContextId::zero(),
project_path: None,
original_image: Arc::new(Image::empty()),
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
full_path: None,
};
let added_context = AddedContext::image(image_context, Some(&model), cx);
assert!(matches!(
added_context.status,
ContextStatus::Warning { .. }
));
assert!(matches!(added_context.kind, ContextKind::Image));
assert_eq!(added_context.name.as_ref(), "Image");
assert!(added_context.parent.is_none());
assert!(added_context.icon_path.is_none());
}
#[gpui::test]
fn test_image_context_ready_for_no_model(cx: &mut App) {
let image_context = ImageContext {
context_id: ContextId::zero(),
project_path: None,
original_image: Arc::new(Image::empty()),
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
full_path: None,
};
let added_context = AddedContext::image(image_context, None, cx);
assert!(
matches!(added_context.status, ContextStatus::Ready),
"Expected ready status when no model provided"
);
assert!(matches!(added_context.kind, ContextKind::Image));
assert_eq!(added_context.name.as_ref(), "Image");
assert!(added_context.parent.is_none());
assert!(added_context.icon_path.is_none());
}
}

View File

@@ -13,7 +13,6 @@ path = "src/assistant_tool.rs"
[dependencies]
anyhow.workspace = true
async-watch.workspace = true
buffer_diff.workspace = true
clock.workspace = true
collections.workspace = true

View File

@@ -1,7 +1,7 @@
use anyhow::{Context as _, Result};
use buffer_diff::BufferDiff;
use collections::BTreeMap;
use futures::{FutureExt, StreamExt, channel::mpsc};
use futures::{StreamExt, channel::mpsc};
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint};
use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
@@ -92,21 +92,21 @@ impl ActionLog {
let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
let (diff_update_tx, diff_update_rx) = mpsc::unbounded();
let diff_base;
let unreviewed_edits;
let unreviewed_changes;
if is_created {
diff_base = Rope::default();
unreviewed_edits = Patch::new(vec![Edit {
unreviewed_changes = Patch::new(vec![Edit {
old: 0..1,
new: 0..text_snapshot.max_point().row + 1,
}])
} else {
diff_base = buffer.read(cx).as_rope().clone();
unreviewed_edits = Patch::default();
unreviewed_changes = Patch::default();
}
TrackedBuffer {
buffer: buffer.clone(),
diff_base,
unreviewed_edits: unreviewed_edits,
unreviewed_changes,
snapshot: text_snapshot.clone(),
status,
version: buffer.read(cx).version(),
@@ -175,7 +175,7 @@ impl ActionLog {
.map_or(false, |file| file.disk_state() != DiskState::Deleted)
{
// If the buffer had been deleted by a tool, but it got
// resurrected externally, we want to clear the edits we
// resurrected externally, we want to clear the changes we
// were tracking and reset the buffer's state.
self.tracked_buffers.remove(&buffer);
self.track_buffer_internal(buffer, false, cx);
@@ -188,274 +188,108 @@ impl ActionLog {
async fn maintain_diff(
this: WeakEntity<Self>,
buffer: Entity<Buffer>,
mut buffer_updates: mpsc::UnboundedReceiver<(ChangeAuthor, text::BufferSnapshot)>,
mut diff_update: mpsc::UnboundedReceiver<(ChangeAuthor, text::BufferSnapshot)>,
cx: &mut AsyncApp,
) -> Result<()> {
let git_store = this.read_with(cx, |this, cx| this.project.read(cx).git_store().clone())?;
let git_diff = this
.update(cx, |this, cx| {
this.project.update(cx, |project, cx| {
project.open_uncommitted_diff(buffer.clone(), cx)
})
})?
.await
.ok();
let buffer_repo = git_store.read_with(cx, |git_store, cx| {
git_store.repository_and_path_for_buffer_id(buffer.read(cx).remote_id(), cx)
})?;
while let Some((author, buffer_snapshot)) = diff_update.next().await {
let (rebase, diff, language, language_registry) =
this.read_with(cx, |this, cx| {
let tracked_buffer = this
.tracked_buffers
.get(&buffer)
.context("buffer not tracked")?;
let (git_diff_updates_tx, mut git_diff_updates_rx) = async_watch::channel(());
let _repo_subscription =
if let Some((git_diff, (buffer_repo, _))) = git_diff.as_ref().zip(buffer_repo) {
cx.update(|cx| {
let mut old_head = buffer_repo.read(cx).head_commit.clone();
Some(cx.subscribe(git_diff, move |_, event, cx| match event {
buffer_diff::BufferDiffEvent::DiffChanged { .. } => {
let new_head = buffer_repo.read(cx).head_commit.clone();
if new_head != old_head {
old_head = new_head;
git_diff_updates_tx.send(()).ok();
let rebase = cx.background_spawn({
let mut base_text = tracked_buffer.diff_base.clone();
let old_snapshot = tracked_buffer.snapshot.clone();
let new_snapshot = buffer_snapshot.clone();
let unreviewed_changes = tracked_buffer.unreviewed_changes.clone();
async move {
let edits = diff_snapshots(&old_snapshot, &new_snapshot);
if let ChangeAuthor::User = author {
apply_non_conflicting_edits(
&unreviewed_changes,
edits,
&mut base_text,
new_snapshot.as_rope(),
);
}
(Arc::new(base_text.to_string()), base_text)
}
_ => {}
}))
})?
} else {
None
};
});
loop {
futures::select_biased! {
buffer_update = buffer_updates.next() => {
if let Some((author, buffer_snapshot)) = buffer_update {
Self::track_edits(&this, &buffer, author, buffer_snapshot, cx).await?;
} else {
break;
}
}
_ = git_diff_updates_rx.changed().fuse() => {
if let Some(git_diff) = git_diff.as_ref() {
Self::keep_committed_edits(&this, &buffer, &git_diff, cx).await?;
}
}
anyhow::Ok((
rebase,
tracked_buffer.diff.clone(),
tracked_buffer.buffer.read(cx).language().cloned(),
tracked_buffer.buffer.read(cx).language_registry(),
))
})??;
let (new_base_text, new_diff_base) = rebase.await;
let diff_snapshot = BufferDiff::update_diff(
diff.clone(),
buffer_snapshot.clone(),
Some(new_base_text),
true,
false,
language,
language_registry,
cx,
)
.await;
let mut unreviewed_changes = Patch::default();
if let Ok(diff_snapshot) = diff_snapshot {
unreviewed_changes = cx
.background_spawn({
let diff_snapshot = diff_snapshot.clone();
let buffer_snapshot = buffer_snapshot.clone();
let new_diff_base = new_diff_base.clone();
async move {
let mut unreviewed_changes = Patch::default();
for hunk in diff_snapshot.hunks_intersecting_range(
Anchor::MIN..Anchor::MAX,
&buffer_snapshot,
) {
let old_range = new_diff_base
.offset_to_point(hunk.diff_base_byte_range.start)
..new_diff_base.offset_to_point(hunk.diff_base_byte_range.end);
let new_range = hunk.range.start..hunk.range.end;
unreviewed_changes.push(point_to_row_edit(
Edit {
old: old_range,
new: new_range,
},
&new_diff_base,
&buffer_snapshot.as_rope(),
));
}
unreviewed_changes
}
})
.await;
diff.update(cx, |diff, cx| {
diff.set_snapshot(diff_snapshot, &buffer_snapshot, cx)
})?;
}
this.update(cx, |this, cx| {
let tracked_buffer = this
.tracked_buffers
.get_mut(&buffer)
.context("buffer not tracked")?;
tracked_buffer.diff_base = new_diff_base;
tracked_buffer.snapshot = buffer_snapshot;
tracked_buffer.unreviewed_changes = unreviewed_changes;
cx.notify();
anyhow::Ok(())
})??;
}
Ok(())
}
async fn track_edits(
this: &WeakEntity<ActionLog>,
buffer: &Entity<Buffer>,
author: ChangeAuthor,
buffer_snapshot: text::BufferSnapshot,
cx: &mut AsyncApp,
) -> Result<()> {
let rebase = this.read_with(cx, |this, cx| {
let tracked_buffer = this
.tracked_buffers
.get(buffer)
.context("buffer not tracked")?;
let rebase = cx.background_spawn({
let mut base_text = tracked_buffer.diff_base.clone();
let old_snapshot = tracked_buffer.snapshot.clone();
let new_snapshot = buffer_snapshot.clone();
let unreviewed_edits = tracked_buffer.unreviewed_edits.clone();
async move {
let edits = diff_snapshots(&old_snapshot, &new_snapshot);
if let ChangeAuthor::User = author {
apply_non_conflicting_edits(
&unreviewed_edits,
edits,
&mut base_text,
new_snapshot.as_rope(),
);
}
(Arc::new(base_text.to_string()), base_text)
}
});
anyhow::Ok(rebase)
})??;
let (new_base_text, new_diff_base) = rebase.await;
Self::update_diff(
this,
buffer,
buffer_snapshot,
new_base_text,
new_diff_base,
cx,
)
.await
}
async fn keep_committed_edits(
this: &WeakEntity<ActionLog>,
buffer: &Entity<Buffer>,
git_diff: &Entity<BufferDiff>,
cx: &mut AsyncApp,
) -> Result<()> {
let buffer_snapshot = this.read_with(cx, |this, _cx| {
let tracked_buffer = this
.tracked_buffers
.get(buffer)
.context("buffer not tracked")?;
anyhow::Ok(tracked_buffer.snapshot.clone())
})??;
let (new_base_text, new_diff_base) = this
.read_with(cx, |this, cx| {
let tracked_buffer = this
.tracked_buffers
.get(buffer)
.context("buffer not tracked")?;
let old_unreviewed_edits = tracked_buffer.unreviewed_edits.clone();
let agent_diff_base = tracked_buffer.diff_base.clone();
let git_diff_base = git_diff.read(cx).base_text().as_rope().clone();
let buffer_text = tracked_buffer.snapshot.as_rope().clone();
anyhow::Ok(cx.background_spawn(async move {
let mut old_unreviewed_edits = old_unreviewed_edits.into_iter().peekable();
let committed_edits = language::line_diff(
&agent_diff_base.to_string(),
&git_diff_base.to_string(),
)
.into_iter()
.map(|(old, new)| Edit { old, new });
let mut new_agent_diff_base = agent_diff_base.clone();
let mut row_delta = 0i32;
for committed in committed_edits {
while let Some(unreviewed) = old_unreviewed_edits.peek() {
// If the committed edit matches the unreviewed
// edit, assume the user wants to keep it.
if committed.old == unreviewed.old {
let unreviewed_new =
buffer_text.slice_rows(unreviewed.new.clone()).to_string();
let committed_new =
git_diff_base.slice_rows(committed.new.clone()).to_string();
if unreviewed_new == committed_new {
let old_byte_start =
new_agent_diff_base.point_to_offset(Point::new(
(unreviewed.old.start as i32 + row_delta) as u32,
0,
));
let old_byte_end =
new_agent_diff_base.point_to_offset(cmp::min(
Point::new(
(unreviewed.old.end as i32 + row_delta) as u32,
0,
),
new_agent_diff_base.max_point(),
));
new_agent_diff_base
.replace(old_byte_start..old_byte_end, &unreviewed_new);
row_delta +=
unreviewed.new_len() as i32 - unreviewed.old_len() as i32;
}
} else if unreviewed.old.start >= committed.old.end {
break;
}
old_unreviewed_edits.next().unwrap();
}
}
(
Arc::new(new_agent_diff_base.to_string()),
new_agent_diff_base,
)
}))
})??
.await;
Self::update_diff(
this,
buffer,
buffer_snapshot,
new_base_text,
new_diff_base,
cx,
)
.await
}
async fn update_diff(
this: &WeakEntity<ActionLog>,
buffer: &Entity<Buffer>,
buffer_snapshot: text::BufferSnapshot,
new_base_text: Arc<String>,
new_diff_base: Rope,
cx: &mut AsyncApp,
) -> Result<()> {
let (diff, language, language_registry) = this.read_with(cx, |this, cx| {
let tracked_buffer = this
.tracked_buffers
.get(buffer)
.context("buffer not tracked")?;
anyhow::Ok((
tracked_buffer.diff.clone(),
buffer.read(cx).language().cloned(),
buffer.read(cx).language_registry().clone(),
))
})??;
let diff_snapshot = BufferDiff::update_diff(
diff.clone(),
buffer_snapshot.clone(),
Some(new_base_text),
true,
false,
language,
language_registry,
cx,
)
.await;
let mut unreviewed_edits = Patch::default();
if let Ok(diff_snapshot) = diff_snapshot {
unreviewed_edits = cx
.background_spawn({
let diff_snapshot = diff_snapshot.clone();
let buffer_snapshot = buffer_snapshot.clone();
let new_diff_base = new_diff_base.clone();
async move {
let mut unreviewed_edits = Patch::default();
for hunk in diff_snapshot
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer_snapshot)
{
let old_range = new_diff_base
.offset_to_point(hunk.diff_base_byte_range.start)
..new_diff_base.offset_to_point(hunk.diff_base_byte_range.end);
let new_range = hunk.range.start..hunk.range.end;
unreviewed_edits.push(point_to_row_edit(
Edit {
old: old_range,
new: new_range,
},
&new_diff_base,
&buffer_snapshot.as_rope(),
));
}
unreviewed_edits
}
})
.await;
diff.update(cx, |diff, cx| {
diff.set_snapshot(diff_snapshot, &buffer_snapshot, cx);
})?;
}
this.update(cx, |this, cx| {
let tracked_buffer = this
.tracked_buffers
.get_mut(buffer)
.context("buffer not tracked")?;
tracked_buffer.diff_base = new_diff_base;
tracked_buffer.snapshot = buffer_snapshot;
tracked_buffer.unreviewed_edits = unreviewed_edits;
cx.notify();
anyhow::Ok(())
})?
}
/// Track a buffer as read, so we can notify the model about user edits.
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer_internal(buffer, false, cx);
@@ -516,7 +350,7 @@ impl ActionLog {
buffer_range.start.to_point(buffer)..buffer_range.end.to_point(buffer);
let mut delta = 0i32;
tracked_buffer.unreviewed_edits.retain_mut(|edit| {
tracked_buffer.unreviewed_changes.retain_mut(|edit| {
edit.old.start = (edit.old.start as i32 + delta) as u32;
edit.old.end = (edit.old.end as i32 + delta) as u32;
@@ -627,7 +461,7 @@ impl ActionLog {
.project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx));
// Clear all tracked edits for this buffer and start over as if we just read it.
// Clear all tracked changes for this buffer and start over as if we just read it.
self.tracked_buffers.remove(&buffer);
self.buffer_read(buffer.clone(), cx);
cx.notify();
@@ -643,7 +477,7 @@ impl ActionLog {
.peekable();
let mut edits_to_revert = Vec::new();
for edit in tracked_buffer.unreviewed_edits.edits() {
for edit in tracked_buffer.unreviewed_changes.edits() {
let new_range = tracked_buffer
.snapshot
.anchor_before(Point::new(edit.new.start, 0))
@@ -695,7 +529,7 @@ impl ActionLog {
.retain(|_buffer, tracked_buffer| match tracked_buffer.status {
TrackedBufferStatus::Deleted => false,
_ => {
tracked_buffer.unreviewed_edits.clear();
tracked_buffer.unreviewed_changes.clear();
tracked_buffer.diff_base = tracked_buffer.snapshot.as_rope().clone();
tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx);
true
@@ -704,11 +538,11 @@ impl ActionLog {
cx.notify();
}
/// Returns the set of buffers that contain edits that haven't been reviewed by the user.
/// Returns the set of buffers that contain changes that haven't been reviewed by the user.
pub fn changed_buffers(&self, cx: &App) -> BTreeMap<Entity<Buffer>, Entity<BufferDiff>> {
self.tracked_buffers
.iter()
.filter(|(_, tracked)| tracked.has_edits(cx))
.filter(|(_, tracked)| tracked.has_changes(cx))
.map(|(buffer, tracked)| (buffer.clone(), tracked.diff.clone()))
.collect()
}
@@ -828,7 +662,11 @@ fn point_to_row_edit(edit: Edit<Point>, old_text: &Rope, new_text: &Rope) -> Edi
old: edit.old.start.row + 1..edit.old.end.row + 1,
new: edit.new.start.row + 1..edit.new.end.row + 1,
}
} else if edit.old.start.column == 0 && edit.old.end.column == 0 && edit.new.end.column == 0 {
} else if edit.old.start.column == 0
&& edit.old.end.column == 0
&& edit.new.end.column == 0
&& edit.old.end != old_text.max_point()
{
Edit {
old: edit.old.start.row..edit.old.end.row,
new: edit.new.start.row..edit.new.end.row,
@@ -856,7 +694,7 @@ enum TrackedBufferStatus {
struct TrackedBuffer {
buffer: Entity<Buffer>,
diff_base: Rope,
unreviewed_edits: Patch<u32>,
unreviewed_changes: Patch<u32>,
status: TrackedBufferStatus,
version: clock::Global,
diff: Entity<BufferDiff>,
@@ -868,7 +706,7 @@ struct TrackedBuffer {
}
impl TrackedBuffer {
fn has_edits(&self, cx: &App) -> bool {
fn has_changes(&self, cx: &App) -> bool {
self.diff
.read(cx)
.hunks(&self.buffer.read(cx), cx)
@@ -889,6 +727,8 @@ pub struct ChangedBuffer {
#[cfg(test)]
mod tests {
use std::env;
use super::*;
use buffer_diff::DiffHunkStatusKind;
use gpui::TestAppContext;
@@ -897,7 +737,6 @@ mod tests {
use rand::prelude::*;
use serde_json::json;
use settings::SettingsStore;
use std::env;
use util::{RandomCharIter, path};
#[ctor::ctor]
@@ -1912,15 +1751,15 @@ mod tests {
.unwrap();
}
_ => {
let is_agent_edit = rng.gen_bool(0.5);
if is_agent_edit {
let is_agent_change = rng.gen_bool(0.5);
if is_agent_change {
log::info!("agent edit");
} else {
log::info!("user edit");
}
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
if is_agent_edit {
if is_agent_change {
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
}
});
@@ -1945,7 +1784,7 @@ mod tests {
let tracked_buffer = log.tracked_buffers.get(&buffer).unwrap();
let mut old_text = tracked_buffer.diff_base.clone();
let new_text = buffer.read(cx).as_rope();
for edit in tracked_buffer.unreviewed_edits.edits() {
for edit in tracked_buffer.unreviewed_changes.edits() {
let old_start = old_text.point_to_offset(Point::new(edit.new.start, 0));
let old_end = old_text.point_to_offset(cmp::min(
Point::new(edit.new.start + edit.old_len(), 0),
@@ -1961,171 +1800,6 @@ mod tests {
}
}
#[gpui::test]
async fn test_keep_edits_on_commit(cx: &mut gpui::TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(
path!("/project"),
json!({
".git": {},
"file.txt": "a\nb\nc\nd\ne\nf\ng\nh\ni\nj",
}),
)
.await;
fs.set_head_for_repo(
path!("/project/.git").as_ref(),
&[("file.txt".into(), "a\nb\nc\nd\ne\nf\ng\nh\ni\nj".into())],
"0000000",
);
cx.run_until_parked();
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| {
project.find_project_path(path!("/project/file.txt"), cx)
})
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer.edit(
[
// Edit at the very start: a -> A
(Point::new(0, 0)..Point::new(0, 1), "A"),
// Deletion in the middle: remove lines d and e
(Point::new(3, 0)..Point::new(5, 0), ""),
// Modification: g -> GGG
(Point::new(6, 0)..Point::new(6, 1), "GGG"),
// Addition: insert new line after h
(Point::new(7, 1)..Point::new(7, 1), "\nNEW"),
// Edit the very last character: j -> J
(Point::new(9, 0)..Point::new(9, 1), "J"),
],
None,
cx,
);
});
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
});
cx.run_until_parked();
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![
HunkStatus {
range: Point::new(0, 0)..Point::new(1, 0),
diff_status: DiffHunkStatusKind::Modified,
old_text: "a\n".into()
},
HunkStatus {
range: Point::new(3, 0)..Point::new(3, 0),
diff_status: DiffHunkStatusKind::Deleted,
old_text: "d\ne\n".into()
},
HunkStatus {
range: Point::new(4, 0)..Point::new(5, 0),
diff_status: DiffHunkStatusKind::Modified,
old_text: "g\n".into()
},
HunkStatus {
range: Point::new(6, 0)..Point::new(7, 0),
diff_status: DiffHunkStatusKind::Added,
old_text: "".into()
},
HunkStatus {
range: Point::new(8, 0)..Point::new(8, 1),
diff_status: DiffHunkStatusKind::Modified,
old_text: "j".into()
}
]
)]
);
// Simulate a git commit that matches some edits but not others:
// - Accepts the first edit (a -> A)
// - Accepts the deletion (remove d and e)
// - Makes a different change to g (g -> G instead of GGG)
// - Ignores the NEW line addition
// - Ignores the last line edit (j stays as j)
fs.set_head_for_repo(
path!("/project/.git").as_ref(),
&[("file.txt".into(), "A\nb\nc\nf\nG\nh\ni\nj".into())],
"0000001",
);
cx.run_until_parked();
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![
HunkStatus {
range: Point::new(4, 0)..Point::new(5, 0),
diff_status: DiffHunkStatusKind::Modified,
old_text: "g\n".into()
},
HunkStatus {
range: Point::new(6, 0)..Point::new(7, 0),
diff_status: DiffHunkStatusKind::Added,
old_text: "".into()
},
HunkStatus {
range: Point::new(8, 0)..Point::new(8, 1),
diff_status: DiffHunkStatusKind::Modified,
old_text: "j".into()
}
]
)]
);
// Make another commit that accepts the NEW line but with different content
fs.set_head_for_repo(
path!("/project/.git").as_ref(),
&[(
"file.txt".into(),
"A\nb\nc\nf\nGGG\nh\nDIFFERENT\ni\nj".into(),
)],
"0000002",
);
cx.run_until_parked();
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![
HunkStatus {
range: Point::new(6, 0)..Point::new(7, 0),
diff_status: DiffHunkStatusKind::Added,
old_text: "".into()
},
HunkStatus {
range: Point::new(8, 0)..Point::new(8, 1),
diff_status: DiffHunkStatusKind::Modified,
old_text: "j".into()
}
]
)]
);
// Final commit that accepts all remaining edits
fs.set_head_for_repo(
path!("/project/.git").as_ref(),
&[("file.txt".into(), "A\nb\nc\nf\nGGG\nh\nNEW\ni\nJ".into())],
"0000003",
);
cx.run_until_parked();
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct HunkStatus {
range: Range<Point>,

View File

@@ -16,24 +16,11 @@ pub fn adapt_schema_to_format(
}
match format {
LanguageModelToolSchemaFormat::JsonSchema => preprocess_json_schema(json),
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
}
}
fn preprocess_json_schema(json: &mut Value) -> Result<()> {
// `additionalProperties` defaults to `false` unless explicitly specified.
// This prevents models from hallucinating tool parameters.
if let Value::Object(obj) = json {
if let Some(Value::String(type_str)) = obj.get("type") {
if type_str == "object" && !obj.contains_key("additionalProperties") {
obj.insert("additionalProperties".to_string(), Value::Bool(false));
}
}
}
Ok(())
}
/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
if let Value::Object(obj) = json {
@@ -250,59 +237,4 @@ mod tests {
assert!(adapt_to_json_schema_subset(&mut json).is_err());
}
#[test]
fn test_preprocess_json_schema_adds_additional_properties() {
let mut json = json!({
"type": "object",
"properties": {
"name": {
"type": "string"
}
}
});
preprocess_json_schema(&mut json).unwrap();
assert_eq!(
json,
json!({
"type": "object",
"properties": {
"name": {
"type": "string"
}
},
"additionalProperties": false
})
);
}
#[test]
fn test_preprocess_json_schema_preserves_additional_properties() {
let mut json = json!({
"type": "object",
"properties": {
"name": {
"type": "string"
}
},
"additionalProperties": true
});
preprocess_json_schema(&mut json).unwrap();
assert_eq!(
json,
json!({
"type": "object",
"properties": {
"name": {
"type": "string"
}
},
"additionalProperties": true
})
);
}
}

View File

@@ -37,13 +37,13 @@ use crate::diagnostics_tool::DiagnosticsTool;
use crate::edit_file_tool::EditFileTool;
use crate::fetch_tool::FetchTool;
use crate::find_path_tool::FindPathTool;
use crate::grep_tool::GrepTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::thinking_tool::ThinkingTool;
pub use edit_file_tool::{EditFileMode, EditFileToolInput};
pub use find_path_tool::FindPathToolInput;
pub use grep_tool::{GrepTool, GrepToolInput};
pub use open_tool::OpenTool;
pub use read_file_tool::{ReadFileTool, ReadFileToolInput};
pub use terminal_tool::TerminalTool;
@@ -126,7 +126,6 @@ mod tests {
}
},
"required": ["location"],
"additionalProperties": false
})
);
}

View File

@@ -54,7 +54,6 @@ impl Template for EditFilePromptTemplate {
pub enum EditAgentOutputEvent {
ResolvingEditRange(Range<Anchor>),
UnresolvedEditRange,
AmbiguousEditRange(Vec<Range<usize>>),
Edited,
}
@@ -270,29 +269,16 @@ impl EditAgent {
}
}
let (edit_events_, mut resolved_old_text) = resolve_old_text.await?;
let (edit_events_, resolved_old_text) = resolve_old_text.await?;
edit_events = edit_events_;
// If we can't resolve the old text, restart the loop waiting for a
// new edit (or for the stream to end).
let resolved_old_text = match resolved_old_text.len() {
1 => resolved_old_text.pop().unwrap(),
0 => {
output_events
.unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
.ok();
continue;
}
_ => {
let ranges = resolved_old_text
.into_iter()
.map(|text| text.range)
.collect();
output_events
.unbounded_send(EditAgentOutputEvent::AmbiguousEditRange(ranges))
.ok();
continue;
}
let Some(resolved_old_text) = resolved_old_text else {
output_events
.unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
.ok();
continue;
};
// Compute edits in the background and apply them as they become
@@ -419,7 +405,7 @@ impl EditAgent {
mut edit_events: T,
cx: &mut AsyncApp,
) -> (
Task<Result<(T, Vec<ResolvedOldText>)>>,
Task<Result<(T, Option<ResolvedOldText>)>>,
async_watch::Receiver<Option<Range<usize>>>,
)
where
@@ -439,29 +425,21 @@ impl EditAgent {
}
}
let matches = matcher.finish();
let old_range = if matches.len() == 1 {
matches.first()
let old_range = matcher.finish();
old_range_tx.send(old_range.clone())?;
if let Some(old_range) = old_range {
let line_indent =
LineIndent::from_iter(matcher.query_lines().first().unwrap().chars());
Ok((
edit_events,
Some(ResolvedOldText {
range: old_range,
indent: line_indent,
}),
))
} else {
// No matches or multiple ambiguous matches
None
};
old_range_tx.send(old_range.cloned())?;
let indent = LineIndent::from_iter(
matcher
.query_lines()
.first()
.unwrap_or(&String::new())
.chars(),
);
let resolved_old_texts = matches
.into_iter()
.map(|range| ResolvedOldText { range, indent })
.collect::<Vec<_>>();
Ok((edit_events, resolved_old_texts))
Ok((edit_events, None))
}
});
(task, old_range_rx)
@@ -1344,76 +1322,6 @@ mod tests {
EditAgent::new(model, project, action_log, Templates::new())
}
#[gpui::test(iterations = 10)]
async fn test_non_unique_text_error(cx: &mut TestAppContext, mut rng: StdRng) {
let agent = init_test(cx).await;
let original_text = indoc! {"
function foo() {
return 42;
}
function bar() {
return 42;
}
function baz() {
return 42;
}
"};
let buffer = cx.new(|cx| Buffer::local(original_text, cx));
let (apply, mut events) = agent.edit(
buffer.clone(),
String::new(),
&LanguageModelRequest::default(),
&mut cx.to_async(),
);
cx.run_until_parked();
// When <old_text> matches text in more than one place
simulate_llm_output(
&agent,
indoc! {"
<old_text>
return 42;
</old_text>
<new_text>
return 100;
</new_text>
"},
&mut rng,
cx,
);
apply.await.unwrap();
// Then the text should remain unchanged
let result_text = buffer.read_with(cx, |buffer, _| buffer.snapshot().text());
assert_eq!(
result_text,
indoc! {"
function foo() {
return 42;
}
function bar() {
return 42;
}
function baz() {
return 42;
}
"},
"Text should remain unchanged when there are multiple matches"
);
// And AmbiguousEditRange even should be emitted
let events = drain_events(&mut events);
let ambiguous_ranges = vec![17..31, 52..66, 87..101];
assert!(
events.contains(&EditAgentOutputEvent::AmbiguousEditRange(ambiguous_ranges)),
"Should emit AmbiguousEditRange for non-unique text"
);
}
fn drain_events(
stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
) -> Vec<EditAgentOutputEvent> {

View File

@@ -1351,7 +1351,7 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
let mismatched_tag_ratio =
cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
if mismatched_tag_ratio > 0.10 {
if mismatched_tag_ratio > 0.05 {
for eval_output in eval_outputs {
println!("{}", eval_output);
}

View File

@@ -11,7 +11,7 @@ pub struct StreamingFuzzyMatcher {
snapshot: TextBufferSnapshot,
query_lines: Vec<String>,
incomplete_line: String,
best_matches: Vec<Range<usize>>,
best_match: Option<Range<usize>>,
matrix: SearchMatrix,
}
@@ -22,7 +22,7 @@ impl StreamingFuzzyMatcher {
snapshot,
query_lines: Vec::new(),
incomplete_line: String::new(),
best_matches: Vec::new(),
best_match: None,
matrix: SearchMatrix::new(buffer_line_count + 1),
}
}
@@ -55,41 +55,31 @@ impl StreamingFuzzyMatcher {
self.incomplete_line.replace_range(..last_pos + 1, "");
self.best_matches = self.resolve_location_fuzzy();
if let Some(first_match) = self.best_matches.first() {
Some(first_match.clone())
} else {
None
}
} else {
if let Some(first_match) = self.best_matches.first() {
Some(first_match.clone())
} else {
None
}
self.best_match = self.resolve_location_fuzzy();
}
self.best_match.clone()
}
/// Finish processing and return the final best match(es).
/// Finish processing and return the final best match.
///
/// This processes any remaining incomplete line before returning the final
/// match result.
pub fn finish(&mut self) -> Vec<Range<usize>> {
pub fn finish(&mut self) -> Option<Range<usize>> {
// Process any remaining incomplete line
if !self.incomplete_line.is_empty() {
self.query_lines.push(self.incomplete_line.clone());
self.incomplete_line.clear();
self.best_matches = self.resolve_location_fuzzy();
self.best_match = self.resolve_location_fuzzy();
}
self.best_matches.clone()
self.best_match.clone()
}
fn resolve_location_fuzzy(&mut self) -> Vec<Range<usize>> {
fn resolve_location_fuzzy(&mut self) -> Option<Range<usize>> {
let new_query_line_count = self.query_lines.len();
let old_query_line_count = self.matrix.rows.saturating_sub(1);
if new_query_line_count == old_query_line_count {
return Vec::new();
return None;
}
self.matrix.resize_rows(new_query_line_count + 1);
@@ -142,61 +132,53 @@ impl StreamingFuzzyMatcher {
}
}
// Find all matches with the best cost
// Traceback to find the best match
let buffer_line_count = self.snapshot.max_point().row as usize + 1;
let mut buffer_row_end = buffer_line_count as u32;
let mut best_cost = u32::MAX;
let mut matches_with_best_cost = Vec::new();
for col in 1..=buffer_line_count {
let cost = self.matrix.get(new_query_line_count, col).cost;
if cost < best_cost {
best_cost = cost;
matches_with_best_cost.clear();
matches_with_best_cost.push(col as u32);
} else if cost == best_cost {
matches_with_best_cost.push(col as u32);
buffer_row_end = col as u32;
}
}
// Find ranges for the matches
let mut valid_matches = Vec::new();
for &buffer_row_end in &matches_with_best_cost {
let mut matched_lines = 0;
let mut query_row = new_query_line_count;
let mut buffer_row_start = buffer_row_end;
while query_row > 0 && buffer_row_start > 0 {
let current = self.matrix.get(query_row, buffer_row_start as usize);
match current.direction {
SearchDirection::Diagonal => {
query_row -= 1;
buffer_row_start -= 1;
matched_lines += 1;
}
SearchDirection::Up => {
query_row -= 1;
}
SearchDirection::Left => {
buffer_row_start -= 1;
}
let mut matched_lines = 0;
let mut query_row = new_query_line_count;
let mut buffer_row_start = buffer_row_end;
while query_row > 0 && buffer_row_start > 0 {
let current = self.matrix.get(query_row, buffer_row_start as usize);
match current.direction {
SearchDirection::Diagonal => {
query_row -= 1;
buffer_row_start -= 1;
matched_lines += 1;
}
SearchDirection::Up => {
query_row -= 1;
}
SearchDirection::Left => {
buffer_row_start -= 1;
}
}
let matched_buffer_row_count = buffer_row_end - buffer_row_start;
let matched_ratio = matched_lines as f32
/ (matched_buffer_row_count as f32).max(new_query_line_count as f32);
if matched_ratio >= 0.8 {
let buffer_start_ix = self
.snapshot
.point_to_offset(Point::new(buffer_row_start, 0));
let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
buffer_row_end - 1,
self.snapshot.line_len(buffer_row_end - 1),
));
valid_matches.push((buffer_row_start, buffer_start_ix..buffer_end_ix));
}
}
valid_matches.into_iter().map(|(_, range)| range).collect()
let matched_buffer_row_count = buffer_row_end - buffer_row_start;
let matched_ratio = matched_lines as f32
/ (matched_buffer_row_count as f32).max(new_query_line_count as f32);
if matched_ratio >= 0.8 {
let buffer_start_ix = self
.snapshot
.point_to_offset(Point::new(buffer_row_start, 0));
let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
buffer_row_end - 1,
self.snapshot.line_len(buffer_row_end - 1),
));
Some(buffer_start_ix..buffer_end_ix)
} else {
None
}
}
}
@@ -656,35 +638,28 @@ mod tests {
matcher.push(chunk);
}
let actual_ranges = matcher.finish();
let result = matcher.finish();
// If no expected ranges, we expect no match
if expected_ranges.is_empty() {
assert!(
actual_ranges.is_empty(),
assert_eq!(
result, None,
"Expected no match for query: {:?}, but found: {:?}",
query,
actual_ranges
query, result
);
} else {
let mut actual_ranges = Vec::new();
if let Some(range) = result {
actual_ranges.push(range);
}
let text_with_actual_range = generate_marked_text(&text, &actual_ranges, false);
pretty_assertions::assert_eq!(
text_with_actual_range,
text_with_expected_range,
indoc! {"
Query: {:?}
Chunks: {:?}
Expected marked text: {}
Actual marked text: {}
Expected ranges: {:?}
Actual ranges: {:?}"
},
"Query: {:?}, Chunks: {:?}",
query,
chunks,
text_with_expected_range,
text_with_actual_range,
expected_ranges,
actual_ranges
chunks
);
}
}
@@ -712,11 +687,8 @@ mod tests {
fn finish(mut finder: StreamingFuzzyMatcher) -> Option<String> {
let snapshot = finder.snapshot.clone();
let matches = finder.finish();
if let Some(range) = matches.first() {
Some(snapshot.text_for_range(range.clone()).collect::<String>())
} else {
None
}
finder
.finish()
.map(|range| snapshot.text_for_range(range).collect::<String>())
}
}

View File

@@ -239,7 +239,6 @@ impl Tool for EditFileTool {
};
let mut hallucinated_old_text = false;
let mut ambiguous_ranges = Vec::new();
while let Some(event) = events.next().await {
match event {
EditAgentOutputEvent::Edited => {
@@ -248,7 +247,6 @@ impl Tool for EditFileTool {
}
}
EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
EditAgentOutputEvent::ResolvingEditRange(range) => {
if let Some(card) = card_clone.as_ref() {
card.update(cx, |card, cx| card.reveal_range(range, cx))?;
@@ -331,17 +329,6 @@ impl Tool for EditFileTool {
I can perform the requested edits.
"}
);
anyhow::ensure!(
ambiguous_ranges.is_empty(),
// TODO: Include ambiguous_ranges, converted to line numbers.
// This would work best if we add `line_hint` parameter
// to edit_file_tool
formatdoc! {"
<old_text> matches more than one position in the file. Read the
relevant sections of {input_path} again and extend <old_text> so
that I can perform the requested edits.
"}
);
Ok(ToolResultOutput {
content: ToolResultContent::Text("No edits were made.".into()),
output: serde_json::to_value(output).ok(),

View File

@@ -6,12 +6,11 @@ use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{OffsetRangeExt, ParseStatus, Point};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::{
Project, WorktreeSettings,
Project,
search::{SearchQuery, SearchResult},
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{cmp, fmt::Write, sync::Arc};
use ui::IconName;
use util::RangeExt;
@@ -131,23 +130,6 @@ impl Tool for GrepTool {
}
};
// Exclude global file_scan_exclusions and private_files settings
let exclude_matcher = {
let global_settings = WorktreeSettings::get_global(cx);
let exclude_patterns = global_settings
.file_scan_exclusions
.sources()
.iter()
.chain(global_settings.private_files.sources().iter());
match PathMatcher::new(exclude_patterns) {
Ok(matcher) => matcher,
Err(error) => {
return Task::ready(Err(anyhow!("invalid exclude pattern: {error}"))).into();
}
}
};
let query = match SearchQuery::regex(
&input.regex,
false,
@@ -155,7 +137,7 @@ impl Tool for GrepTool {
false,
false,
include_matcher,
exclude_matcher,
PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
true, // Always match file include pattern against *full project paths* that start with a project root.
None,
) {
@@ -178,24 +160,12 @@ impl Tool for GrepTool {
continue;
}
let Ok((Some(path), mut parse_status)) = buffer.read_with(cx, |buffer, cx| {
let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| {
(buffer.file().map(|file| file.full_path(cx)), buffer.parse_status())
}) else {
})? else {
continue;
};
// Check if this file should be excluded based on its worktree settings
if let Ok(Some(project_path)) = project.read_with(cx, |project, cx| {
project.find_project_path(&path, cx)
}) {
if cx.update(|cx| {
let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
worktree_settings.is_path_excluded(&project_path.path)
|| worktree_settings.is_path_private(&project_path.path)
}).unwrap_or(false) {
continue;
}
}
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
@@ -314,11 +284,10 @@ impl Tool for GrepTool {
mod tests {
use super::*;
use assistant_tool::Tool;
use gpui::{AppContext, TestAppContext, UpdateGlobal};
use gpui::{AppContext, TestAppContext};
use language::{Language, LanguageConfig, LanguageMatcher};
use language_model::fake_provider::FakeLanguageModel;
use project::{FakeFs, Project, WorktreeSettings};
use serde_json::json;
use project::{FakeFs, Project};
use settings::SettingsStore;
use unindent::Unindent;
use util::path;
@@ -330,7 +299,7 @@ mod tests {
let fs = FakeFs::new(cx.executor().clone());
fs.insert_tree(
path!("/root"),
"/root",
serde_json::json!({
"src": {
"main.rs": "fn main() {\n println!(\"Hello, world!\");\n}",
@@ -418,7 +387,7 @@ mod tests {
let fs = FakeFs::new(cx.executor().clone());
fs.insert_tree(
path!("/root"),
"/root",
serde_json::json!({
"case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
}),
@@ -499,7 +468,7 @@ mod tests {
// Create test file with syntax structures
fs.insert_tree(
path!("/root"),
"/root",
serde_json::json!({
"test_syntax.rs": r#"
fn top_level_function() {
@@ -820,488 +789,4 @@ mod tests {
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
.unwrap()
}
#[gpui::test]
async fn test_grep_security_boundaries(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/"),
json!({
"project_root": {
"allowed_file.rs": "fn main() { println!(\"This file is in the project\"); }",
".mysecrets": "SECRET_KEY=abc123\nfn secret() { /* private */ }",
".secretdir": {
"config": "fn special_configuration() { /* excluded */ }"
},
".mymetadata": "fn custom_metadata() { /* excluded */ }",
"subdir": {
"normal_file.rs": "fn normal_file_content() { /* Normal */ }",
"special.privatekey": "fn private_key_content() { /* private */ }",
"data.mysensitive": "fn sensitive_data() { /* private */ }"
}
},
"outside_project": {
"sensitive_file.rs": "fn outside_function() { /* This file is outside the project */ }"
}
}),
)
.await;
cx.update(|cx| {
use gpui::UpdateGlobal;
use project::WorktreeSettings;
use settings::SettingsStore;
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
settings.file_scan_exclusions = Some(vec![
"**/.secretdir".to_string(),
"**/.mymetadata".to_string(),
]);
settings.private_files = Some(vec![
"**/.mysecrets".to_string(),
"**/*.privatekey".to_string(),
"**/*.mysensitive".to_string(),
]);
});
});
});
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
// Searching for files outside the project worktree should return no results
let result = cx
.update(|cx| {
let input = json!({
"regex": "outside_function"
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
let results = result.unwrap();
let paths = extract_paths_from_results(&results.content.as_str().unwrap());
assert!(
paths.is_empty(),
"grep_tool should not find files outside the project worktree"
);
// Searching within the project should succeed
let result = cx
.update(|cx| {
let input = json!({
"regex": "main"
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
let results = result.unwrap();
let paths = extract_paths_from_results(&results.content.as_str().unwrap());
assert!(
paths.iter().any(|p| p.contains("allowed_file.rs")),
"grep_tool should be able to search files inside worktrees"
);
// Searching files that match file_scan_exclusions should return no results
let result = cx
.update(|cx| {
let input = json!({
"regex": "special_configuration"
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
let results = result.unwrap();
let paths = extract_paths_from_results(&results.content.as_str().unwrap());
assert!(
paths.is_empty(),
"grep_tool should not search files in .secretdir (file_scan_exclusions)"
);
let result = cx
.update(|cx| {
let input = json!({
"regex": "custom_metadata"
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
let results = result.unwrap();
let paths = extract_paths_from_results(&results.content.as_str().unwrap());
assert!(
paths.is_empty(),
"grep_tool should not search .mymetadata files (file_scan_exclusions)"
);
// Searching private files should return no results
let result = cx
.update(|cx| {
let input = json!({
"regex": "SECRET_KEY"
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
let results = result.unwrap();
let paths = extract_paths_from_results(&results.content.as_str().unwrap());
assert!(
paths.is_empty(),
"grep_tool should not search .mysecrets (private_files)"
);
let result = cx
.update(|cx| {
let input = json!({
"regex": "private_key_content"
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
let results = result.unwrap();
let paths = extract_paths_from_results(&results.content.as_str().unwrap());
assert!(
paths.is_empty(),
"grep_tool should not search .privatekey files (private_files)"
);
let result = cx
.update(|cx| {
let input = json!({
"regex": "sensitive_data"
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
let results = result.unwrap();
let paths = extract_paths_from_results(&results.content.as_str().unwrap());
assert!(
paths.is_empty(),
"grep_tool should not search .mysensitive files (private_files)"
);
// Searching a normal file should still work, even with private_files configured
let result = cx
.update(|cx| {
let input = json!({
"regex": "normal_file_content"
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
let results = result.unwrap();
let paths = extract_paths_from_results(&results.content.as_str().unwrap());
assert!(
paths.iter().any(|p| p.contains("normal_file.rs")),
"Should be able to search normal files"
);
// Path traversal attempts with .. in include_pattern should not escape project
let result = cx
.update(|cx| {
let input = json!({
"regex": "outside_function",
"include_pattern": "../outside_project/**/*.rs"
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
let results = result.unwrap();
let paths = extract_paths_from_results(&results.content.as_str().unwrap());
assert!(
paths.is_empty(),
"grep_tool should not allow escaping project boundaries with relative paths"
);
}
#[gpui::test]
async fn test_grep_with_multiple_worktree_settings(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
// Create first worktree with its own private files
fs.insert_tree(
path!("/worktree1"),
json!({
".zed": {
"settings.json": r#"{
"file_scan_exclusions": ["**/fixture.*"],
"private_files": ["**/secret.rs"]
}"#
},
"src": {
"main.rs": "fn main() { let secret_key = \"hidden\"; }",
"secret.rs": "const API_KEY: &str = \"secret_value\";",
"utils.rs": "pub fn get_config() -> String { \"config\".to_string() }"
},
"tests": {
"test.rs": "fn test_secret() { assert!(true); }",
"fixture.sql": "SELECT * FROM secret_table;"
}
}),
)
.await;
// Create second worktree with different private files
fs.insert_tree(
path!("/worktree2"),
json!({
".zed": {
"settings.json": r#"{
"file_scan_exclusions": ["**/internal.*"],
"private_files": ["**/private.js", "**/data.json"]
}"#
},
"lib": {
"public.js": "export function getSecret() { return 'public'; }",
"private.js": "const SECRET_KEY = \"private_value\";",
"data.json": "{\"secret_data\": \"hidden\"}"
},
"docs": {
"README.md": "# Documentation with secret info",
"internal.md": "Internal secret documentation"
}
}),
)
.await;
// Set global settings
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
settings.file_scan_exclusions =
Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]);
settings.private_files = Some(vec!["**/.env".to_string()]);
});
});
});
let project = Project::test(
fs.clone(),
[path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
cx,
)
.await;
// Wait for worktrees to be fully scanned
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
// Search for "secret" - should exclude files based on worktree-specific settings
let result = cx
.update(|cx| {
let input = json!({
"regex": "secret",
"case_sensitive": false
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await
.unwrap();
let content = result.content.as_str().unwrap();
let paths = extract_paths_from_results(&content);
// Should find matches in non-private files
assert!(
paths.iter().any(|p| p.contains("main.rs")),
"Should find 'secret' in worktree1/src/main.rs"
);
assert!(
paths.iter().any(|p| p.contains("test.rs")),
"Should find 'secret' in worktree1/tests/test.rs"
);
assert!(
paths.iter().any(|p| p.contains("public.js")),
"Should find 'secret' in worktree2/lib/public.js"
);
assert!(
paths.iter().any(|p| p.contains("README.md")),
"Should find 'secret' in worktree2/docs/README.md"
);
// Should NOT find matches in private/excluded files based on worktree settings
assert!(
!paths.iter().any(|p| p.contains("secret.rs")),
"Should not search in worktree1/src/secret.rs (local private_files)"
);
assert!(
!paths.iter().any(|p| p.contains("fixture.sql")),
"Should not search in worktree1/tests/fixture.sql (local file_scan_exclusions)"
);
assert!(
!paths.iter().any(|p| p.contains("private.js")),
"Should not search in worktree2/lib/private.js (local private_files)"
);
assert!(
!paths.iter().any(|p| p.contains("data.json")),
"Should not search in worktree2/lib/data.json (local private_files)"
);
assert!(
!paths.iter().any(|p| p.contains("internal.md")),
"Should not search in worktree2/docs/internal.md (local file_scan_exclusions)"
);
// Test with `include_pattern` specific to one worktree
let result = cx
.update(|cx| {
let input = json!({
"regex": "secret",
"include_pattern": "worktree1/**/*.rs"
});
Arc::new(GrepTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await
.unwrap();
let content = result.content.as_str().unwrap();
let paths = extract_paths_from_results(&content);
// Should only find matches in worktree1 *.rs files (excluding private ones)
assert!(
paths.iter().any(|p| p.contains("main.rs")),
"Should find match in worktree1/src/main.rs"
);
assert!(
paths.iter().any(|p| p.contains("test.rs")),
"Should find match in worktree1/tests/test.rs"
);
assert!(
!paths.iter().any(|p| p.contains("secret.rs")),
"Should not find match in excluded worktree1/src/secret.rs"
);
assert!(
paths.iter().all(|p| !p.contains("worktree2")),
"Should not find any matches in worktree2"
);
}
// Helper function to extract file paths from grep results
fn extract_paths_from_results(results: &str) -> Vec<String> {
results
.lines()
.filter(|line| line.starts_with("## Matches in "))
.map(|line| {
line.strip_prefix("## Matches in ")
.unwrap()
.trim()
.to_string()
})
.collect()
}
}

View File

@@ -6,4 +6,3 @@ Searches the contents of files in the project with a regular expression
- Never use this tool to search for paths. Only search file contents with this tool.
- Use this tool when you need to find files containing specific patterns
- Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages.
- DO NOT use HTML entities solely to escape characters in the tool parameters.

View File

@@ -3,10 +3,9 @@ use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::{Project, WorktreeSettings};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{fmt::Write, path::Path, sync::Arc};
use ui::IconName;
use util::markdown::MarkdownInlineCode;
@@ -120,80 +119,21 @@ impl Tool for ListDirectoryTool {
else {
return Task::ready(Err(anyhow!("Worktree not found"))).into();
};
let worktree = worktree.read(cx);
// Check if the directory whose contents we're listing is itself excluded or private
let global_settings = WorktreeSettings::get_global(cx);
if global_settings.is_path_excluded(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot list directory because its path matches the user's global `file_scan_exclusions` setting: {}",
&input.path
)))
.into();
}
if global_settings.is_path_private(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot list directory because its path matches the user's global `private_files` setting: {}",
&input.path
)))
.into();
}
let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
if worktree_settings.is_path_excluded(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot list directory because its path matches the user's worktree`file_scan_exclusions` setting: {}",
&input.path
)))
.into();
}
if worktree_settings.is_path_private(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot list directory because its path matches the user's worktree `private_paths` setting: {}",
&input.path
)))
.into();
}
let worktree_snapshot = worktree.read(cx).snapshot();
let worktree_root_name = worktree.read(cx).root_name().to_string();
let Some(entry) = worktree_snapshot.entry_for_path(&project_path.path) else {
let Some(entry) = worktree.entry_for_path(&project_path.path) else {
return Task::ready(Err(anyhow!("Path not found: {}", input.path))).into();
};
if !entry.is_dir() {
return Task::ready(Err(anyhow!("{} is not a directory.", input.path))).into();
}
let worktree_snapshot = worktree.read(cx).snapshot();
let mut folders = Vec::new();
let mut files = Vec::new();
for entry in worktree_snapshot.child_entries(&project_path.path) {
// Skip private and excluded files and directories
if global_settings.is_path_private(&entry.path)
|| global_settings.is_path_excluded(&entry.path)
{
continue;
}
if project
.read(cx)
.find_project_path(&entry.path, cx)
.map(|project_path| {
let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
worktree_settings.is_path_excluded(&project_path.path)
|| worktree_settings.is_path_private(&project_path.path)
})
.unwrap_or(false)
{
continue;
}
let full_path = Path::new(&worktree_root_name)
for entry in worktree.child_entries(&project_path.path) {
let full_path = Path::new(worktree.root_name())
.join(&entry.path)
.display()
.to_string();
@@ -226,10 +166,10 @@ impl Tool for ListDirectoryTool {
mod tests {
use super::*;
use assistant_tool::Tool;
use gpui::{AppContext, TestAppContext, UpdateGlobal};
use gpui::{AppContext, TestAppContext};
use indoc::indoc;
use language_model::fake_provider::FakeLanguageModel;
use project::{FakeFs, Project, WorktreeSettings};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
@@ -257,7 +197,7 @@ mod tests {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"src": {
"main.rs": "fn main() {}",
@@ -387,7 +327,7 @@ mod tests {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"empty_dir": {}
}),
@@ -419,7 +359,7 @@ mod tests {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"file.txt": "content"
}),
@@ -472,394 +412,4 @@ mod tests {
.contains("is not a directory")
);
}
#[gpui::test]
async fn test_list_directory_security(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/project"),
json!({
"normal_dir": {
"file1.txt": "content",
"file2.txt": "content"
},
".mysecrets": "SECRET_KEY=abc123",
".secretdir": {
"config": "special configuration",
"secret.txt": "secret content"
},
".mymetadata": "custom metadata",
"visible_dir": {
"normal.txt": "normal content",
"special.privatekey": "private key content",
"data.mysensitive": "sensitive data",
".hidden_subdir": {
"hidden_file.txt": "hidden content"
}
}
}),
)
.await;
// Configure settings explicitly
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
settings.file_scan_exclusions = Some(vec![
"**/.secretdir".to_string(),
"**/.mymetadata".to_string(),
"**/.hidden_subdir".to_string(),
]);
settings.private_files = Some(vec![
"**/.mysecrets".to_string(),
"**/*.privatekey".to_string(),
"**/*.mysensitive".to_string(),
]);
});
});
});
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let tool = Arc::new(ListDirectoryTool);
// Listing root directory should exclude private and excluded files
let input = json!({
"path": "project"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await
.unwrap();
let content = result.content.as_str().unwrap();
// Should include normal directories
assert!(content.contains("normal_dir"), "Should list normal_dir");
assert!(content.contains("visible_dir"), "Should list visible_dir");
// Should NOT include excluded or private files
assert!(
!content.contains(".secretdir"),
"Should not list .secretdir (file_scan_exclusions)"
);
assert!(
!content.contains(".mymetadata"),
"Should not list .mymetadata (file_scan_exclusions)"
);
assert!(
!content.contains(".mysecrets"),
"Should not list .mysecrets (private_files)"
);
// Trying to list an excluded directory should fail
let input = json!({
"path": "project/.secretdir"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await;
assert!(
result.is_err(),
"Should not be able to list excluded directory"
);
assert!(
result
.unwrap_err()
.to_string()
.contains("file_scan_exclusions"),
"Error should mention file_scan_exclusions"
);
// Listing a directory should exclude private files within it
let input = json!({
"path": "project/visible_dir"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await
.unwrap();
let content = result.content.as_str().unwrap();
// Should include normal files
assert!(content.contains("normal.txt"), "Should list normal.txt");
// Should NOT include private files
assert!(
!content.contains("privatekey"),
"Should not list .privatekey files (private_files)"
);
assert!(
!content.contains("mysensitive"),
"Should not list .mysensitive files (private_files)"
);
// Should NOT include subdirectories that match exclusions
assert!(
!content.contains(".hidden_subdir"),
"Should not list .hidden_subdir (file_scan_exclusions)"
);
}
#[gpui::test]
async fn test_list_directory_with_multiple_worktree_settings(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
// Create first worktree with its own private files
fs.insert_tree(
path!("/worktree1"),
json!({
".zed": {
"settings.json": r#"{
"file_scan_exclusions": ["**/fixture.*"],
"private_files": ["**/secret.rs", "**/config.toml"]
}"#
},
"src": {
"main.rs": "fn main() { println!(\"Hello from worktree1\"); }",
"secret.rs": "const API_KEY: &str = \"secret_key_1\";",
"config.toml": "[database]\nurl = \"postgres://localhost/db1\""
},
"tests": {
"test.rs": "mod tests { fn test_it() {} }",
"fixture.sql": "CREATE TABLE users (id INT, name VARCHAR(255));"
}
}),
)
.await;
// Create second worktree with different private files
fs.insert_tree(
path!("/worktree2"),
json!({
".zed": {
"settings.json": r#"{
"file_scan_exclusions": ["**/internal.*"],
"private_files": ["**/private.js", "**/data.json"]
}"#
},
"lib": {
"public.js": "export function greet() { return 'Hello from worktree2'; }",
"private.js": "const SECRET_TOKEN = \"private_token_2\";",
"data.json": "{\"api_key\": \"json_secret_key\"}"
},
"docs": {
"README.md": "# Public Documentation",
"internal.md": "# Internal Secrets and Configuration"
}
}),
)
.await;
// Set global settings
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
settings.file_scan_exclusions =
Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]);
settings.private_files = Some(vec!["**/.env".to_string()]);
});
});
});
let project = Project::test(
fs.clone(),
[path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
cx,
)
.await;
// Wait for worktrees to be fully scanned
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let tool = Arc::new(ListDirectoryTool);
// Test listing worktree1/src - should exclude secret.rs and config.toml based on local settings
let input = json!({
"path": "worktree1/src"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await
.unwrap();
let content = result.content.as_str().unwrap();
assert!(content.contains("main.rs"), "Should list main.rs");
assert!(
!content.contains("secret.rs"),
"Should not list secret.rs (local private_files)"
);
assert!(
!content.contains("config.toml"),
"Should not list config.toml (local private_files)"
);
// Test listing worktree1/tests - should exclude fixture.sql based on local settings
let input = json!({
"path": "worktree1/tests"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await
.unwrap();
let content = result.content.as_str().unwrap();
assert!(content.contains("test.rs"), "Should list test.rs");
assert!(
!content.contains("fixture.sql"),
"Should not list fixture.sql (local file_scan_exclusions)"
);
// Test listing worktree2/lib - should exclude private.js and data.json based on local settings
let input = json!({
"path": "worktree2/lib"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await
.unwrap();
let content = result.content.as_str().unwrap();
assert!(content.contains("public.js"), "Should list public.js");
assert!(
!content.contains("private.js"),
"Should not list private.js (local private_files)"
);
assert!(
!content.contains("data.json"),
"Should not list data.json (local private_files)"
);
// Test listing worktree2/docs - should exclude internal.md based on local settings
let input = json!({
"path": "worktree2/docs"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await
.unwrap();
let content = result.content.as_str().unwrap();
assert!(content.contains("README.md"), "Should list README.md");
assert!(
!content.contains("internal.md"),
"Should not list internal.md (local file_scan_exclusions)"
);
// Test trying to list an excluded directory directly
let input = json!({
"path": "worktree1/src/secret.rs"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await;
// This should fail because we're trying to list a file, not a directory
assert!(result.is_err(), "Should fail when trying to list a file");
}
}

View File

@@ -12,10 +12,9 @@ use language::{Anchor, Point};
use language_model::{
LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
};
use project::{AgentLocation, Project, WorktreeSettings};
use project::{AgentLocation, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::sync::Arc;
use ui::IconName;
use util::markdown::MarkdownInlineCode;
@@ -108,48 +107,12 @@ impl Tool for ReadFileTool {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
};
// Error out if this path is either excluded or private in global settings
let global_settings = WorktreeSettings::get_global(cx);
if global_settings.is_path_excluded(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the global `file_scan_exclusions` setting: {}",
&input.path
)))
.into();
}
if global_settings.is_path_private(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the global `private_files` setting: {}",
&input.path
)))
.into();
}
// Error out if this path is either excluded or private in worktree settings
let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
if worktree_settings.is_path_excluded(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the worktree `file_scan_exclusions` setting: {}",
&input.path
)))
.into();
}
if worktree_settings.is_path_private(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the worktree `private_files` setting: {}",
&input.path
)))
.into();
}
let file_path = input.path.clone();
if image_store::is_image_file(&project, &project_path, cx) {
if !model.supports_images() {
return Task::ready(Err(anyhow!(
"Attempted to read an image, but Zed doesn't currently support sending images to {}.",
"Attempted to read an image, but Zed doesn't currently sending images to {}.",
model.name().0
)))
.into();
@@ -289,10 +252,10 @@ impl Tool for ReadFileTool {
#[cfg(test)]
mod test {
use super::*;
use gpui::{AppContext, TestAppContext, UpdateGlobal};
use gpui::{AppContext, TestAppContext};
use language::{Language, LanguageConfig, LanguageMatcher};
use language_model::fake_provider::FakeLanguageModel;
use project::{FakeFs, Project, WorktreeSettings};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
@@ -302,7 +265,7 @@ mod test {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root"), json!({})).await;
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
@@ -336,7 +299,7 @@ mod test {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
"/root",
json!({
"small_file.txt": "This is a small file content"
}),
@@ -375,7 +338,7 @@ mod test {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
"/root",
json!({
"large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n a: u32,\n b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
}),
@@ -466,7 +429,7 @@ mod test {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
"/root",
json!({
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
}),
@@ -507,7 +470,7 @@ mod test {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
"/root",
json!({
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
}),
@@ -638,544 +601,4 @@ mod test {
)
.unwrap()
}
#[gpui::test]
async fn test_read_file_security(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/"),
json!({
"project_root": {
"allowed_file.txt": "This file is in the project",
".mysecrets": "SECRET_KEY=abc123",
".secretdir": {
"config": "special configuration"
},
".mymetadata": "custom metadata",
"subdir": {
"normal_file.txt": "Normal file content",
"special.privatekey": "private key content",
"data.mysensitive": "sensitive data"
}
},
"outside_project": {
"sensitive_file.txt": "This file is outside the project"
}
}),
)
.await;
cx.update(|cx| {
use gpui::UpdateGlobal;
use project::WorktreeSettings;
use settings::SettingsStore;
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
settings.file_scan_exclusions = Some(vec![
"**/.secretdir".to_string(),
"**/.mymetadata".to_string(),
]);
settings.private_files = Some(vec![
"**/.mysecrets".to_string(),
"**/*.privatekey".to_string(),
"**/*.mysensitive".to_string(),
]);
});
});
});
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
// Reading a file outside the project worktree should fail
let result = cx
.update(|cx| {
let input = json!({
"path": "/outside_project/sensitive_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read an absolute path outside a worktree"
);
// Reading a file within the project should succeed
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/allowed_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
assert!(
result.is_ok(),
"read_file_tool should be able to read files inside worktrees"
);
// Reading files that match file_scan_exclusions should fail
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/.secretdir/config"
});
Arc::new(ReadFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read files in .secretdir (file_scan_exclusions)"
);
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/.mymetadata"
});
Arc::new(ReadFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .mymetadata files (file_scan_exclusions)"
);
// Reading private files should fail
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/.mysecrets"
});
Arc::new(ReadFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .mysecrets (private_files)"
);
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/subdir/special.privatekey"
});
Arc::new(ReadFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .privatekey files (private_files)"
);
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/subdir/data.mysensitive"
});
Arc::new(ReadFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .mysensitive files (private_files)"
);
// Reading a normal file should still work, even with private_files configured
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/subdir/normal_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
assert!(result.is_ok(), "Should be able to read normal files");
assert_eq!(
result.unwrap().content.as_str().unwrap(),
"Normal file content"
);
// Path traversal attempts with .. should fail
let result = cx
.update(|cx| {
let input = json!({
"path": "project_root/../outside_project/sensitive_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read a relative path that resolves to outside a worktree"
);
}
#[gpui::test]
async fn test_read_file_with_multiple_worktree_settings(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
// Create first worktree with its own private_files setting
fs.insert_tree(
path!("/worktree1"),
json!({
"src": {
"main.rs": "fn main() { println!(\"Hello from worktree1\"); }",
"secret.rs": "const API_KEY: &str = \"secret_key_1\";",
"config.toml": "[database]\nurl = \"postgres://localhost/db1\""
},
"tests": {
"test.rs": "mod tests { fn test_it() {} }",
"fixture.sql": "CREATE TABLE users (id INT, name VARCHAR(255));"
},
".zed": {
"settings.json": r#"{
"file_scan_exclusions": ["**/fixture.*"],
"private_files": ["**/secret.rs", "**/config.toml"]
}"#
}
}),
)
.await;
// Create second worktree with different private_files setting
fs.insert_tree(
path!("/worktree2"),
json!({
"lib": {
"public.js": "export function greet() { return 'Hello from worktree2'; }",
"private.js": "const SECRET_TOKEN = \"private_token_2\";",
"data.json": "{\"api_key\": \"json_secret_key\"}"
},
"docs": {
"README.md": "# Public Documentation",
"internal.md": "# Internal Secrets and Configuration"
},
".zed": {
"settings.json": r#"{
"file_scan_exclusions": ["**/internal.*"],
"private_files": ["**/private.js", "**/data.json"]
}"#
}
}),
)
.await;
// Set global settings
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
settings.file_scan_exclusions =
Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]);
settings.private_files = Some(vec!["**/.env".to_string()]);
});
});
});
let project = Project::test(
fs.clone(),
[path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
cx,
)
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let tool = Arc::new(ReadFileTool);
// Test reading allowed files in worktree1
let input = json!({
"path": "worktree1/src/main.rs"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await
.unwrap();
assert_eq!(
result.content.as_str().unwrap(),
"fn main() { println!(\"Hello from worktree1\"); }"
);
// Test reading private file in worktree1 should fail
let input = json!({
"path": "worktree1/src/secret.rs"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `private_files` setting"),
"Error should mention worktree private_files setting"
);
// Test reading excluded file in worktree1 should fail
let input = json!({
"path": "worktree1/tests/fixture.sql"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `file_scan_exclusions` setting"),
"Error should mention worktree file_scan_exclusions setting"
);
// Test reading allowed files in worktree2
let input = json!({
"path": "worktree2/lib/public.js"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await
.unwrap();
assert_eq!(
result.content.as_str().unwrap(),
"export function greet() { return 'Hello from worktree2'; }"
);
// Test reading private file in worktree2 should fail
let input = json!({
"path": "worktree2/lib/private.js"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `private_files` setting"),
"Error should mention worktree private_files setting"
);
// Test reading excluded file in worktree2 should fail
let input = json!({
"path": "worktree2/docs/internal.md"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `file_scan_exclusions` setting"),
"Error should mention worktree file_scan_exclusions setting"
);
// Test that files allowed in one worktree but not in another are handled correctly
// (e.g., config.toml is private in worktree1 but doesn't exist in worktree2)
let input = json!({
"path": "worktree1/src/config.toml"
});
let result = cx
.update(|cx| {
tool.clone().run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
})
.output
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `private_files` setting"),
"Config.toml should be blocked by worktree1's private_files setting"
);
}
}

View File

@@ -56,7 +56,6 @@ pub struct Channel {
pub name: SharedString,
pub visibility: proto::ChannelVisibility,
pub parent_path: Vec<ChannelId>,
pub channel_order: i32,
}
#[derive(Default, Debug)]
@@ -615,24 +614,7 @@ impl ChannelStore {
to: to.0,
})
.await?;
Ok(())
})
}
pub fn reorder_channel(
&mut self,
channel_id: ChannelId,
direction: proto::reorder_channel::Direction,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(async move |_, _| {
client
.request(proto::ReorderChannel {
channel_id: channel_id.0,
direction: direction.into(),
})
.await?;
Ok(())
})
}
@@ -1045,18 +1027,6 @@ impl ChannelStore {
});
}
#[cfg(any(test, feature = "test-support"))]
pub fn reset(&mut self) {
self.channel_invitations.clear();
self.channel_index.clear();
self.channel_participants.clear();
self.outgoing_invites.clear();
self.opened_buffers.clear();
self.opened_chats.clear();
self.disconnect_channel_buffers_task = None;
self.channel_states.clear();
}
pub(crate) fn update_channels(
&mut self,
payload: proto::UpdateChannels,
@@ -1081,7 +1051,6 @@ impl ChannelStore {
visibility: channel.visibility(),
name: channel.name.into(),
parent_path: channel.parent_path.into_iter().map(ChannelId).collect(),
channel_order: channel.channel_order,
}),
),
}

View File

@@ -61,13 +61,11 @@ impl ChannelPathsInsertGuard<'_> {
ret = existing_channel.visibility != channel_proto.visibility()
|| existing_channel.name != channel_proto.name
|| existing_channel.parent_path != parent_path
|| existing_channel.channel_order != channel_proto.channel_order;
|| existing_channel.parent_path != parent_path;
existing_channel.visibility = channel_proto.visibility();
existing_channel.name = channel_proto.name.into();
existing_channel.parent_path = parent_path;
existing_channel.channel_order = channel_proto.channel_order;
} else {
self.channels_by_id.insert(
ChannelId(channel_proto.id),
@@ -76,7 +74,6 @@ impl ChannelPathsInsertGuard<'_> {
visibility: channel_proto.visibility(),
name: channel_proto.name.into(),
parent_path,
channel_order: channel_proto.channel_order,
}),
);
self.insert_root(ChannelId(channel_proto.id));
@@ -103,18 +100,17 @@ impl Drop for ChannelPathsInsertGuard<'_> {
fn channel_path_sorting_key(
id: ChannelId,
channels_by_id: &BTreeMap<ChannelId, Arc<Channel>>,
) -> impl Iterator<Item = (i32, ChannelId)> {
let (parent_path, order_and_id) =
channels_by_id
.get(&id)
.map_or((&[] as &[_], None), |channel| {
(
channel.parent_path.as_slice(),
Some((channel.channel_order, channel.id)),
)
});
) -> impl Iterator<Item = (&str, ChannelId)> {
let (parent_path, name) = channels_by_id
.get(&id)
.map_or((&[] as &[_], None), |channel| {
(
channel.parent_path.as_slice(),
Some((channel.name.as_ref(), channel.id)),
)
});
parent_path
.iter()
.filter_map(|id| Some((channels_by_id.get(id)?.channel_order, *id)))
.chain(order_and_id)
.filter_map(|id| Some((channels_by_id.get(id)?.name.as_ref(), *id)))
.chain(name)
}

View File

@@ -21,14 +21,12 @@ fn test_update_channels(cx: &mut App) {
name: "b".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: Vec::new(),
channel_order: 1,
},
proto::Channel {
id: 2,
name: "a".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: Vec::new(),
channel_order: 2,
},
],
..Default::default()
@@ -39,8 +37,8 @@ fn test_update_channels(cx: &mut App) {
&channel_store,
&[
//
(0, "b".to_string()),
(0, "a".to_string()),
(0, "b".to_string()),
],
cx,
);
@@ -54,14 +52,12 @@ fn test_update_channels(cx: &mut App) {
name: "x".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![1],
channel_order: 1,
},
proto::Channel {
id: 4,
name: "y".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![2],
channel_order: 1,
},
],
..Default::default()
@@ -71,111 +67,15 @@ fn test_update_channels(cx: &mut App) {
assert_channels(
&channel_store,
&[
(0, "b".to_string()),
(1, "x".to_string()),
(0, "a".to_string()),
(1, "y".to_string()),
(0, "b".to_string()),
(1, "x".to_string()),
],
cx,
);
}
#[gpui::test]
fn test_update_channels_order_independent(cx: &mut App) {
/// Based on: https://stackoverflow.com/a/59939809
fn unique_permutations<T: Clone>(items: Vec<T>) -> Vec<Vec<T>> {
if items.len() == 1 {
vec![items]
} else {
let mut output: Vec<Vec<T>> = vec![];
for (ix, first) in items.iter().enumerate() {
let mut remaining_elements = items.clone();
remaining_elements.remove(ix);
for mut permutation in unique_permutations(remaining_elements) {
permutation.insert(0, first.clone());
output.push(permutation);
}
}
output
}
}
let test_data = vec![
proto::Channel {
id: 6,
name: "β".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![1, 3],
channel_order: 1,
},
proto::Channel {
id: 5,
name: "α".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![1],
channel_order: 2,
},
proto::Channel {
id: 3,
name: "x".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![1],
channel_order: 1,
},
proto::Channel {
id: 4,
name: "y".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![2],
channel_order: 1,
},
proto::Channel {
id: 1,
name: "b".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: Vec::new(),
channel_order: 1,
},
proto::Channel {
id: 2,
name: "a".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: Vec::new(),
channel_order: 2,
},
];
let channel_store = init_test(cx);
let permutations = unique_permutations(test_data);
for test_instance in permutations {
channel_store.update(cx, |channel_store, _| channel_store.reset());
update_channels(
&channel_store,
proto::UpdateChannels {
channels: test_instance,
..Default::default()
},
cx,
);
assert_channels(
&channel_store,
&[
(0, "b".to_string()),
(1, "x".to_string()),
(2, "β".to_string()),
(1, "α".to_string()),
(0, "a".to_string()),
(1, "y".to_string()),
],
cx,
);
}
}
#[gpui::test]
fn test_dangling_channel_paths(cx: &mut App) {
let channel_store = init_test(cx);
@@ -189,21 +89,18 @@ fn test_dangling_channel_paths(cx: &mut App) {
name: "a".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![],
channel_order: 1,
},
proto::Channel {
id: 1,
name: "b".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![0],
channel_order: 1,
},
proto::Channel {
id: 2,
name: "c".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![0, 1],
channel_order: 1,
},
],
..Default::default()
@@ -250,7 +147,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
name: "the-channel".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![],
channel_order: 1,
}],
..Default::default()
});

View File

@@ -266,14 +266,11 @@ CREATE TABLE "channels" (
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"visibility" VARCHAR NOT NULL,
"parent_path" TEXT NOT NULL,
"requires_zed_cla" BOOLEAN NOT NULL DEFAULT FALSE,
"channel_order" INTEGER NOT NULL DEFAULT 1
"requires_zed_cla" BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX "index_channels_on_parent_path" ON "channels" ("parent_path");
CREATE INDEX "index_channels_on_parent_path_and_order" ON "channels" ("parent_path", "channel_order");
CREATE TABLE IF NOT EXISTS "channel_chat_participants" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"user_id" INTEGER NOT NULL REFERENCES users (id),

View File

@@ -1,16 +0,0 @@
-- Add channel_order column to channels table with default value
ALTER TABLE channels ADD COLUMN channel_order INTEGER NOT NULL DEFAULT 1;
-- Update channel_order for existing channels using ROW_NUMBER for deterministic ordering
UPDATE channels
SET channel_order = (
SELECT ROW_NUMBER() OVER (
PARTITION BY parent_path
ORDER BY name, id
)
FROM channels c2
WHERE c2.id = channels.id
);
-- Create index for efficient ordering queries
CREATE INDEX "index_channels_on_parent_path_and_order" ON "channels" ("parent_path", "channel_order");

View File

@@ -582,7 +582,6 @@ pub struct Channel {
pub visibility: ChannelVisibility,
/// parent_path is the channel ids from the root to this one (not including this one)
pub parent_path: Vec<ChannelId>,
pub channel_order: i32,
}
impl Channel {
@@ -592,7 +591,6 @@ impl Channel {
visibility: value.visibility,
name: value.clone().name,
parent_path: value.ancestors().collect(),
channel_order: value.channel_order,
}
}
@@ -602,13 +600,8 @@ impl Channel {
name: self.name.clone(),
visibility: self.visibility.into(),
parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(),
channel_order: self.channel_order,
}
}
pub fn root_id(&self) -> ChannelId {
self.parent_path.first().copied().unwrap_or(self.id)
}
}
#[derive(Debug, PartialEq, Eq, Hash)]

View File

@@ -4,7 +4,7 @@ use rpc::{
ErrorCode, ErrorCodeExt,
proto::{ChannelBufferVersion, VectorClockEntry, channel_member::Kind},
};
use sea_orm::{ActiveValue, DbBackend, TryGetableMany};
use sea_orm::{DbBackend, TryGetableMany};
impl Database {
#[cfg(test)]
@@ -59,32 +59,16 @@ impl Database {
parent = Some(parent_channel);
}
let parent_path = parent
.as_ref()
.map_or(String::new(), |parent| parent.path());
// Find the maximum channel_order among siblings to set the new channel at the end
let max_order = if parent_path.is_empty() {
0
} else {
max_order(&parent_path, &tx).await?
};
log::info!(
"Creating channel '{}' with parent_path='{}', max_order={}, new_order={}",
name,
parent_path,
max_order,
max_order + 1
);
let channel = channel::ActiveModel {
id: ActiveValue::NotSet,
name: ActiveValue::Set(name.to_string()),
visibility: ActiveValue::Set(ChannelVisibility::Members),
parent_path: ActiveValue::Set(parent_path),
parent_path: ActiveValue::Set(
parent
.as_ref()
.map_or(String::new(), |parent| parent.path()),
),
requires_zed_cla: ActiveValue::NotSet,
channel_order: ActiveValue::Set(max_order + 1),
}
.insert(&*tx)
.await?;
@@ -547,7 +531,11 @@ impl Database {
.get_channel_descendants_excluding_self(channels.iter(), tx)
.await?;
descendants.extend(channels);
for channel in channels {
if let Err(ix) = descendants.binary_search_by_key(&channel.path(), |c| c.path()) {
descendants.insert(ix, channel);
}
}
let roles_by_channel_id = channel_memberships
.iter()
@@ -964,14 +952,11 @@ impl Database {
}
let root_id = channel.root_id();
let new_parent_path = new_parent.path();
let old_path = format!("{}{}/", channel.parent_path, channel.id);
let new_path = format!("{}{}/", &new_parent_path, channel.id);
let new_order = max_order(&new_parent_path, &tx).await? + 1;
let new_path = format!("{}{}/", new_parent.path(), channel.id);
let mut model = channel.into_active_model();
model.parent_path = ActiveValue::Set(new_parent.path());
model.channel_order = ActiveValue::Set(new_order);
let channel = model.update(&*tx).await?;
let descendent_ids =
@@ -1001,137 +986,6 @@ impl Database {
})
.await
}
pub async fn reorder_channel(
&self,
channel_id: ChannelId,
direction: proto::reorder_channel::Direction,
user_id: UserId,
) -> Result<Vec<Channel>> {
self.transaction(|tx| async move {
let mut channel = self.get_channel_internal(channel_id, &tx).await?;
if channel.is_root() {
log::info!("Skipping reorder of root channel {}", channel.id,);
return Ok(vec![]);
}
log::info!(
"Reordering channel {} (parent_path: '{}', order: {})",
channel.id,
channel.parent_path,
channel.channel_order
);
// Check if user is admin of the channel
self.check_user_is_channel_admin(&channel, user_id, &tx)
.await?;
// Find the sibling channel to swap with
let sibling_channel = match direction {
proto::reorder_channel::Direction::Up => {
log::info!(
"Looking for sibling with parent_path='{}' and order < {}",
channel.parent_path,
channel.channel_order
);
// Find channel with highest order less than current
channel::Entity::find()
.filter(
channel::Column::ParentPath
.eq(&channel.parent_path)
.and(channel::Column::ChannelOrder.lt(channel.channel_order)),
)
.order_by_desc(channel::Column::ChannelOrder)
.one(&*tx)
.await?
}
proto::reorder_channel::Direction::Down => {
log::info!(
"Looking for sibling with parent_path='{}' and order > {}",
channel.parent_path,
channel.channel_order
);
// Find channel with lowest order greater than current
channel::Entity::find()
.filter(
channel::Column::ParentPath
.eq(&channel.parent_path)
.and(channel::Column::ChannelOrder.gt(channel.channel_order)),
)
.order_by_asc(channel::Column::ChannelOrder)
.one(&*tx)
.await?
}
};
let mut sibling_channel = match sibling_channel {
Some(sibling) => {
log::info!(
"Found sibling {} (parent_path: '{}', order: {})",
sibling.id,
sibling.parent_path,
sibling.channel_order
);
sibling
}
None => {
log::warn!("No sibling found to swap with");
// No sibling to swap with
return Ok(vec![]);
}
};
let current_order = channel.channel_order;
let sibling_order = sibling_channel.channel_order;
channel::ActiveModel {
id: ActiveValue::Unchanged(sibling_channel.id),
channel_order: ActiveValue::Set(current_order),
..Default::default()
}
.update(&*tx)
.await?;
sibling_channel.channel_order = current_order;
channel::ActiveModel {
id: ActiveValue::Unchanged(channel.id),
channel_order: ActiveValue::Set(sibling_order),
..Default::default()
}
.update(&*tx)
.await?;
channel.channel_order = sibling_order;
log::info!(
"Reorder complete. Swapped channels {} and {}",
channel.id,
sibling_channel.id
);
let swapped_channels = vec![
Channel::from_model(channel),
Channel::from_model(sibling_channel),
];
Ok(swapped_channels)
})
.await
}
}
async fn max_order(parent_path: &str, tx: &TransactionHandle) -> Result<i32> {
let max_order = channel::Entity::find()
.filter(channel::Column::ParentPath.eq(parent_path))
.select_only()
.column_as(channel::Column::ChannelOrder.max(), "max_order")
.into_tuple::<Option<i32>>()
.one(&**tx)
.await?
.flatten()
.unwrap_or(0);
Ok(max_order)
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]

View File

@@ -10,9 +10,6 @@ pub struct Model {
pub visibility: ChannelVisibility,
pub parent_path: String,
pub requires_zed_cla: bool,
/// The order of this channel relative to its siblings within the same parent.
/// Lower values appear first. Channels are sorted by parent_path first, then by channel_order.
pub channel_order: i32,
}
impl Model {

View File

@@ -172,40 +172,16 @@ impl Drop for TestDb {
}
}
#[track_caller]
fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
let expected_channels = expected.into_iter().collect::<HashSet<_>>();
let actual_channels = actual.into_iter().collect::<HashSet<_>>();
pretty_assertions::assert_eq!(expected_channels, actual_channels);
}
fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
use std::collections::HashMap;
let mut result = Vec::new();
let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
for (id, parent_path, name) in channels {
let parent_key = parent_path.to_vec();
let order = if parent_key.is_empty() {
1
} else {
*order_by_parent
.entry(parent_key.clone())
.and_modify(|e| *e += 1)
.or_insert(1)
};
result.push(Channel {
channels
.iter()
.map(|(id, parent_path, name)| Channel {
id: *id,
name: name.to_string(),
visibility: ChannelVisibility::Members,
parent_path: parent_key,
channel_order: order,
});
}
result
parent_path: parent_path.to_vec(),
})
.collect()
}
static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);

View File

@@ -1,15 +1,15 @@
use crate::{
db::{
Channel, ChannelId, ChannelRole, Database, NewUserParams, RoomId, UserId,
tests::{assert_channel_tree_matches, channel_tree, new_test_connection, new_test_user},
tests::{channel_tree, new_test_connection, new_test_user},
},
test_both_dbs,
};
use rpc::{
ConnectionId,
proto::{self, reorder_channel},
proto::{self},
};
use std::{collections::HashSet, sync::Arc};
use std::sync::Arc;
test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
@@ -59,28 +59,28 @@ async fn test_channels(db: &Arc<Database>) {
.unwrap();
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_channel_tree_matches(
assert_eq!(
result.channels,
channel_tree(&[
(zed_id, &[], "zed"),
(crdb_id, &[zed_id], "crdb"),
(livestreaming_id, &[zed_id], "livestreaming"),
(livestreaming_id, &[zed_id], "livestreaming",),
(replace_id, &[zed_id], "replace"),
(rust_id, &[], "rust"),
(cargo_id, &[rust_id], "cargo"),
(cargo_ra_id, &[rust_id, cargo_id], "cargo-ra"),
]),
(cargo_ra_id, &[rust_id, cargo_id], "cargo-ra",)
],)
);
let result = db.get_channels_for_user(b_id).await.unwrap();
assert_channel_tree_matches(
assert_eq!(
result.channels,
channel_tree(&[
(zed_id, &[], "zed"),
(crdb_id, &[zed_id], "crdb"),
(livestreaming_id, &[zed_id], "livestreaming"),
(replace_id, &[zed_id], "replace"),
]),
(livestreaming_id, &[zed_id], "livestreaming",),
(replace_id, &[zed_id], "replace")
],)
);
// Update member permissions
@@ -94,14 +94,14 @@ async fn test_channels(db: &Arc<Database>) {
assert!(set_channel_admin.is_ok());
let result = db.get_channels_for_user(b_id).await.unwrap();
assert_channel_tree_matches(
assert_eq!(
result.channels,
channel_tree(&[
(zed_id, &[], "zed"),
(crdb_id, &[zed_id], "crdb"),
(livestreaming_id, &[zed_id], "livestreaming"),
(replace_id, &[zed_id], "replace"),
]),
(livestreaming_id, &[zed_id], "livestreaming",),
(replace_id, &[zed_id], "replace")
],)
);
// Remove a single channel
@@ -313,8 +313,8 @@ async fn test_channel_renames(db: &Arc<Database>) {
test_both_dbs!(
test_db_channel_moving,
test_db_channel_moving_postgres,
test_db_channel_moving_sqlite
test_channels_moving_postgres,
test_channels_moving_sqlite
);
async fn test_db_channel_moving(db: &Arc<Database>) {
@@ -343,14 +343,16 @@ async fn test_db_channel_moving(db: &Arc<Database>) {
.await
.unwrap();
let livestreaming_sub_id = db
.create_sub_channel("livestreaming_sub", livestreaming_id, a_id)
let livestreaming_dag_id = db
.create_sub_channel("livestreaming_dag", livestreaming_id, a_id)
.await
.unwrap();
// ========================================================================
// sanity check
// Initial DAG:
// /- gpui2
// zed -- crdb - livestreaming - livestreaming_sub
// zed -- crdb - livestreaming - livestreaming_dag
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_channel_tree(
result.channels,
@@ -358,242 +360,10 @@ async fn test_db_channel_moving(db: &Arc<Database>) {
(zed_id, &[]),
(crdb_id, &[zed_id]),
(livestreaming_id, &[zed_id, crdb_id]),
(livestreaming_sub_id, &[zed_id, crdb_id, livestreaming_id]),
(livestreaming_dag_id, &[zed_id, crdb_id, livestreaming_id]),
(gpui2_id, &[zed_id]),
],
);
// Check that we can do a simple leaf -> leaf move
db.move_channel(livestreaming_sub_id, crdb_id, a_id)
.await
.unwrap();
// /- gpui2
// zed -- crdb -- livestreaming
// \- livestreaming_sub
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_channel_tree(
result.channels,
&[
(zed_id, &[]),
(crdb_id, &[zed_id]),
(livestreaming_id, &[zed_id, crdb_id]),
(livestreaming_sub_id, &[zed_id, crdb_id]),
(gpui2_id, &[zed_id]),
],
);
// Check that we can move a whole subtree at once
db.move_channel(crdb_id, gpui2_id, a_id).await.unwrap();
// zed -- gpui2 -- crdb -- livestreaming
// \- livestreaming_sub
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_channel_tree(
result.channels,
&[
(zed_id, &[]),
(gpui2_id, &[zed_id]),
(crdb_id, &[zed_id, gpui2_id]),
(livestreaming_id, &[zed_id, gpui2_id, crdb_id]),
(livestreaming_sub_id, &[zed_id, gpui2_id, crdb_id]),
],
);
}
test_both_dbs!(
test_channel_reordering,
test_channel_reordering_postgres,
test_channel_reordering_sqlite
);
async fn test_channel_reordering(db: &Arc<Database>) {
let admin_id = db
.create_user(
"admin@example.com",
None,
false,
NewUserParams {
github_login: "admin".into(),
github_user_id: 1,
},
)
.await
.unwrap()
.user_id;
let user_id = db
.create_user(
"user@example.com",
None,
false,
NewUserParams {
github_login: "user".into(),
github_user_id: 2,
},
)
.await
.unwrap()
.user_id;
// Create a root channel with some sub-channels
let root_id = db.create_root_channel("root", admin_id).await.unwrap();
// Invite user to root channel so they can see the sub-channels
db.invite_channel_member(root_id, user_id, admin_id, ChannelRole::Member)
.await
.unwrap();
db.respond_to_channel_invite(root_id, user_id, true)
.await
.unwrap();
let alpha_id = db
.create_sub_channel("alpha", root_id, admin_id)
.await
.unwrap();
let beta_id = db
.create_sub_channel("beta", root_id, admin_id)
.await
.unwrap();
let gamma_id = db
.create_sub_channel("gamma", root_id, admin_id)
.await
.unwrap();
// Initial order should be: root, alpha (order=1), beta (order=2), gamma (order=3)
let result = db.get_channels_for_user(admin_id).await.unwrap();
assert_channel_tree_order(
result.channels,
&[
(root_id, &[], 1),
(alpha_id, &[root_id], 1),
(beta_id, &[root_id], 2),
(gamma_id, &[root_id], 3),
],
);
// Test moving beta up (should swap with alpha)
let updated_channels = db
.reorder_channel(beta_id, reorder_channel::Direction::Up, admin_id)
.await
.unwrap();
// Verify that beta and alpha were returned as updated
assert_eq!(updated_channels.len(), 2);
let updated_ids: std::collections::HashSet<_> = updated_channels.iter().map(|c| c.id).collect();
assert!(updated_ids.contains(&alpha_id));
assert!(updated_ids.contains(&beta_id));
// Now order should be: root, beta (order=1), alpha (order=2), gamma (order=3)
let result = db.get_channels_for_user(admin_id).await.unwrap();
assert_channel_tree_order(
result.channels,
&[
(root_id, &[], 1),
(beta_id, &[root_id], 1),
(alpha_id, &[root_id], 2),
(gamma_id, &[root_id], 3),
],
);
// Test moving gamma down (should be no-op since it's already last)
let updated_channels = db
.reorder_channel(gamma_id, reorder_channel::Direction::Down, admin_id)
.await
.unwrap();
// Should return just nothing
assert_eq!(updated_channels.len(), 0);
// Test moving alpha down (should swap with gamma)
let updated_channels = db
.reorder_channel(alpha_id, reorder_channel::Direction::Down, admin_id)
.await
.unwrap();
// Verify that alpha and gamma were returned as updated
assert_eq!(updated_channels.len(), 2);
let updated_ids: std::collections::HashSet<_> = updated_channels.iter().map(|c| c.id).collect();
assert!(updated_ids.contains(&alpha_id));
assert!(updated_ids.contains(&gamma_id));
// Now order should be: root, beta (order=1), gamma (order=2), alpha (order=3)
let result = db.get_channels_for_user(admin_id).await.unwrap();
assert_channel_tree_order(
result.channels,
&[
(root_id, &[], 1),
(beta_id, &[root_id], 1),
(gamma_id, &[root_id], 2),
(alpha_id, &[root_id], 3),
],
);
// Test that non-admin cannot reorder
let reorder_result = db
.reorder_channel(beta_id, reorder_channel::Direction::Up, user_id)
.await;
assert!(reorder_result.is_err());
// Test moving beta up (should be no-op since it's already first)
let updated_channels = db
.reorder_channel(beta_id, reorder_channel::Direction::Up, admin_id)
.await
.unwrap();
// Should return nothing
assert_eq!(updated_channels.len(), 0);
// Adding a channel to an existing ordering should add it to the end
let delta_id = db
.create_sub_channel("delta", root_id, admin_id)
.await
.unwrap();
let result = db.get_channels_for_user(admin_id).await.unwrap();
assert_channel_tree_order(
result.channels,
&[
(root_id, &[], 1),
(beta_id, &[root_id], 1),
(gamma_id, &[root_id], 2),
(alpha_id, &[root_id], 3),
(delta_id, &[root_id], 4),
],
);
// And moving a channel into an existing ordering should add it to the end
let eta_id = db
.create_sub_channel("eta", delta_id, admin_id)
.await
.unwrap();
let result = db.get_channels_for_user(admin_id).await.unwrap();
assert_channel_tree_order(
result.channels,
&[
(root_id, &[], 1),
(beta_id, &[root_id], 1),
(gamma_id, &[root_id], 2),
(alpha_id, &[root_id], 3),
(delta_id, &[root_id], 4),
(eta_id, &[root_id, delta_id], 1),
],
);
db.move_channel(eta_id, root_id, admin_id).await.unwrap();
let result = db.get_channels_for_user(admin_id).await.unwrap();
assert_channel_tree_order(
result.channels,
&[
(root_id, &[], 1),
(beta_id, &[root_id], 1),
(gamma_id, &[root_id], 2),
(alpha_id, &[root_id], 3),
(delta_id, &[root_id], 4),
(eta_id, &[root_id], 5),
],
);
}
test_both_dbs!(
@@ -652,20 +422,6 @@ async fn test_db_channel_moving_bugs(db: &Arc<Database>) {
(livestreaming_id, &[zed_id, projects_id]),
],
);
// Can't un-root a root channel
db.move_channel(zed_id, livestreaming_id, user_id)
.await
.unwrap_err();
let result = db.get_channels_for_user(user_id).await.unwrap();
assert_channel_tree(
result.channels,
&[
(zed_id, &[]),
(projects_id, &[zed_id]),
(livestreaming_id, &[zed_id, projects_id]),
],
);
}
test_both_dbs!(
@@ -989,29 +745,10 @@ fn assert_channel_tree(actual: Vec<Channel>, expected: &[(ChannelId, &[ChannelId
let actual = actual
.iter()
.map(|channel| (channel.id, channel.parent_path.as_slice()))
.collect::<HashSet<_>>();
let expected = expected
.iter()
.map(|(id, parents)| (*id, *parents))
.collect::<HashSet<_>>();
pretty_assertions::assert_eq!(actual, expected, "wrong channel ids and parent paths");
}
#[track_caller]
fn assert_channel_tree_order(actual: Vec<Channel>, expected: &[(ChannelId, &[ChannelId], i32)]) {
let actual = actual
.iter()
.map(|channel| {
(
channel.id,
channel.parent_path.as_slice(),
channel.channel_order,
)
})
.collect::<HashSet<_>>();
let expected = expected
.iter()
.map(|(id, parents, order)| (*id, *parents, *order))
.collect::<HashSet<_>>();
pretty_assertions::assert_eq!(actual, expected, "wrong channel ids and parent paths");
.collect::<Vec<_>>();
pretty_assertions::assert_eq!(
actual,
expected.to_vec(),
"wrong channel ids and parent paths"
);
}

View File

@@ -384,7 +384,6 @@ impl Server {
.add_request_handler(get_notifications)
.add_request_handler(mark_notification_as_read)
.add_request_handler(move_channel)
.add_request_handler(reorder_channel)
.add_request_handler(follow)
.add_message_handler(unfollow)
.add_message_handler(update_followers)
@@ -3221,51 +3220,6 @@ async fn move_channel(
Ok(())
}
async fn reorder_channel(
request: proto::ReorderChannel,
response: Response<proto::ReorderChannel>,
session: Session,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let direction = request.direction();
let updated_channels = session
.db()
.await
.reorder_channel(channel_id, direction, session.user_id())
.await?;
if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
let connection_pool = session.connection_pool().await;
for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
let channels = updated_channels
.iter()
.filter_map(|channel| {
if role.can_see_channel(channel.visibility) {
Some(channel.to_proto())
} else {
None
}
})
.collect::<Vec<_>>();
if channels.is_empty() {
continue;
}
let update = proto::UpdateChannels {
channels,
..Default::default()
};
session.peer.send(connection_id, update.clone())?;
}
}
response.send(Ack {})?;
Ok(())
}
/// Get the list of channel members
async fn get_channel_members(
request: proto::GetChannelMembers,

View File

@@ -2624,7 +2624,6 @@ async fn test_git_diff_base_change(
client_a.fs().set_head_for_repo(
Path::new("/dir/.git"),
&[("a.txt".into(), committed_text.clone())],
"deadbeef",
);
// Create the buffer
@@ -2718,7 +2717,6 @@ async fn test_git_diff_base_change(
client_a.fs().set_head_for_repo(
Path::new("/dir/.git"),
&[("a.txt".into(), new_committed_text.clone())],
"deadbeef",
);
// Wait for buffer_local_a to receive it
@@ -3008,7 +3006,6 @@ async fn test_git_status_sync(
client_a.fs().set_head_for_repo(
path!("/dir/.git").as_ref(),
&[("b.txt".into(), "B".into()), ("c.txt".into(), "c".into())],
"deadbeef",
);
client_a.fs().set_index_for_repo(
path!("/dir/.git").as_ref(),

View File

@@ -14,9 +14,9 @@ use fuzzy::{StringMatchCandidate, match_strings};
use gpui::{
AnyElement, App, AsyncWindowContext, Bounds, ClickEvent, ClipboardItem, Context, DismissEvent,
Div, Entity, EventEmitter, FocusHandle, Focusable, FontStyle, InteractiveElement, IntoElement,
KeyContext, ListOffset, ListState, MouseDownEvent, ParentElement, Pixels, Point, PromptLevel,
Render, SharedString, Styled, Subscription, Task, TextStyle, WeakEntity, Window, actions,
anchored, canvas, deferred, div, fill, list, point, prelude::*, px,
ListOffset, ListState, MouseDownEvent, ParentElement, Pixels, Point, PromptLevel, Render,
SharedString, Styled, Subscription, Task, TextStyle, WeakEntity, Window, actions, anchored,
canvas, deferred, div, fill, list, point, prelude::*, px,
};
use menu::{Cancel, Confirm, SecondaryConfirm, SelectNext, SelectPrevious};
use project::{Fs, Project};
@@ -52,8 +52,6 @@ actions!(
StartMoveChannel,
MoveSelected,
InsertSpace,
MoveChannelUp,
MoveChannelDown,
]
);
@@ -1963,33 +1961,6 @@ impl CollabPanel {
})
}
fn move_channel_up(&mut self, _: &MoveChannelUp, window: &mut Window, cx: &mut Context<Self>) {
if let Some(channel) = self.selected_channel() {
self.channel_store.update(cx, |store, cx| {
store
.reorder_channel(channel.id, proto::reorder_channel::Direction::Up, cx)
.detach_and_prompt_err("Failed to move channel up", window, cx, |_, _, _| None)
});
}
}
fn move_channel_down(
&mut self,
_: &MoveChannelDown,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(channel) = self.selected_channel() {
self.channel_store.update(cx, |store, cx| {
store
.reorder_channel(channel.id, proto::reorder_channel::Direction::Down, cx)
.detach_and_prompt_err("Failed to move channel down", window, cx, |_, _, _| {
None
})
});
}
}
fn open_channel_notes(
&mut self,
channel_id: ChannelId,
@@ -2003,7 +1974,7 @@ impl CollabPanel {
fn show_inline_context_menu(
&mut self,
_: &Secondary,
_: &menu::SecondaryConfirm,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -2032,21 +2003,6 @@ impl CollabPanel {
}
}
fn dispatch_context(&self, window: &Window, cx: &Context<Self>) -> KeyContext {
let mut dispatch_context = KeyContext::new_with_defaults();
dispatch_context.add("CollabPanel");
dispatch_context.add("menu");
let identifier = if self.channel_name_editor.focus_handle(cx).is_focused(window) {
"editing"
} else {
"not_editing"
};
dispatch_context.add(identifier);
dispatch_context
}
fn selected_channel(&self) -> Option<&Arc<Channel>> {
self.selection
.and_then(|ix| self.entries.get(ix))
@@ -3009,7 +2965,7 @@ fn render_tree_branch(
impl Render for CollabPanel {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.key_context(self.dispatch_context(window, cx))
.key_context("CollabPanel")
.on_action(cx.listener(CollabPanel::cancel))
.on_action(cx.listener(CollabPanel::select_next))
.on_action(cx.listener(CollabPanel::select_previous))
@@ -3021,8 +2977,6 @@ impl Render for CollabPanel {
.on_action(cx.listener(CollabPanel::collapse_selected_channel))
.on_action(cx.listener(CollabPanel::expand_selected_channel))
.on_action(cx.listener(CollabPanel::start_move_selected_channel))
.on_action(cx.listener(CollabPanel::move_channel_up))
.on_action(cx.listener(CollabPanel::move_channel_down))
.track_focus(&self.focus_handle(cx))
.size_full()
.child(if self.user_store.read(cx).current_user().is_none() {

View File

@@ -448,7 +448,7 @@ impl PickerDelegate for CommandPaletteDelegate {
}
}
pub fn humanize_action_name(name: &str) -> String {
fn humanize_action_name(name: &str) -> String {
let capacity = name.len() + name.chars().filter(|c| c.is_uppercase()).count();
let mut result = String::with_capacity(capacity);
for char in name.chars() {

View File

@@ -161,7 +161,7 @@ impl ComponentMetadata {
}
/// Implement this trait to define a UI component. This will allow you to
/// derive `RegisterComponent` on it, in turn allowing you to preview the
/// derive `RegisterComponent` on it, in tutn allowing you to preview the
/// contents of the preview fn in `workspace: open component preview`.
///
/// This can be useful for visual debugging and testing, documenting UI

View File

@@ -15,9 +15,6 @@ settings.workspace = true
regex.workspace = true
util.workspace = true
workspace-hack.workspace = true
zed.workspace = true
gpui.workspace = true
command_palette.workspace = true
[lints]
workspace = true

View File

@@ -5,7 +5,6 @@ use mdbook::book::{Book, Chapter};
use mdbook::preprocess::CmdPreprocessor;
use regex::Regex;
use settings::KeymapFile;
use std::collections::HashSet;
use std::io::{self, Read};
use std::process;
use std::sync::LazyLock;
@@ -18,8 +17,6 @@ static KEYMAP_LINUX: LazyLock<KeymapFile> = LazyLock::new(|| {
load_keymap("keymaps/default-linux.json").expect("Failed to load Linux keymap")
});
static ALL_ACTIONS: LazyLock<Vec<ActionDef>> = LazyLock::new(dump_all_gpui_actions);
pub fn make_app() -> Command {
Command::new("zed-docs-preprocessor")
.about("Preprocesses Zed Docs content to provide rich action & keybinding support and more")
@@ -32,9 +29,6 @@ pub fn make_app() -> Command {
fn main() -> Result<()> {
let matches = make_app().get_matches();
// call a zed:: function so everything in `zed` crate is linked and
// all actions in the actual app are registered
zed::stdout_is_a_pty();
if let Some(sub_args) = matches.subcommand_matches("supports") {
handle_supports(sub_args);
@@ -45,43 +39,6 @@ fn main() -> Result<()> {
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Error {
ActionNotFound { action_name: String },
DeprecatedActionUsed { used: String, should_be: String },
}
impl Error {
fn new_for_not_found_action(action_name: String) -> Self {
for action in &*ALL_ACTIONS {
for alias in action.deprecated_aliases {
if alias == &action_name {
return Error::DeprecatedActionUsed {
used: action_name.clone(),
should_be: action.name.to_string(),
};
}
}
}
Error::ActionNotFound {
action_name: action_name.to_string(),
}
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::ActionNotFound { action_name } => write!(f, "Action not found: {}", action_name),
Error::DeprecatedActionUsed { used, should_be } => write!(
f,
"Deprecated action used: {} should be {}",
used, should_be
),
}
}
}
fn handle_preprocessing() -> Result<()> {
let mut stdin = io::stdin();
let mut input = String::new();
@@ -89,19 +46,8 @@ fn handle_preprocessing() -> Result<()> {
let (_ctx, mut book) = CmdPreprocessor::parse_input(input.as_bytes())?;
let mut errors = HashSet::<Error>::new();
template_and_validate_keybindings(&mut book, &mut errors);
template_and_validate_actions(&mut book, &mut errors);
if !errors.is_empty() {
const ANSI_RED: &'static str = "\x1b[31m";
const ANSI_RESET: &'static str = "\x1b[0m";
for error in &errors {
eprintln!("{ANSI_RED}ERROR{ANSI_RESET}: {}", error);
}
return Err(anyhow::anyhow!("Found {} errors in docs", errors.len()));
}
template_keybinding(&mut book);
template_action(&mut book);
serde_json::to_writer(io::stdout(), &book)?;
@@ -120,17 +66,13 @@ fn handle_supports(sub_args: &ArgMatches) -> ! {
}
}
fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet<Error>) {
fn template_keybinding(book: &mut Book) {
let regex = Regex::new(r"\{#kb (.*?)\}").unwrap();
for_each_chapter_mut(book, |chapter| {
chapter.content = regex
.replace_all(&chapter.content, |caps: &regex::Captures| {
let action = caps[1].trim();
if find_action_by_name(action).is_none() {
errors.insert(Error::new_for_not_found_action(action.to_string()));
return String::new();
}
let macos_binding = find_binding("macos", action).unwrap_or_default();
let linux_binding = find_binding("linux", action).unwrap_or_default();
@@ -144,30 +86,35 @@ fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet<Error
});
}
fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet<Error>) {
fn template_action(book: &mut Book) {
let regex = Regex::new(r"\{#action (.*?)\}").unwrap();
for_each_chapter_mut(book, |chapter| {
chapter.content = regex
.replace_all(&chapter.content, |caps: &regex::Captures| {
let name = caps[1].trim();
let Some(action) = find_action_by_name(name) else {
errors.insert(Error::new_for_not_found_action(name.to_string()));
return String::new();
};
format!("<code class=\"hljs\">{}</code>", &action.human_name)
let formatted_name = name
.chars()
.enumerate()
.map(|(i, c)| {
if i > 0 && c.is_uppercase() {
format!(" {}", c.to_lowercase())
} else {
c.to_string()
}
})
.collect::<String>()
.trim()
.to_string()
.replace("::", ":");
format!("<code class=\"hljs\">{}</code>", formatted_name)
})
.into_owned()
});
}
fn find_action_by_name(name: &str) -> Option<&ActionDef> {
ALL_ACTIONS
.binary_search_by(|action| action.name.cmp(name))
.ok()
.map(|index| &ALL_ACTIONS[index])
}
fn find_binding(os: &str, action: &str) -> Option<String> {
let keymap = match os {
"macos" => &KEYMAP_MACOS,
@@ -233,25 +180,3 @@ where
func(chapter);
});
}
#[derive(Debug, serde::Serialize)]
struct ActionDef {
name: &'static str,
human_name: String,
deprecated_aliases: &'static [&'static str],
}
fn dump_all_gpui_actions() -> Vec<ActionDef> {
let mut actions = gpui::generate_list_of_all_registered_actions()
.into_iter()
.map(|action| ActionDef {
name: action.name,
human_name: command_palette::humanize_action_name(action.name),
deprecated_aliases: action.aliases,
})
.collect::<Vec<ActionDef>>();
actions.sort_by_key(|a| a.name);
return actions;
}

View File

@@ -639,7 +639,6 @@ pub struct HighlightedChunk<'a> {
pub text: &'a str,
pub style: Option<HighlightStyle>,
pub is_tab: bool,
pub is_inlay: bool,
pub replacement: Option<ChunkReplacement>,
}
@@ -653,7 +652,6 @@ impl<'a> HighlightedChunk<'a> {
let style = self.style;
let is_tab = self.is_tab;
let renderer = self.replacement;
let is_inlay = self.is_inlay;
iter::from_fn(move || {
let mut prefix_len = 0;
while let Some(&ch) = chars.peek() {
@@ -669,7 +667,6 @@ impl<'a> HighlightedChunk<'a> {
text: prefix,
style,
is_tab,
is_inlay,
replacement: renderer.clone(),
});
}
@@ -696,7 +693,6 @@ impl<'a> HighlightedChunk<'a> {
text: prefix,
style: Some(invisible_style),
is_tab: false,
is_inlay,
replacement: Some(ChunkReplacement::Str(replacement.into())),
});
} else {
@@ -720,7 +716,6 @@ impl<'a> HighlightedChunk<'a> {
text: prefix,
style: Some(invisible_style),
is_tab: false,
is_inlay,
replacement: renderer.clone(),
});
}
@@ -733,7 +728,6 @@ impl<'a> HighlightedChunk<'a> {
text: remainder,
style,
is_tab,
is_inlay,
replacement: renderer.clone(),
})
} else {
@@ -967,10 +961,7 @@ impl DisplaySnapshot {
if chunk.is_unnecessary {
diagnostic_highlight.fade_out = Some(editor_style.unnecessary_code_fade);
}
if chunk.underline
&& editor_style.show_underlines
&& !(chunk.is_unnecessary && severity > lsp::DiagnosticSeverity::WARNING)
{
if chunk.underline && editor_style.show_underlines {
let diagnostic_color = super::diagnostic_style(severity, &editor_style.status);
diagnostic_highlight.underline = Some(UnderlineStyle {
color: Some(diagnostic_color),
@@ -990,7 +981,6 @@ impl DisplaySnapshot {
text: chunk.text,
style: highlight_style,
is_tab: chunk.is_tab,
is_inlay: chunk.is_inlay,
replacement: chunk.renderer.map(ChunkReplacement::Renderer),
}
.highlight_invisibles(editor_style)

View File

@@ -1259,8 +1259,6 @@ pub struct Chunk<'a> {
pub underline: bool,
/// Whether this chunk of text was originally a tab character.
pub is_tab: bool,
/// Whether this chunk of text was originally a tab character.
pub is_inlay: bool,
/// An optional recipe for how the chunk should be presented.
pub renderer: Option<ChunkRenderer>,
}
@@ -1426,7 +1424,6 @@ impl<'a> Iterator for FoldChunks<'a> {
diagnostic_severity: chunk.diagnostic_severity,
is_unnecessary: chunk.is_unnecessary,
is_tab: chunk.is_tab,
is_inlay: chunk.is_inlay,
underline: chunk.underline,
renderer: None,
});

View File

@@ -336,7 +336,6 @@ impl<'a> Iterator for InlayChunks<'a> {
Chunk {
text: chunk,
highlight_style,
is_inlay: true,
..Default::default()
}
}

View File

@@ -10873,54 +10873,14 @@ impl Editor {
pub fn rewrap_impl(&mut self, options: RewrapOptions, cx: &mut Context<Self>) {
let buffer = self.buffer.read(cx).snapshot(cx);
let selections = self.selections.all::<Point>(cx);
// Shrink and split selections to respect paragraph boundaries.
let ranges = selections.into_iter().flat_map(|selection| {
let language_settings = buffer.language_settings_at(selection.head(), cx);
let language_scope = buffer.language_scope_at(selection.head());
let Some(start_row) = (selection.start.row..=selection.end.row)
.find(|row| !buffer.is_line_blank(MultiBufferRow(*row)))
else {
return vec![];
};
let Some(end_row) = (selection.start.row..=selection.end.row)
.rev()
.find(|row| !buffer.is_line_blank(MultiBufferRow(*row)))
else {
return vec![];
};
let mut row = start_row;
let mut ranges = Vec::new();
while let Some(blank_row) =
(row..end_row).find(|row| buffer.is_line_blank(MultiBufferRow(*row)))
{
let next_paragraph_start = (blank_row + 1..=end_row)
.find(|row| !buffer.is_line_blank(MultiBufferRow(*row)))
.unwrap();
ranges.push((
language_settings.clone(),
language_scope.clone(),
Point::new(row, 0)..Point::new(blank_row - 1, 0),
));
row = next_paragraph_start;
}
ranges.push((
language_settings.clone(),
language_scope.clone(),
Point::new(row, 0)..Point::new(end_row, 0),
));
ranges
});
let mut selections = selections.iter().peekable();
let mut edits = Vec::new();
let mut rewrapped_row_ranges = Vec::<RangeInclusive<u32>>::new();
for (language_settings, language_scope, range) in ranges {
let mut start_row = range.start.row;
let mut end_row = range.end.row;
while let Some(selection) = selections.next() {
let mut start_row = selection.start.row;
let mut end_row = selection.end.row;
// Skip selections that overlap with a range that has already been rewrapped.
let selection_range = start_row..end_row;
@@ -10931,7 +10891,7 @@ impl Editor {
continue;
}
let tab_size = language_settings.tab_size;
let tab_size = buffer.language_settings_at(selection.head(), cx).tab_size;
// Since not all lines in the selection may be at the same indent
// level, choose the indent size that is the most common between all
@@ -10962,20 +10922,25 @@ impl Editor {
let mut line_prefix = indent_size.chars().collect::<String>();
let mut inside_comment = false;
if let Some(comment_prefix) = language_scope.and_then(|language| {
language
.line_comment_prefixes()
.iter()
.find(|prefix| buffer.contains_str_at(indent_end, prefix))
.cloned()
}) {
if let Some(comment_prefix) =
buffer
.language_scope_at(selection.head())
.and_then(|language| {
language
.line_comment_prefixes()
.iter()
.find(|prefix| buffer.contains_str_at(indent_end, prefix))
.cloned()
})
{
line_prefix.push_str(&comment_prefix);
inside_comment = true;
}
let language_settings = buffer.language_settings_at(selection.head(), cx);
let allow_rewrap_based_on_language = match language_settings.allow_rewrap {
RewrapBehavior::InComments => inside_comment,
RewrapBehavior::InSelections => !range.is_empty(),
RewrapBehavior::InSelections => !selection.is_empty(),
RewrapBehavior::Anywhere => true,
};
@@ -10986,12 +10951,11 @@ impl Editor {
continue;
}
if range.is_empty() {
if selection.is_empty() {
'expand_upwards: while start_row > 0 {
let prev_row = start_row - 1;
if buffer.contains_str_at(Point::new(prev_row, 0), &line_prefix)
&& buffer.line_len(MultiBufferRow(prev_row)) as usize > line_prefix.len()
&& !buffer.is_line_blank(MultiBufferRow(prev_row))
{
start_row = prev_row;
} else {
@@ -11003,7 +10967,6 @@ impl Editor {
let next_row = end_row + 1;
if buffer.contains_str_at(Point::new(next_row, 0), &line_prefix)
&& buffer.line_len(MultiBufferRow(next_row)) as usize > line_prefix.len()
&& !buffer.is_line_blank(MultiBufferRow(next_row))
{
end_row = next_row;
} else {

View File

@@ -1912,19 +1912,19 @@ fn test_prev_next_word_boundary(cx: &mut TestAppContext) {
assert_selection_ranges("use std::ˇstr::{foo, bar}\n\n {ˇbaz.qux()}", editor, cx);
editor.move_to_previous_word_start(&MoveToPreviousWordStart, window, cx);
assert_selection_ranges("use stdˇ::str::{foo, bar}\n\nˇ {baz.qux()}", editor, cx);
assert_selection_ranges("use stdˇ::str::{foo, bar}\n\n ˇ{baz.qux()}", editor, cx);
editor.move_to_previous_word_start(&MoveToPreviousWordStart, window, cx);
assert_selection_ranges("use ˇstd::str::{foo, bar}\nˇ\n {baz.qux()}", editor, cx);
assert_selection_ranges("use ˇstd::str::{foo, bar}\n\nˇ {baz.qux()}", editor, cx);
editor.move_to_previous_word_start(&MoveToPreviousWordStart, window, cx);
assert_selection_ranges("ˇuse std::str::{foo, bar}\nˇ\n {baz.qux()}", editor, cx);
editor.move_to_previous_word_start(&MoveToPreviousWordStart, window, cx);
assert_selection_ranges("ˇuse std::str::{foo, barˇ}\n\n {baz.qux()}", editor, cx);
editor.move_to_previous_word_start(&MoveToPreviousWordStart, window, cx);
assert_selection_ranges("ˇuse std::str::{foo, ˇbar}\n\n {baz.qux()}", editor, cx);
editor.move_to_next_word_end(&MoveToNextWordEnd, window, cx);
assert_selection_ranges("useˇ std::str::{foo, barˇ}\n\n {baz.qux()}", editor, cx);
assert_selection_ranges("useˇ std::str::{foo, bar}ˇ\n\n {baz.qux()}", editor, cx);
editor.move_to_next_word_end(&MoveToNextWordEnd, window, cx);
assert_selection_ranges("use stdˇ::str::{foo, bar}\nˇ\n {baz.qux()}", editor, cx);
@@ -1942,7 +1942,7 @@ fn test_prev_next_word_boundary(cx: &mut TestAppContext) {
editor.select_to_previous_word_start(&SelectToPreviousWordStart, window, cx);
assert_selection_ranges(
"use std«ˇ::s»tr::{foo, bar}\n\n«ˇ {b»az.qux()}",
"use std«ˇ::s»tr::{foo, bar}\n\n «ˇ{b»az.qux()}",
editor,
cx,
);
@@ -5111,7 +5111,7 @@ async fn test_rewrap(cx: &mut TestAppContext) {
nisl venenatis tempus. Donec molestie blandit quam, et porta nunc laoreet in.
Integer sit amet scelerisque nisi.
"},
plaintext_language.clone(),
plaintext_language,
&mut cx,
);
@@ -5174,69 +5174,6 @@ async fn test_rewrap(cx: &mut TestAppContext) {
&mut cx,
);
assert_rewrap(
indoc! {"
«ˇone one one one one one one one one one one one one one one one one one one one one one one one one
two»
three
«ˇ\t
four four four four four four four four four four four four four four four four four four four four»
«ˇfive five five five five five five five five five five five five five five five five five five five
\t»
six six six six six six six six six six six six six six six six six six six six six six six six six
"},
indoc! {"
«ˇone one one one one one one one one one one one one one one one one one one one
one one one one one
two»
three
«ˇ\t
four four four four four four four four four four four four four four four four
four four four four»
«ˇfive five five five five five five five five five five five five five five five
five five five five
\t»
six six six six six six six six six six six six six six six six six six six six six six six six six
"},
plaintext_language.clone(),
&mut cx,
);
assert_rewrap(
indoc! {"
//ˇ long long long long long long long long long long long long long long long long long long long long long long long long long long long long
//ˇ
//ˇ long long long long long long long long long long long long long long long long long long long long long long long long long long long long
//ˇ short short short
int main(void) {
return 17;
}
"},
indoc! {"
//ˇ long long long long long long long long long long long long long long long
// long long long long long long long long long long long long long
//ˇ
//ˇ long long long long long long long long long long long long long long long
//ˇ long long long long long long long long long long long long long short short
// short
int main(void) {
return 17;
}
"},
language_with_c_comments,
&mut cx,
);
#[track_caller]
fn assert_rewrap(
unwrapped_text: &str,
@@ -17923,7 +17860,6 @@ async fn test_display_diff_hunks(cx: &mut TestAppContext) {
("file-2".into(), "two\n".into()),
("file-3".into(), "three\n".into()),
],
"deadbeef",
);
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;

View File

@@ -6871,7 +6871,6 @@ impl LineWithInvisibles {
text: "\n",
style: None,
is_tab: false,
is_inlay: false,
replacement: None,
}]) {
if let Some(replacement) = highlighted_chunk.replacement {
@@ -7005,7 +7004,7 @@ impl LineWithInvisibles {
strikethrough: text_style.strikethrough,
});
if editor_mode.is_full() && !highlighted_chunk.is_inlay {
if editor_mode.is_full() {
// Line wrap pads its contents with fake whitespaces,
// avoid printing them
let is_soft_wrapped = is_row_soft_wrapped(row);

View File

@@ -264,18 +264,7 @@ pub fn previous_word_start(map: &DisplaySnapshot, point: DisplayPoint) -> Displa
let raw_point = point.to_point(map);
let classifier = map.buffer_snapshot.char_classifier_at(raw_point);
let mut is_first_iteration = true;
find_preceding_boundary_display_point(map, point, FindRange::MultiLine, |left, right| {
// Make alt-left skip punctuation on Mac OS to respect Mac VSCode behaviour. For example: hello.| goes to |hello.
if is_first_iteration
&& classifier.is_punctuation(right)
&& !classifier.is_punctuation(left)
{
is_first_iteration = false;
return false;
}
is_first_iteration = false;
(classifier.kind(left) != classifier.kind(right) && !classifier.is_whitespace(right))
|| left == '\n'
})
@@ -316,18 +305,8 @@ pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> Dis
pub fn next_word_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint {
let raw_point = point.to_point(map);
let classifier = map.buffer_snapshot.char_classifier_at(raw_point);
let mut is_first_iteration = true;
find_boundary(map, point, FindRange::MultiLine, |left, right| {
// Make alt-right skip punctuation on Mac OS to respect the Mac behaviour. For example: |.hello goes to .hello|
if is_first_iteration
&& classifier.is_punctuation(left)
&& !classifier.is_punctuation(right)
{
is_first_iteration = false;
return false;
}
is_first_iteration = false;
find_boundary(map, point, FindRange::MultiLine, |left, right| {
(classifier.kind(left) != classifier.kind(right) && !classifier.is_whitespace(left))
|| right == '\n'
})
@@ -803,15 +782,10 @@ mod tests {
fn assert(marked_text: &str, cx: &mut gpui::App) {
let (snapshot, display_points) = marked_display_snapshot(marked_text, cx);
let actual = previous_word_start(&snapshot, display_points[1]);
let expected = display_points[0];
if actual != expected {
eprintln!(
"previous_word_start mismatch for '{}': actual={:?}, expected={:?}",
marked_text, actual, expected
);
}
assert_eq!(actual, expected);
assert_eq!(
previous_word_start(&snapshot, display_points[1]),
display_points[0]
);
}
assert("\nˇ ˇlorem", cx);
@@ -822,17 +796,12 @@ mod tests {
assert("\nlorem\nˇ ˇipsum", cx);
assert("\n\nˇ\nˇ", cx);
assert(" ˇlorem ˇipsum", cx);
assert("ˇlorem-ˇipsum", cx);
assert("loremˇ-ˇipsum", cx);
assert("loremˇ-#$@ˇipsum", cx);
assert("ˇlorem_ˇipsum", cx);
assert(" ˇdefγˇ", cx);
assert(" ˇbcΔˇ", cx);
// Test punctuation skipping behavior
assert("ˇhello.ˇ", cx);
assert("helloˇ...ˇ", cx);
assert("helloˇ.---..ˇtest", cx);
assert("test ˇ.--ˇtest", cx);
assert("oneˇ,;:!?ˇtwo", cx);
assert(" abˇ——ˇcd", cx);
}
#[gpui::test]
@@ -986,15 +955,10 @@ mod tests {
fn assert(marked_text: &str, cx: &mut gpui::App) {
let (snapshot, display_points) = marked_display_snapshot(marked_text, cx);
let actual = next_word_end(&snapshot, display_points[0]);
let expected = display_points[1];
if actual != expected {
eprintln!(
"next_word_end mismatch for '{}': actual={:?}, expected={:?}",
marked_text, actual, expected
);
}
assert_eq!(actual, expected);
assert_eq!(
next_word_end(&snapshot, display_points[0]),
display_points[1]
);
}
assert("\nˇ loremˇ", cx);
@@ -1003,18 +967,11 @@ mod tests {
assert(" loremˇ ˇ\nipsum\n", cx);
assert("\nˇ\nˇ\n\n", cx);
assert("loremˇ ipsumˇ ", cx);
assert("loremˇ-ipsumˇ", cx);
assert("loremˇ-ˇipsum", cx);
assert("loremˇ#$@-ˇipsum", cx);
assert("loremˇ_ipsumˇ", cx);
assert(" ˇbcΔˇ", cx);
assert(" abˇ——ˇcd", cx);
// Test punctuation skipping behavior
assert("ˇ.helloˇ", cx);
assert("display_pointsˇ[0ˇ]", cx);
assert("ˇ...ˇhello", cx);
assert("helloˇ.---..ˇtest", cx);
assert("testˇ.--ˇ test", cx);
assert("oneˇ,;:!?ˇtwo", cx);
}
#[gpui::test]

View File

@@ -240,7 +240,8 @@ impl EditorTestContext {
// unlike cx.simulate_keystrokes(), this does not run_until_parked
// so you can use it to test detailed timing
pub fn simulate_keystroke(&mut self, keystroke_text: &str) {
let keystroke = Keystroke::parse(keystroke_text).unwrap();
let keyboard_mapper = self.keyboard_mapper();
let keystroke = Keystroke::parse(keystroke_text, keyboard_mapper.as_ref()).unwrap();
self.cx.dispatch_keystroke(self.window, keystroke);
}
@@ -304,7 +305,6 @@ impl EditorTestContext {
fs.set_head_for_repo(
&Self::root_path().join(".git"),
&[(path.into(), diff_base.to_string())],
"deadbeef",
);
self.cx.run_until_parked();
}

View File

@@ -1,59 +0,0 @@
use agent_settings::AgentProfileId;
use anyhow::Result;
use assistant_tools::GrepToolInput;
use async_trait::async_trait;
use crate::example::{Example, ExampleContext, ExampleMetadata};
pub struct GrepParamsEscapementExample;
/*
This eval checks that the model doesn't use HTML escapement for characters like `<` and
`>` in tool parameters.
original +system_prompt change +tool description
claude-opus-4 89% 92% 97%+
claude-sonnet-4 100%
gpt-4.1-mini 100%
gemini-2.5-pro 98%
*/
#[async_trait(?Send)]
impl Example for GrepParamsEscapementExample {
fn meta(&self) -> ExampleMetadata {
ExampleMetadata {
name: "grep_params_escapement".to_string(),
url: "https://github.com/octocat/hello-world".to_string(),
revision: "7fd1a60b01f91b314f59955a4e4d4e80d8edf11d".to_string(),
language_server: None,
max_assertions: Some(1),
profile_id: AgentProfileId::default(),
existing_thread_json: None,
max_turns: Some(2),
}
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
// cx.push_user_message("How does the precedence/specificity work with Keymap contexts? I am seeing that `MessageEditor > Editor` is lower precendence than `Editor` which is surprising to me, but might be how it works");
cx.push_user_message("Search for files containing the characters `>` or `<`");
let response = cx.run_turns(2).await?;
let grep_input = response
.find_tool_call("grep")
.and_then(|tool_use| tool_use.parse_input::<GrepToolInput>().ok());
cx.assert_some(grep_input.as_ref(), "`grep` tool should be called")?;
cx.assert(
!contains_html_entities(&grep_input.unwrap().regex),
"Tool parameters should not be escaped",
)
}
}
fn contains_html_entities(pattern: &str) -> bool {
regex::Regex::new(r"&[a-zA-Z]+;|&#[0-9]+;|&#x[0-9a-fA-F]+;")
.unwrap()
.is_match(pattern)
}

View File

@@ -16,7 +16,6 @@ mod add_arg_to_trait_method;
mod code_block_citations;
mod comment_translation;
mod file_search;
mod grep_params_escapement;
mod overwrite_file;
mod planets;
@@ -28,7 +27,6 @@ pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
Rc::new(planets::Planets),
Rc::new(comment_translation::CommentTranslation),
Rc::new(overwrite_file::FileOverwriteExample),
Rc::new(grep_params_escapement::GrepParamsEscapementExample),
];
for example_path in list_declarative_examples(examples_dir).unwrap() {

View File

@@ -1456,12 +1456,7 @@ impl FakeFs {
.unwrap();
}
pub fn set_head_for_repo(
&self,
dot_git: &Path,
head_state: &[(RepoPath, String)],
sha: impl Into<String>,
) {
pub fn set_head_for_repo(&self, dot_git: &Path, head_state: &[(RepoPath, String)]) {
self.with_git_state(dot_git, true, |state| {
state.head_contents.clear();
state.head_contents.extend(
@@ -1469,7 +1464,6 @@ impl FakeFs {
.iter()
.map(|(path, content)| (path.clone(), content.clone())),
);
state.refs.insert("HEAD".into(), sha.into());
})
.unwrap();
}

View File

@@ -1387,7 +1387,6 @@ mod tests {
fs.set_head_for_repo(
path!("/project/.git").as_ref(),
&[("foo.txt".into(), "foo\n".into())],
"deadbeef",
);
fs.set_index_for_repo(
path!("/project/.git").as_ref(),
@@ -1524,7 +1523,6 @@ mod tests {
fs.set_head_for_repo(
path!("/project/.git").as_ref(),
&[("foo".into(), "original\n".into())],
"deadbeef",
);
cx.run_until_parked();

View File

@@ -288,18 +288,6 @@ impl ActionRegistry {
}
}
/// Generate a list of all the registered actions.
/// Useful for transforming the list of available actions into a
/// format suited for static analysis such as in validating keymaps, or
/// generating documentation.
pub fn generate_list_of_all_registered_actions() -> Vec<MacroActionData> {
let mut actions = Vec::new();
for builder in inventory::iter::<MacroActionBuilder> {
actions.push(builder.0());
}
actions
}
/// Defines and registers unit structs that can be used as actions.
///
/// To use more complex data types as actions, use `impl_actions!`
@@ -345,6 +333,7 @@ macro_rules! action_as {
::std::clone::Clone, ::std::default::Default, ::std::fmt::Debug, ::std::cmp::PartialEq,
)]
pub struct $name;
gpui::__impl_action!(
$namespace,
$name,

View File

@@ -37,10 +37,10 @@ use crate::{
AssetSource, BackgroundExecutor, Bounds, ClipboardItem, CursorStyle, DispatchPhase, DisplayId,
EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, KeyContext,
Keymap, Keystroke, LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform,
PlatformDisplay, PlatformKeyboardLayout, Point, PromptBuilder, PromptButton, PromptHandle,
PromptLevel, Render, RenderImage, RenderablePromptHandle, Reservation, ScreenCaptureSource,
SharedString, SubscriberSet, Subscription, SvgRenderer, Task, TextSystem, Window,
WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, PromptBuilder,
PromptButton, PromptHandle, PromptLevel, Render, RenderImage, RenderablePromptHandle,
Reservation, ScreenCaptureSource, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
TextSystem, Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
colors::{Colors, GlobalColors},
current_platform, hash, init_app_menus,
};
@@ -262,6 +262,7 @@ pub struct App {
pub(crate) window_handles: FxHashMap<WindowId, AnyWindowHandle>,
pub(crate) focus_handles: Arc<FocusMap>,
pub(crate) keymap: Rc<RefCell<Keymap>>,
pub(crate) keyboard_mapper: Box<dyn PlatformKeyboardMapper>,
pub(crate) keyboard_layout: Box<dyn PlatformKeyboardLayout>,
pub(crate) global_action_listeners:
FxHashMap<TypeId, Vec<Rc<dyn Fn(&dyn Any, DispatchPhase, &mut Self)>>>,
@@ -308,6 +309,7 @@ impl App {
let text_system = Arc::new(TextSystem::new(platform.text_system()));
let entities = EntityMap::new();
let keyboard_mapper = platform.keyboard_mapper();
let keyboard_layout = platform.keyboard_layout();
let app = Rc::new_cyclic(|this| AppCell {
@@ -333,6 +335,7 @@ impl App {
window_handles: FxHashMap::default(),
focus_handles: Arc::new(RwLock::new(SlotMap::with_key())),
keymap: Rc::new(RefCell::new(Keymap::default())),
keyboard_mapper,
keyboard_layout,
global_action_listeners: FxHashMap::default(),
pending_effects: VecDeque::new(),
@@ -369,6 +372,7 @@ impl App {
move || {
if let Some(app) = app.upgrade() {
let cx = &mut app.borrow_mut();
cx.keyboard_mapper = cx.platform.keyboard_mapper();
cx.keyboard_layout = cx.platform.keyboard_layout();
cx.keyboard_layout_observers
.clone()
@@ -413,6 +417,11 @@ impl App {
self.quitting = false;
}
/// Get the keyboard mapper of current keyboard layout
pub fn keyboard_mapper(&self) -> &dyn PlatformKeyboardMapper {
self.keyboard_mapper.as_ref()
}
/// Get the id of the current keyboard layout
pub fn keyboard_layout(&self) -> &dyn PlatformKeyboardLayout {
self.keyboard_layout.as_ref()

View File

@@ -3,9 +3,9 @@ use crate::{
BackgroundExecutor, BorrowAppContext, Bounds, ClipboardItem, DrawPhase, Drawable, Element,
Empty, EventEmitter, ForegroundExecutor, Global, InputEvent, Keystroke, Modifiers,
ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels,
Platform, Point, Render, Result, Size, Task, TestDispatcher, TestPlatform,
TestScreenCaptureSource, TestWindow, TextSystem, VisualContext, Window, WindowBounds,
WindowHandle, WindowOptions,
Platform, PlatformKeyboardMapper, Point, Render, Result, Size, Task, TestDispatcher,
TestPlatform, TestScreenCaptureSource, TestWindow, TextSystem, VisualContext, Window,
WindowBounds, WindowHandle, WindowOptions,
};
use anyhow::{anyhow, bail};
use futures::{Stream, StreamExt, channel::oneshot};
@@ -397,14 +397,20 @@ impl TestAppContext {
self.background_executor.run_until_parked()
}
/// Returns the current keyboard mapper for this platform.
pub fn keyboard_mapper(&self) -> Box<dyn PlatformKeyboardMapper> {
self.test_platform.keyboard_mapper()
}
/// simulate_keystrokes takes a space-separated list of keys to type.
/// cx.simulate_keystrokes("cmd-shift-p b k s p enter")
/// in Zed, this will run backspace on the current editor through the command palette.
/// This will also run the background executor until it's parked.
pub fn simulate_keystrokes(&mut self, window: AnyWindowHandle, keystrokes: &str) {
let keyboard_mapper = self.test_platform.keyboard_mapper();
for keystroke in keystrokes
.split(' ')
.map(Keystroke::parse)
.map(|source| Keystroke::parse(source, keyboard_mapper.as_ref()))
.map(Result::unwrap)
{
self.dispatch_keystroke(window, keystroke);
@@ -418,7 +424,12 @@ impl TestAppContext {
/// will type abc into your current editor
/// This will also run the background executor until it's parked.
pub fn simulate_input(&mut self, window: AnyWindowHandle, input: &str) {
for keystroke in input.split("").map(Keystroke::parse).map(Result::unwrap) {
let keyboard_mapper = self.test_platform.keyboard_mapper();
for keystroke in input
.split("")
.map(|source| Keystroke::parse(source, keyboard_mapper.as_ref()))
.map(Result::unwrap)
{
self.dispatch_keystroke(window, keystroke);
}

View File

@@ -538,8 +538,15 @@ mod test {
})
.unwrap();
cx.dispatch_keystroke(*window, Keystroke::parse("a").unwrap());
cx.dispatch_keystroke(*window, Keystroke::parse("ctrl-g").unwrap());
let keyboard_mapper = cx.keyboard_mapper();
cx.dispatch_keystroke(
*window,
Keystroke::parse("a", keyboard_mapper.as_ref()).unwrap(),
);
cx.dispatch_keystroke(
*window,
Keystroke::parse("ctrl-g", keyboard_mapper.as_ref()).unwrap(),
);
window
.update(cx, |test_view, _, _| {

View File

@@ -256,7 +256,7 @@ impl Keymap {
#[cfg(test)]
mod tests {
use super::*;
use crate as gpui;
use crate::{self as gpui, TestKeyboardMapper};
use gpui::{NoAction, actions};
actions!(
@@ -292,6 +292,7 @@ mod tests {
#[test]
fn test_keymap_disabled() {
let keyboard_mapper = TestKeyboardMapper::new();
let bindings = [
KeyBinding::new("ctrl-a", ActionAlpha {}, Some("editor")),
KeyBinding::new("ctrl-b", ActionAlpha {}, Some("editor")),
@@ -306,7 +307,7 @@ mod tests {
assert!(
keymap
.bindings_for_input(
&[Keystroke::parse("ctrl-a").unwrap()],
&[Keystroke::parse("ctrl-a", &keyboard_mapper).unwrap()],
&[KeyContext::parse("barf").unwrap()],
)
.0
@@ -315,7 +316,7 @@ mod tests {
assert!(
!keymap
.bindings_for_input(
&[Keystroke::parse("ctrl-a").unwrap()],
&[Keystroke::parse("ctrl-a", &keyboard_mapper).unwrap()],
&[KeyContext::parse("editor").unwrap()],
)
.0
@@ -326,7 +327,7 @@ mod tests {
assert!(
keymap
.bindings_for_input(
&[Keystroke::parse("ctrl-a").unwrap()],
&[Keystroke::parse("ctrl-a", &keyboard_mapper).unwrap()],
&[KeyContext::parse("editor mode=full").unwrap()],
)
.0
@@ -337,7 +338,7 @@ mod tests {
assert!(
keymap
.bindings_for_input(
&[Keystroke::parse("ctrl-b").unwrap()],
&[Keystroke::parse("ctrl-b", &keyboard_mapper).unwrap()],
&[KeyContext::parse("barf").unwrap()],
)
.0
@@ -356,8 +357,9 @@ mod tests {
let mut keymap = Keymap::default();
keymap.add_bindings(bindings.clone());
let space = || Keystroke::parse("space").unwrap();
let w = || Keystroke::parse("w").unwrap();
let keyboard_mapper = TestKeyboardMapper::new();
let space = || Keystroke::parse("space", &keyboard_mapper).unwrap();
let w = || Keystroke::parse("w", &keyboard_mapper).unwrap();
let space_w = [space(), w()];
let space_w_w = [space(), w(), w()];

View File

@@ -2,7 +2,9 @@ use std::rc::Rc;
use collections::HashMap;
use crate::{Action, InvalidKeystrokeError, KeyBindingContextPredicate, Keystroke};
use crate::{
Action, InvalidKeystrokeError, KeyBindingContextPredicate, Keystroke, PlatformKeyboardMapper,
};
use smallvec::SmallVec;
/// A keybinding and its associated metadata, from the keymap.
@@ -30,7 +32,14 @@ impl KeyBinding {
} else {
None
};
Self::load(keystrokes, Box::new(action), context_predicate, None).unwrap()
Self::load(
keystrokes,
Box::new(action),
context_predicate,
None,
&crate::EmptyKeyboardMapper,
)
.unwrap()
}
/// Load a keybinding from the given raw data.
@@ -39,10 +48,11 @@ impl KeyBinding {
action: Box<dyn Action>,
context_predicate: Option<Rc<KeyBindingContextPredicate>>,
key_equivalents: Option<&HashMap<char, char>>,
keyboard_mapper: &dyn PlatformKeyboardMapper,
) -> std::result::Result<Self, InvalidKeystrokeError> {
let mut keystrokes: SmallVec<[Keystroke; 2]> = keystrokes
.split_whitespace()
.map(Keystroke::parse)
.map(|source| Keystroke::parse(source, keyboard_mapper))
.collect::<std::result::Result<_, _>>()?;
if let Some(equivalents) = key_equivalents {

View File

@@ -1,5 +1,6 @@
mod app_menu;
mod keyboard;
mod keycodes;
mod keystroke;
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
@@ -66,6 +67,7 @@ use uuid::Uuid;
pub use app_menu::*;
pub use keyboard::*;
pub use keycodes::*;
pub use keystroke::*;
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
@@ -194,7 +196,6 @@ pub(crate) trait Platform: 'static {
fn on_quit(&self, callback: Box<dyn FnMut()>);
fn on_reopen(&self, callback: Box<dyn FnMut()>);
fn on_keyboard_layout_change(&self, callback: Box<dyn FnMut()>);
fn set_menus(&self, menus: Vec<Menu>, keymap: &Keymap);
fn get_menus(&self) -> Option<Vec<OwnedMenu>> {
@@ -214,7 +215,6 @@ pub(crate) trait Platform: 'static {
fn on_app_menu_action(&self, callback: Box<dyn FnMut(&dyn Action)>);
fn on_will_open_app_menu(&self, callback: Box<dyn FnMut()>);
fn on_validate_app_menu_command(&self, callback: Box<dyn FnMut(&dyn Action) -> bool>);
fn keyboard_layout(&self) -> Box<dyn PlatformKeyboardLayout>;
fn compositor_name(&self) -> &'static str {
""
@@ -235,6 +235,10 @@ pub(crate) trait Platform: 'static {
fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Task<Result<()>>;
fn read_credentials(&self, url: &str) -> Task<Result<Option<(String, Vec<u8>)>>>;
fn delete_credentials(&self, url: &str) -> Task<Result<()>>;
fn keyboard_mapper(&self) -> Box<dyn PlatformKeyboardMapper>;
fn keyboard_layout(&self) -> Box<dyn PlatformKeyboardLayout>;
fn on_keyboard_layout_change(&self, callback: Box<dyn FnMut()>);
}
/// A handle to a platform's display, e.g. a monitor or laptop screen.

View File

@@ -1,3 +1,7 @@
use anyhow::Result;
use super::ScanCode;
/// A trait for platform-specific keyboard layouts
pub trait PlatformKeyboardLayout {
/// Get the keyboard layout ID, which should be unique to the layout
@@ -5,3 +9,150 @@ pub trait PlatformKeyboardLayout {
/// Get the keyboard layout display name
fn name(&self) -> &str;
}
/// TODO:
pub trait PlatformKeyboardMapper {
/// TODO:
fn scan_code_to_key(&self, scan_code: ScanCode) -> Result<String>;
/// TODO:
fn get_shifted_key(&self, key: &str) -> Result<Option<String>>;
}
/// TODO:
pub struct TestKeyboardMapper {
#[cfg(target_os = "windows")]
mapper: super::WindowsKeyboardMapper,
#[cfg(target_os = "macos")]
mapper: super::MacKeyboardMapper,
#[cfg(target_os = "linux")]
mapper: super::LinuxKeyboardMapper,
}
impl PlatformKeyboardMapper for TestKeyboardMapper {
fn scan_code_to_key(&self, scan_code: ScanCode) -> Result<String> {
self.mapper.scan_code_to_key(scan_code)
}
fn get_shifted_key(&self, key: &str) -> Result<Option<String>> {
self.mapper.get_shifted_key(key)
}
}
impl TestKeyboardMapper {
/// TODO:
pub fn new() -> Self {
Self {
#[cfg(target_os = "windows")]
mapper: super::WindowsKeyboardMapper::new(),
#[cfg(target_os = "macos")]
mapper: super::MacKeyboardMapper::new(),
#[cfg(target_os = "linux")]
mapper: super::LinuxKeyboardMapper::new(),
}
}
}
/// A dummy keyboard mapper that does not support any key mappings
pub struct EmptyKeyboardMapper;
impl PlatformKeyboardMapper for EmptyKeyboardMapper {
fn scan_code_to_key(&self, _scan_code: ScanCode) -> Result<String> {
anyhow::bail!("EmptyKeyboardMapper does not support scan codes")
}
fn get_shifted_key(&self, _key: &str) -> Result<Option<String>> {
anyhow::bail!("EmptyKeyboardMapper does not support shifted keys")
}
}
pub(crate) fn is_alphabetic_key(key: &str) -> bool {
matches!(
key,
"a" | "b"
| "c"
| "d"
| "e"
| "f"
| "g"
| "h"
| "i"
| "j"
| "k"
| "l"
| "m"
| "n"
| "o"
| "p"
| "q"
| "r"
| "s"
| "t"
| "u"
| "v"
| "w"
| "x"
| "y"
| "z"
)
}
#[cfg(test)]
mod tests {
use strum::IntoEnumIterator;
use crate::ScanCode;
use super::{PlatformKeyboardMapper, TestKeyboardMapper};
#[test]
fn test_get_shifted_key() {
let mapper = TestKeyboardMapper::new();
for ch in 'a'..='z' {
let key = ch.to_string();
let shifted_key = key.to_uppercase();
assert_eq!(mapper.get_shifted_key(&key).unwrap().unwrap(), shifted_key);
}
let shift_pairs = [
("1", "!"),
("2", "@"),
("3", "#"),
("4", "$"),
("5", "%"),
("6", "^"),
("7", "&"),
("8", "*"),
("9", "("),
("0", ")"),
("`", "~"),
("-", "_"),
("=", "+"),
("[", "{"),
("]", "}"),
("\\", "|"),
(";", ":"),
("'", "\""),
(",", "<"),
(".", ">"),
("/", "?"),
];
for (key, shifted_key) in shift_pairs {
assert_eq!(mapper.get_shifted_key(key).unwrap().unwrap(), shifted_key);
}
let immutable_keys = ["backspace", "space", "tab", "enter", "f1"];
for key in immutable_keys {
assert_eq!(mapper.get_shifted_key(key).unwrap(), None);
}
}
#[test]
fn test_scan_code_to_key() {
let mapper = TestKeyboardMapper::new();
for scan_code in ScanCode::iter() {
let key = mapper.scan_code_to_key(scan_code).unwrap();
assert_eq!(key, scan_code.to_key());
}
}
}

View File

@@ -0,0 +1,464 @@
use strum::EnumIter;
/// Scan codes for the keyboard, which are used to identify keys in a keyboard layout-independent way.
/// Currently, we only support a limited set of scan codes here:
/// https://code.visualstudio.com/docs/configure/keybindings#_keyboard-layoutindependent-bindings
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumIter)]
pub enum ScanCode {
/// F1 key
F1,
/// F1 key
F2,
/// F1 key
F3,
/// F1 key
F4,
/// F1 key
F5,
/// F1 key
F6,
/// F1 key
F7,
/// F1 key
F8,
/// F1 key
F9,
/// F1 key
F10,
/// F1 key
F11,
/// F1 key
F12,
/// F1 key
F13,
/// F1 key
F14,
/// F1 key
F15,
/// F1 key
F16,
/// F1 key
F17,
/// F1 key
F18,
/// F1 key
F19,
/// F20 key
F20,
/// F20 key
F21,
/// F20 key
F22,
/// F20 key
F23,
/// F20 key
F24,
/// A key on the main keyboard.
A,
/// B key on the main keyboard.
B,
/// C key on the main keyboard.
C,
/// D key on the main keyboard.
D,
/// E key on the main keyboard.
E,
/// F key on the main keyboard.
F,
/// G key on the main keyboard.
G,
/// H key on the main keyboard.
H,
/// I key on the main keyboard.
I,
/// J key on the main keyboard.
J,
/// K key on the main keyboard.
K,
/// L key on the main keyboard.
L,
/// M key on the main keyboard.
M,
/// N key on the main keyboard.
N,
/// O key on the main keyboard.
O,
/// P key on the main keyboard.
P,
/// Q key on the main keyboard.
Q,
/// R key on the main keyboard.
R,
/// S key on the main keyboard.
S,
/// T key on the main keyboard.
T,
/// U key on the main keyboard.
U,
/// V key on the main keyboard.
V,
/// W key on the main keyboard.
W,
/// X key on the main keyboard.
X,
/// Y key on the main keyboard.
Y,
/// Z key on the main keyboard.
Z,
/// 0 key on the main keyboard.
Digit0,
/// 1 key on the main keyboard.
Digit1,
/// 2 key on the main keyboard.
Digit2,
/// 3 key on the main keyboard.
Digit3,
/// 4 key on the main keyboard.
Digit4,
/// 5 key on the main keyboard.
Digit5,
/// 6 key on the main keyboard.
Digit6,
/// 7 key on the main keyboard.
Digit7,
/// 8 key on the main keyboard.
Digit8,
/// 9 key on the main keyboard.
Digit9,
/// Backquote key on the main keyboard: `
Backquote,
/// Minus key on the main keyboard: -
Minus,
/// Equal key on the main keyboard: =
Equal,
/// BracketLeft key on the main keyboard: [
BracketLeft,
/// BracketRight key on the main keyboard: ]
BracketRight,
/// Backslash key on the main keyboard: \
Backslash,
/// Semicolon key on the main keyboard: ;
Semicolon,
/// Quote key on the main keyboard: '
Quote,
/// Comma key on the main keyboard: ,
Comma,
/// Period key on the main keyboard: .
Period,
/// Slash key on the main keyboard: /
Slash,
/// Left arrow key
Left,
/// Up arrow key
Up,
/// Right arrow key
Right,
/// Down arrow key
Down,
/// PAGE UP key
PageUp,
/// PAGE DOWN key
PageDown,
/// END key
End,
/// HOME key
Home,
/// TAB key
Tab,
/// ENTER key, also known as RETURN key
/// This does not distinguish between the main Enter key and the numeric keypad Enter key.
Enter,
/// ESCAPE key
Escape,
/// SPACE key
Space,
/// BACKSPACE key
Backspace,
/// DELETE key
Delete,
// Pause, not supported yet
// CapsLock, not supported yet
/// INSERT key
Insert,
// The following keys are not supported yet:
// Numpad0,
// Numpad1,
// Numpad2,
// Numpad3,
// Numpad4,
// Numpad5,
// Numpad6,
// Numpad7,
// Numpad8,
// Numpad9,
// NumpadMultiply,
// NumpadAdd,
// NumpadComma,
// NumpadSubtract,
// NumpadDecimal,
// NumpadDivide,
}
impl ScanCode {
/// Parse a scan code from a string.
pub fn parse(source: &str) -> Option<Self> {
match source {
"[f1]" => Some(Self::F1),
"[f2]" => Some(Self::F2),
"[f3]" => Some(Self::F3),
"[f4]" => Some(Self::F4),
"[f5]" => Some(Self::F5),
"[f6]" => Some(Self::F6),
"[f7]" => Some(Self::F7),
"[f8]" => Some(Self::F8),
"[f9]" => Some(Self::F9),
"[f10]" => Some(Self::F10),
"[f11]" => Some(Self::F11),
"[f12]" => Some(Self::F12),
"[f13]" => Some(Self::F13),
"[f14]" => Some(Self::F14),
"[f15]" => Some(Self::F15),
"[f16]" => Some(Self::F16),
"[f17]" => Some(Self::F17),
"[f18]" => Some(Self::F18),
"[f19]" => Some(Self::F19),
"[f20]" => Some(Self::F20),
"[f21]" => Some(Self::F21),
"[f22]" => Some(Self::F22),
"[f23]" => Some(Self::F23),
"[f24]" => Some(Self::F24),
"[a]" | "[keya]" => Some(Self::A),
"[b]" | "[keyb]" => Some(Self::B),
"[c]" | "[keyc]" => Some(Self::C),
"[d]" | "[keyd]" => Some(Self::D),
"[e]" | "[keye]" => Some(Self::E),
"[f]" | "[keyf]" => Some(Self::F),
"[g]" | "[keyg]" => Some(Self::G),
"[h]" | "[keyh]" => Some(Self::H),
"[i]" | "[keyi]" => Some(Self::I),
"[j]" | "[keyj]" => Some(Self::J),
"[k]" | "[keyk]" => Some(Self::K),
"[l]" | "[keyl]" => Some(Self::L),
"[m]" | "[keym]" => Some(Self::M),
"[n]" | "[keyn]" => Some(Self::N),
"[o]" | "[keyo]" => Some(Self::O),
"[p]" | "[keyp]" => Some(Self::P),
"[q]" | "[keyq]" => Some(Self::Q),
"[r]" | "[keyr]" => Some(Self::R),
"[s]" | "[keys]" => Some(Self::S),
"[t]" | "[keyt]" => Some(Self::T),
"[u]" | "[keyu]" => Some(Self::U),
"[v]" | "[keyv]" => Some(Self::V),
"[w]" | "[keyw]" => Some(Self::W),
"[x]" | "[keyx]" => Some(Self::X),
"[y]" | "[keyy]" => Some(Self::Y),
"[z]" | "[keyz]" => Some(Self::Z),
"[0]" | "[digit0]" => Some(Self::Digit0),
"[1]" | "[digit1]" => Some(Self::Digit1),
"[2]" | "[digit2]" => Some(Self::Digit2),
"[3]" | "[digit3]" => Some(Self::Digit3),
"[4]" | "[digit4]" => Some(Self::Digit4),
"[5]" | "[digit5]" => Some(Self::Digit5),
"[6]" | "[digit6]" => Some(Self::Digit6),
"[7]" | "[digit7]" => Some(Self::Digit7),
"[8]" | "[digit8]" => Some(Self::Digit8),
"[9]" | "[digit9]" => Some(Self::Digit9),
"[backquote]" => Some(Self::Backquote),
"[minus]" => Some(Self::Minus),
"[equal]" => Some(Self::Equal),
"[bracketleft]" => Some(Self::BracketLeft),
"[bracketright]" => Some(Self::BracketRight),
"[backslash]" => Some(Self::Backslash),
"[semicolon]" => Some(Self::Semicolon),
"[quote]" => Some(Self::Quote),
"[comma]" => Some(Self::Comma),
"[period]" => Some(Self::Period),
"[slash]" => Some(Self::Slash),
"[left]" | "[arrowleft]" => Some(Self::Left),
"[up]" | "[arrowup]" => Some(Self::Up),
"[right]" | "[arrowright]" => Some(Self::Right),
"[down]" | "[arrowdown]" => Some(Self::Down),
"[pageup]" => Some(Self::PageUp),
"[pagedown]" => Some(Self::PageDown),
"[end]" => Some(Self::End),
"[home]" => Some(Self::Home),
"[tab]" => Some(Self::Tab),
"[enter]" => Some(Self::Enter),
"[escape]" => Some(Self::Escape),
"[space]" => Some(Self::Space),
"[backspace]" => Some(Self::Backspace),
"[delete]" => Some(Self::Delete),
// "[pause]" => Some(Self::Pause),
// "[capslock]" => Some(Self::CapsLock),
"[insert]" => Some(Self::Insert),
// "[numpad0]" => Some(Self::Numpad0),
// "[numpad1]" => Some(Self::Numpad1),
// "[numpad2]" => Some(Self::Numpad2),
// "[numpad3]" => Some(Self::Numpad3),
// "[numpad4]" => Some(Self::Numpad4),
// "[numpad5]" => Some(Self::Numpad5),
// "[numpad6]" => Some(Self::Numpad6),
// "[numpad7]" => Some(Self::Numpad7),
// "[numpad8]" => Some(Self::Numpad8),
// "[numpad9]" => Some(Self::Numpad9),
// "[numpadmultiply]" => Some(Self::NumpadMultiply),
// "[numpadadd]" => Some(Self::NumpadAdd),
// "[numpadcomma]" => Some(Self::NumpadComma),
// "[numpadsubtract]" => Some(Self::NumpadSubtract),
// "[numpaddecimal]" => Some(Self::NumpadDecimal),
// "[numpaddivide]" => Some(Self::NumpadDivide),
_ => None,
}
}
/// Convert the scan code to its key face for immutable keys.
pub fn try_to_key(&self) -> Option<String> {
Some(
match self {
ScanCode::F1 => "f1",
ScanCode::F2 => "f2",
ScanCode::F3 => "f3",
ScanCode::F4 => "f4",
ScanCode::F5 => "f5",
ScanCode::F6 => "f6",
ScanCode::F7 => "f7",
ScanCode::F8 => "f8",
ScanCode::F9 => "f9",
ScanCode::F10 => "f10",
ScanCode::F11 => "f11",
ScanCode::F12 => "f12",
ScanCode::F13 => "f13",
ScanCode::F14 => "f14",
ScanCode::F15 => "f15",
ScanCode::F16 => "f16",
ScanCode::F17 => "f17",
ScanCode::F18 => "f18",
ScanCode::F19 => "f19",
ScanCode::F20 => "f20",
ScanCode::F21 => "f21",
ScanCode::F22 => "f22",
ScanCode::F23 => "f23",
ScanCode::F24 => "f24",
ScanCode::Left => "left",
ScanCode::Up => "up",
ScanCode::Right => "right",
ScanCode::Down => "down",
ScanCode::PageUp => "pageup",
ScanCode::PageDown => "pagedown",
ScanCode::End => "end",
ScanCode::Home => "home",
ScanCode::Tab => "tab",
ScanCode::Enter => "enter",
ScanCode::Escape => "escape",
ScanCode::Space => "space",
ScanCode::Backspace => "backspace",
ScanCode::Delete => "delete",
ScanCode::Insert => "insert",
_ => return None,
}
.to_string(),
)
}
/// This function is used to convert the scan code to its key face on US keyboard layout.
/// Only used for tests and Linux.
pub fn to_key(&self) -> &str {
match self {
ScanCode::F1 => "f1",
ScanCode::F2 => "f2",
ScanCode::F3 => "f3",
ScanCode::F4 => "f4",
ScanCode::F5 => "f5",
ScanCode::F6 => "f6",
ScanCode::F7 => "f7",
ScanCode::F8 => "f8",
ScanCode::F9 => "f9",
ScanCode::F10 => "f10",
ScanCode::F11 => "f11",
ScanCode::F12 => "f12",
ScanCode::F13 => "f13",
ScanCode::F14 => "f14",
ScanCode::F15 => "f15",
ScanCode::F16 => "f16",
ScanCode::F17 => "f17",
ScanCode::F18 => "f18",
ScanCode::F19 => "f19",
ScanCode::F20 => "f20",
ScanCode::F21 => "f21",
ScanCode::F22 => "f22",
ScanCode::F23 => "f23",
ScanCode::F24 => "f24",
ScanCode::A => "a",
ScanCode::B => "b",
ScanCode::C => "c",
ScanCode::D => "d",
ScanCode::E => "e",
ScanCode::F => "f",
ScanCode::G => "g",
ScanCode::H => "h",
ScanCode::I => "i",
ScanCode::J => "j",
ScanCode::K => "k",
ScanCode::L => "l",
ScanCode::M => "m",
ScanCode::N => "n",
ScanCode::O => "o",
ScanCode::P => "p",
ScanCode::Q => "q",
ScanCode::R => "r",
ScanCode::S => "s",
ScanCode::T => "t",
ScanCode::U => "u",
ScanCode::V => "v",
ScanCode::W => "w",
ScanCode::X => "x",
ScanCode::Y => "y",
ScanCode::Z => "z",
ScanCode::Digit0 => "0",
ScanCode::Digit1 => "1",
ScanCode::Digit2 => "2",
ScanCode::Digit3 => "3",
ScanCode::Digit4 => "4",
ScanCode::Digit5 => "5",
ScanCode::Digit6 => "6",
ScanCode::Digit7 => "7",
ScanCode::Digit8 => "8",
ScanCode::Digit9 => "9",
ScanCode::Backquote => "`",
ScanCode::Minus => "-",
ScanCode::Equal => "=",
ScanCode::BracketLeft => "[",
ScanCode::BracketRight => "]",
ScanCode::Backslash => "\\",
ScanCode::Semicolon => ";",
ScanCode::Quote => "'",
ScanCode::Comma => ",",
ScanCode::Period => ".",
ScanCode::Slash => "/",
ScanCode::Left => "left",
ScanCode::Up => "up",
ScanCode::Right => "right",
ScanCode::Down => "down",
ScanCode::PageUp => "pageup",
ScanCode::PageDown => "pagedown",
ScanCode::End => "end",
ScanCode::Home => "home",
ScanCode::Tab => "tab",
ScanCode::Enter => "enter",
ScanCode::Escape => "escape",
ScanCode::Space => "space",
ScanCode::Backspace => "backspace",
ScanCode::Delete => "delete",
ScanCode::Insert => "insert",
}
}
}

View File

@@ -4,6 +4,9 @@ use std::{
error::Error,
fmt::{Display, Write},
};
use util::ResultExt;
use super::{PlatformKeyboardMapper, ScanCode, is_alphabetic_key};
/// A keystroke and associated metadata generated by the platform
#[derive(Clone, Debug, Eq, PartialEq, Default, Deserialize, Hash)]
@@ -93,12 +96,41 @@ impl Keystroke {
/// key_char syntax is only used for generating test events,
/// secondary means "cmd" on macOS and "ctrl" on other platforms
/// when matching a key with an key_char set will be matched without it.
pub fn parse(source: &str) -> std::result::Result<Self, InvalidKeystrokeError> {
pub fn parse(
source: &str,
keyboard_mapper: &dyn PlatformKeyboardMapper,
) -> std::result::Result<Self, InvalidKeystrokeError> {
let mut keystroke = Keystroke::parse_keystroke_components(source, '-')?;
// Create error once for reuse
let error = || InvalidKeystrokeError {
keystroke: source.to_owned(),
};
if keystroke.key.starts_with("oem") {
// The oem_key will be handled after https://github.com/zed-industries/zed/pull/29144
return Err(error());
}
if keystroke.key.starts_with('[') && keystroke.key.ends_with(']') {
let scan_code = ScanCode::parse(&keystroke.key).ok_or_else(error)?;
keystroke.key = keyboard_mapper
.scan_code_to_key(scan_code)
.map_err(|_| error())?;
}
Ok(keystroke.into_gpui_style(keyboard_mapper))
}
/// Parses a keystroke string representation into a `Keystroke` struct using a specified separator character.
/// This is the low-level parsing function that handles the basic string format without additional
/// platform-specific mapping or transformations.
pub fn parse_keystroke_components(
source: &str,
separator: char,
) -> std::result::Result<Self, InvalidKeystrokeError> {
let mut modifiers = Modifiers::none();
let mut key = None;
let mut key_char = None;
let mut components = source.split('-').peekable();
let mut components = source.split(separator).peekable();
while let Some(component) = components.next() {
if component.eq_ignore_ascii_case("ctrl") {
modifiers.control = true;
@@ -137,8 +169,8 @@ impl Keystroke {
let mut key_str = component.to_string();
if let Some(next) = components.peek() {
if next.is_empty() && source.ends_with('-') {
key = Some(String::from("-"));
if next.is_empty() && source.ends_with(separator) {
key = Some(String::from(separator));
break;
} else if next.len() > 1 && next.starts_with('>') {
key = Some(key_str);
@@ -187,7 +219,6 @@ impl Keystroke {
let key = key.ok_or_else(|| InvalidKeystrokeError {
keystroke: source.to_owned(),
})?;
Ok(Keystroke {
modifiers,
key,
@@ -195,6 +226,22 @@ impl Keystroke {
})
}
/// Converts this keystroke to a GPUI style keystroke.
/// For example, `ctrl-shift-[` becomes `ctrl-{`, `ctrl-shift-=` becomes `ctrl-+`.
pub fn into_gpui_style(mut self, keyboard_mapper: &dyn PlatformKeyboardMapper) -> Keystroke {
if self.modifiers.shift && !is_alphabetic_key(&self.key) && self.key.len() == 1 {
if let Some(shifted_key) = keyboard_mapper
.get_shifted_key(&self.key)
.log_err()
.flatten()
{
self.modifiers.shift = false;
self.key = shifted_key;
}
}
self
}
/// Produces a representation of this key that Parse can understand.
pub fn unparse(&self) -> String {
let mut str = String::new();
@@ -538,3 +585,200 @@ impl Modifiers {
&& (other.function || !self.function)
}
}
#[cfg(test)]
mod tests {
use crate::{Keystroke, Modifiers, TestKeyboardMapper};
#[test]
fn test_different_separators() {
assert_eq!(
Keystroke::parse_keystroke_components("ctrl-alt--", '-').unwrap(),
Keystroke::parse_keystroke_components("ctrl+alt+-", '+').unwrap(),
);
assert_eq!(
Keystroke::parse_keystroke_components("ctrl-alt-+", '-').unwrap(),
Keystroke::parse_keystroke_components("ctrl+alt++", '+').unwrap(),
);
assert_eq!(
Keystroke::parse_keystroke_components("ctrl-alt-[Minus]", '-').unwrap(),
Keystroke::parse_keystroke_components("ctrl+alt+[Minus]", '+').unwrap(),
);
assert_eq!(
Keystroke::parse_keystroke_components("ctrl-alt-[张小白]", '-').unwrap(),
Keystroke::parse_keystroke_components("ctrl+alt+[张小白]", '+').unwrap(),
);
}
#[test]
fn test_parse_scan_code() {
let keyboard_mapper = TestKeyboardMapper::new();
for letter in 'a'..='z' {
let key1 = format!("[Key{}]", letter.to_uppercase());
let key2 = format!("[{}]", letter.to_uppercase());
let keystroke1 = Keystroke::parse(&key1, &keyboard_mapper).unwrap();
let keystroke2 = Keystroke::parse(&key2, &keyboard_mapper).unwrap();
assert_eq!(
keystroke1,
Keystroke {
modifiers: Modifiers::default(),
key: letter.to_string(),
key_char: None,
}
);
assert_eq!(keystroke1, keystroke2);
let source1 = format!("ctrl-{}", key1);
let source2 = format!("ctrl-{}", key2);
let keystroke1 = Keystroke::parse(&source1, &keyboard_mapper).unwrap();
let keystroke2 = Keystroke::parse(&source2, &keyboard_mapper).unwrap();
assert_eq!(
keystroke1,
Keystroke {
modifiers: Modifiers::control(),
key: letter.to_string(),
key_char: None,
}
);
assert_eq!(keystroke1, keystroke2);
let source1 = format!("ctrl-alt-{}", key1);
let source2 = format!("ctrl-alt-{}", key2);
let keystroke1 = Keystroke::parse(&source1, &keyboard_mapper).unwrap();
let keystroke2 = Keystroke::parse(&source2, &keyboard_mapper).unwrap();
assert_eq!(
keystroke1,
Keystroke {
modifiers: Modifiers {
control: true,
alt: true,
..Default::default()
},
key: letter.to_string(),
key_char: None,
}
);
assert_eq!(keystroke1, keystroke2);
let source1 = format!("ctrl-shift-{}", key1);
let source2 = format!("ctrl-shift-{}", key2);
let keystroke1 = Keystroke::parse(&source1, &keyboard_mapper).unwrap();
let keystroke2 = Keystroke::parse(&source2, &keyboard_mapper).unwrap();
assert_eq!(
keystroke1,
Keystroke {
modifiers: Modifiers::control_shift(),
key: letter.to_string(),
key_char: None,
}
);
assert_eq!(keystroke1, keystroke2);
}
let other_keys = [
("[Backquote]", "`", "~"),
("[Minus]", "-", "_"),
("[Equal]", "=", "+"),
("[BracketLeft]", "[", "{"),
("[BracketRight]", "]", "}"),
("[Backslash]", "\\", "|"),
("[Semicolon]", ";", ":"),
("[Quote]", "'", "\""),
("[Comma]", ",", "<"),
("[Period]", ".", ">"),
("[Slash]", "/", "?"),
];
for (code, key, shifted_key) in other_keys {
assert_eq!(
Keystroke::parse(code, &keyboard_mapper).unwrap(),
Keystroke {
modifiers: Modifiers::default(),
key: key.to_string(),
key_char: None,
}
);
assert_eq!(
Keystroke::parse(&format!("shift-{}", code), &keyboard_mapper).unwrap(),
Keystroke {
modifiers: Modifiers::default(),
key: shifted_key.to_string(),
key_char: None,
}
);
assert_eq!(
Keystroke::parse(&format!("shift-{}", key), &keyboard_mapper).unwrap(),
Keystroke {
modifiers: Modifiers::default(),
key: shifted_key.to_string(),
key_char: None,
}
);
assert_eq!(
Keystroke::parse(&format!("ctrl-{}", code), &keyboard_mapper).unwrap(),
Keystroke {
modifiers: Modifiers::control(),
key: key.to_string(),
key_char: None,
}
);
assert_eq!(
Keystroke::parse(&format!("ctrl-alt-{}", code), &keyboard_mapper).unwrap(),
Keystroke {
modifiers: Modifiers {
control: true,
alt: true,
..Default::default()
},
key: key.to_string(),
key_char: None,
}
);
assert_eq!(
Keystroke::parse(&format!("ctrl-alt-{}", key), &keyboard_mapper).unwrap(),
Keystroke {
modifiers: Modifiers {
control: true,
alt: true,
..Default::default()
},
key: key.to_string(),
key_char: None,
}
);
assert_eq!(
Keystroke::parse(&format!("ctrl-shift-{}", code), &keyboard_mapper).unwrap(),
Keystroke {
modifiers: Modifiers::control(),
key: shifted_key.to_string(),
key_char: None,
}
);
assert_eq!(
Keystroke::parse(&format!("ctrl-alt-shift-{}", code), &keyboard_mapper).unwrap(),
Keystroke {
modifiers: Modifiers {
control: true,
alt: true,
..Default::default()
},
key: shifted_key.to_string(),
key_char: None,
}
);
assert_eq!(
Keystroke::parse(&format!("ctrl-alt-shift-{}", key), &keyboard_mapper).unwrap(),
Keystroke {
modifiers: Modifiers {
control: true,
alt: true,
..Default::default()
},
key: shifted_key.to_string(),
key_char: None,
}
);
}
}
}

View File

@@ -1,4 +1,20 @@
use crate::PlatformKeyboardLayout;
#[cfg(any(feature = "wayland", feature = "x11"))]
use std::sync::LazyLock;
#[cfg(any(feature = "wayland", feature = "x11"))]
use collections::HashMap;
#[cfg(any(feature = "wayland", feature = "x11"))]
use x11rb::{protocol::xkb::ConnectionExt, xcb_ffi::XCBConnection};
#[cfg(any(feature = "wayland", feature = "x11"))]
use xkbcommon::xkb::{
Keycode,
x11::ffi::{XKB_X11_MIN_MAJOR_XKB_VERSION, XKB_X11_MIN_MINOR_XKB_VERSION},
};
use crate::{PlatformKeyboardLayout, PlatformKeyboardMapper, ScanCode};
#[cfg(any(feature = "wayland", feature = "x11"))]
use crate::is_alphabetic_key;
pub(crate) struct LinuxKeyboardLayout {
id: String,
@@ -19,3 +35,332 @@ impl LinuxKeyboardLayout {
Self { id }
}
}
#[cfg(any(feature = "wayland", feature = "x11"))]
pub(crate) struct LinuxKeyboardMapper {
code_to_key: HashMap<Keycode, String>,
key_to_code: HashMap<String, Keycode>,
code_to_shifted_key: HashMap<Keycode, String>,
}
#[cfg(any(feature = "wayland", feature = "x11"))]
impl PlatformKeyboardMapper for LinuxKeyboardMapper {
fn scan_code_to_key(&self, gpui_scan_code: ScanCode) -> anyhow::Result<String> {
if let Some(key) = gpui_scan_code.try_to_key() {
return Ok(key);
}
let Some(scan_code) = get_scan_code(gpui_scan_code) else {
return Err(anyhow::anyhow!("Scan code not found: {:?}", gpui_scan_code));
};
if let Some(key) = self.code_to_key.get(&Keycode::new(scan_code)) {
Ok(key.clone())
} else {
Err(anyhow::anyhow!(
"Key not found for input scan code: {:?}, scan code: {}",
gpui_scan_code,
scan_code
))
}
}
fn get_shifted_key(&self, key: &str) -> anyhow::Result<Option<String>> {
if is_immutable_key(key) {
return Ok(None);
}
if is_alphabetic_key(key) {
return Ok(Some(key.to_uppercase()));
}
let Some(scan_code) = self.key_to_code.get(key) else {
return Err(anyhow::anyhow!("Key not found: {}", key));
};
if let Some(shifted_key) = self.code_to_shifted_key.get(scan_code) {
Ok(Some(shifted_key.clone()))
} else {
Err(anyhow::anyhow!(
"Shifted key not found for key {} with scan code: {:?}",
key,
scan_code
))
}
}
}
#[cfg(any(feature = "wayland", feature = "x11"))]
static XCB_CONNECTION: LazyLock<XCBConnection> =
LazyLock::new(|| XCBConnection::connect(None).unwrap().0);
#[cfg(any(feature = "wayland", feature = "x11"))]
impl LinuxKeyboardMapper {
pub(crate) fn new() -> Self {
let _ = XCB_CONNECTION
.xkb_use_extension(XKB_X11_MIN_MAJOR_XKB_VERSION, XKB_X11_MIN_MINOR_XKB_VERSION)
.unwrap()
.reply()
.unwrap();
let xkb_context = xkbcommon::xkb::Context::new(xkbcommon::xkb::CONTEXT_NO_FLAGS);
let xkb_device_id = xkbcommon::xkb::x11::get_core_keyboard_device_id(&*XCB_CONNECTION);
let xkb_state = {
let xkb_keymap = xkbcommon::xkb::x11::keymap_new_from_device(
&xkb_context,
&*XCB_CONNECTION,
xkb_device_id,
xkbcommon::xkb::KEYMAP_COMPILE_NO_FLAGS,
);
xkbcommon::xkb::x11::state_new_from_device(&xkb_keymap, &*XCB_CONNECTION, xkb_device_id)
};
let mut code_to_key = HashMap::default();
let mut key_to_code = HashMap::default();
let mut code_to_shifted_key = HashMap::default();
let keymap = xkb_state.get_keymap();
let mut shifted_state = xkbcommon::xkb::State::new(&keymap);
let shift_mod = keymap.mod_get_index(xkbcommon::xkb::MOD_NAME_SHIFT);
let shift_mask = 1 << shift_mod;
shifted_state.update_mask(shift_mask, 0, 0, 0, 0, 0);
for &scan_code in TYPEABLE_CODES {
let keycode = Keycode::new(scan_code);
let key = xkb_state.key_get_utf8(keycode);
code_to_key.insert(keycode, key.clone());
key_to_code.insert(key, keycode);
let shifted_key = shifted_state.key_get_utf8(keycode);
code_to_shifted_key.insert(keycode, shifted_key);
}
Self {
code_to_key,
key_to_code,
code_to_shifted_key,
}
}
}
// All typeable scan codes for the standard US keyboard layout, ANSI104
#[cfg(any(feature = "wayland", feature = "x11"))]
const TYPEABLE_CODES: &[u32] = &[
0x0026, // a
0x0038, // b
0x0036, // c
0x0028, // d
0x001a, // e
0x0029, // f
0x002a, // g
0x002b, // h
0x001f, // i
0x002c, // j
0x002d, // k
0x002e, // l
0x003a, // m
0x0039, // n
0x0020, // o
0x0021, // p
0x0018, // q
0x001b, // r
0x0027, // s
0x001c, // t
0x001e, // u
0x0037, // v
0x0019, // w
0x0035, // x
0x001d, // y
0x0034, // z
0x0013, // Digit 0
0x000a, // Digit 1
0x000b, // Digit 2
0x000c, // Digit 3
0x000d, // Digit 4
0x000e, // Digit 5
0x000f, // Digit 6
0x0010, // Digit 7
0x0011, // Digit 8
0x0012, // Digit 9
0x0031, // ` Backquote
0x0014, // - Minus
0x0015, // = Equal
0x0022, // [ Left bracket
0x0023, // ] Right bracket
0x0033, // \ Backslash
0x002f, // ; Semicolon
0x0030, // ' Quote
0x003b, // , Comma
0x003c, // . Period
0x003d, // / Slash
];
#[cfg(any(feature = "wayland", feature = "x11"))]
fn is_immutable_key(key: &str) -> bool {
matches!(
key,
"f1" | "f2"
| "f3"
| "f4"
| "f5"
| "f6"
| "f7"
| "f8"
| "f9"
| "f10"
| "f11"
| "f12"
| "f13"
| "f14"
| "f15"
| "f16"
| "f17"
| "f18"
| "f19"
| "f20"
| "f21"
| "f22"
| "f23"
| "f24"
| "backspace"
| "delete"
| "left"
| "right"
| "up"
| "down"
| "pageup"
| "pagedown"
| "insert"
| "home"
| "end"
| "back"
| "forward"
| "escape"
| "space"
| "tab"
| "enter"
| "shift"
| "control"
| "alt"
| "platform"
| "cmd"
| "super"
| "win"
| "fn"
| "menu"
| "copy"
| "paste"
| "cut"
| "find"
| "open"
| "save"
)
}
#[cfg(any(feature = "wayland", feature = "x11"))]
fn get_scan_code(scan_code: ScanCode) -> Option<u32> {
// https://github.com/microsoft/node-native-keymap/blob/main/deps/chromium/dom_code_data.inc
Some(match scan_code {
ScanCode::F1 => 0x0043,
ScanCode::F2 => 0x0044,
ScanCode::F3 => 0x0045,
ScanCode::F4 => 0x0046,
ScanCode::F5 => 0x0047,
ScanCode::F6 => 0x0048,
ScanCode::F7 => 0x0049,
ScanCode::F8 => 0x004a,
ScanCode::F9 => 0x004b,
ScanCode::F10 => 0x004c,
ScanCode::F11 => 0x005f,
ScanCode::F12 => 0x0060,
ScanCode::F13 => 0x00bf,
ScanCode::F14 => 0x00c0,
ScanCode::F15 => 0x00c1,
ScanCode::F16 => 0x00c2,
ScanCode::F17 => 0x00c3,
ScanCode::F18 => 0x00c4,
ScanCode::F19 => 0x00c5,
ScanCode::F20 => 0x00c6,
ScanCode::F21 => 0x00c7,
ScanCode::F22 => 0x00c8,
ScanCode::F23 => 0x00c9,
ScanCode::F24 => 0x00ca,
ScanCode::A => 0x0026,
ScanCode::B => 0x0038,
ScanCode::C => 0x0036,
ScanCode::D => 0x0028,
ScanCode::E => 0x001a,
ScanCode::F => 0x0029,
ScanCode::G => 0x002a,
ScanCode::H => 0x002b,
ScanCode::I => 0x001f,
ScanCode::J => 0x002c,
ScanCode::K => 0x002d,
ScanCode::L => 0x002e,
ScanCode::M => 0x003a,
ScanCode::N => 0x0039,
ScanCode::O => 0x0020,
ScanCode::P => 0x0021,
ScanCode::Q => 0x0018,
ScanCode::R => 0x001b,
ScanCode::S => 0x0027,
ScanCode::T => 0x001c,
ScanCode::U => 0x001e,
ScanCode::V => 0x0037,
ScanCode::W => 0x0019,
ScanCode::X => 0x0035,
ScanCode::Y => 0x001d,
ScanCode::Z => 0x0034,
ScanCode::Digit0 => 0x0013,
ScanCode::Digit1 => 0x000a,
ScanCode::Digit2 => 0x000b,
ScanCode::Digit3 => 0x000c,
ScanCode::Digit4 => 0x000d,
ScanCode::Digit5 => 0x000e,
ScanCode::Digit6 => 0x000f,
ScanCode::Digit7 => 0x0010,
ScanCode::Digit8 => 0x0011,
ScanCode::Digit9 => 0x0012,
ScanCode::Backquote => 0x0031,
ScanCode::Minus => 0x0014,
ScanCode::Equal => 0x0015,
ScanCode::BracketLeft => 0x0022,
ScanCode::BracketRight => 0x0023,
ScanCode::Backslash => 0x0033,
ScanCode::Semicolon => 0x002f,
ScanCode::Quote => 0x0030,
ScanCode::Comma => 0x003b,
ScanCode::Period => 0x003c,
ScanCode::Slash => 0x003d,
ScanCode::Left => 0x0071,
ScanCode::Up => 0x006f,
ScanCode::Right => 0x0072,
ScanCode::Down => 0x0074,
ScanCode::PageUp => 0x0070,
ScanCode::PageDown => 0x0075,
ScanCode::End => 0x0073,
ScanCode::Home => 0x006e,
ScanCode::Tab => 0x0017,
ScanCode::Enter => 0x0024,
ScanCode::Escape => 0x0009,
ScanCode::Space => 0x0041,
ScanCode::Backspace => 0x0016,
ScanCode::Delete => 0x0077,
ScanCode::Insert => 0x0076,
})
}
#[cfg(not(any(feature = "wayland", feature = "x11")))]
pub(crate) struct LinuxKeyboardMapper;
#[cfg(not(any(feature = "wayland", feature = "x11")))]
impl PlatformKeyboardMapper for LinuxKeyboardMapper {
fn scan_code_to_key(&self, _scan_code: ScanCode) -> anyhow::Result<String> {
Err(anyhow::anyhow!("LinuxKeyboardMapper not supported"))
}
fn get_shifted_key(&self, _key: &str) -> anyhow::Result<Option<String>> {
Err(anyhow::anyhow!("LinuxKeyboardMapper not supported"))
}
}
#[cfg(not(any(feature = "wayland", feature = "x11")))]
impl LinuxKeyboardMapper {
pub(crate) fn new() -> Self {
Self
}
}

View File

@@ -25,8 +25,9 @@ use xkbcommon::xkb::{self, Keycode, Keysym, State};
use crate::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId,
ForegroundExecutor, Keymap, LinuxDispatcher, Menu, MenuItem, OwnedMenu, PathPromptOptions,
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow,
Point, Result, ScreenCaptureSource, Task, WindowAppearance, WindowParams, px,
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper,
PlatformTextSystem, PlatformWindow, Point, Result, ScreenCaptureSource, Task, WindowAppearance,
WindowParams, px,
};
#[cfg(any(feature = "wayland", feature = "x11"))]
@@ -138,6 +139,10 @@ impl<P: LinuxClient + 'static> Platform for P {
self.with_common(|common| common.text_system.clone())
}
fn keyboard_mapper(&self) -> Box<dyn PlatformKeyboardMapper> {
Box::new(super::LinuxKeyboardMapper::new())
}
fn keyboard_layout(&self) -> Box<dyn PlatformKeyboardLayout> {
self.keyboard_layout()
}

View File

@@ -635,8 +635,12 @@ impl WaylandWindowStatePtr {
let mut bounds: Option<Bounds<Pixels>> = None;
if let Some(mut input_handler) = state.input_handler.take() {
drop(state);
if let Some(selection) = input_handler.marked_text_range() {
bounds = input_handler.bounds_for_range(selection.start..selection.start);
if let Some(selection) = input_handler.selected_text_range(true) {
bounds = input_handler.bounds_for_range(if selection.reversed {
selection.range.start..selection.range.start
} else {
selection.range.end..selection.range.end
});
}
self.state.borrow_mut().input_handler = Some(input_handler);
}

View File

@@ -1,5 +1,4 @@
mod client;
mod clipboard;
mod display;
mod event;
mod window;

View File

@@ -1,3 +1,4 @@
use crate::platform::scap_screen_capture::scap_screen_sources;
use core::str;
use std::{
cell::RefCell,
@@ -40,9 +41,8 @@ use xkbc::x11::ffi::{XKB_X11_MIN_MAJOR_XKB_VERSION, XKB_X11_MIN_MINOR_XKB_VERSIO
use xkbcommon::xkb::{self as xkbc, LayoutIndex, ModMask, STATE_LAYOUT_EFFECTIVE};
use super::{
ButtonOrScroll, ScrollDirection, button_or_scroll_from_event_detail,
clipboard::{self, Clipboard},
get_valuator_axis_index, modifiers_from_state, pressed_button_from_mask,
ButtonOrScroll, ScrollDirection, button_or_scroll_from_event_detail, get_valuator_axis_index,
modifiers_from_state, pressed_button_from_mask,
};
use super::{X11Display, X11WindowStatePtr, XcbAtoms};
use super::{XimCallbackEvent, XimHandler};
@@ -56,7 +56,6 @@ use crate::platform::{
reveal_path_internal,
xdg_desktop_portal::{Event as XDPEvent, XDPEventSource},
},
scap_screen_capture::scap_screen_sources,
};
use crate::{
AnyWindowHandle, Bounds, ClipboardItem, CursorStyle, DisplayId, FileDropEvent, Keystroke,
@@ -202,7 +201,7 @@ pub struct X11ClientState {
pointer_device_states: BTreeMap<xinput::DeviceId, PointerDeviceState>,
pub(crate) common: LinuxCommon,
pub(crate) clipboard: Clipboard,
pub(crate) clipboard: x11_clipboard::Clipboard,
pub(crate) clipboard_item: Option<ClipboardItem>,
pub(crate) xdnd_state: Xdnd,
}
@@ -389,7 +388,7 @@ impl X11Client {
.reply()
.unwrap();
let clipboard = Clipboard::new().unwrap();
let clipboard = x11_clipboard::Clipboard::new().unwrap();
let xcb_connection = Rc::new(xcb_connection);
@@ -1505,36 +1504,39 @@ impl LinuxClient for X11Client {
let state = self.0.borrow_mut();
state
.clipboard
.set_text(
std::borrow::Cow::Owned(item.text().unwrap_or_default()),
clipboard::ClipboardKind::Primary,
clipboard::WaitConfig::None,
.store(
state.clipboard.setter.atoms.primary,
state.clipboard.setter.atoms.utf8_string,
item.text().unwrap_or_default().as_bytes(),
)
.context("Failed to write to clipboard (primary)")
.log_with_level(log::Level::Debug);
.ok();
}
fn write_to_clipboard(&self, item: crate::ClipboardItem) {
let mut state = self.0.borrow_mut();
state
.clipboard
.set_text(
std::borrow::Cow::Owned(item.text().unwrap_or_default()),
clipboard::ClipboardKind::Clipboard,
clipboard::WaitConfig::None,
.store(
state.clipboard.setter.atoms.clipboard,
state.clipboard.setter.atoms.utf8_string,
item.text().unwrap_or_default().as_bytes(),
)
.context("Failed to write to clipboard (clipboard)")
.log_with_level(log::Level::Debug);
.ok();
state.clipboard_item.replace(item);
}
fn read_from_primary(&self) -> Option<crate::ClipboardItem> {
let state = self.0.borrow_mut();
return state
state
.clipboard
.get_any(clipboard::ClipboardKind::Primary)
.context("Failed to read from clipboard (primary)")
.log_with_level(log::Level::Debug);
.load(
state.clipboard.getter.atoms.primary,
state.clipboard.getter.atoms.utf8_string,
state.clipboard.getter.atoms.property,
Duration::from_secs(3),
)
.map(|text| crate::ClipboardItem::new_string(String::from_utf8(text).unwrap()))
.ok()
}
fn read_from_clipboard(&self) -> Option<crate::ClipboardItem> {
@@ -1543,15 +1545,26 @@ impl LinuxClient for X11Client {
// which has metadata attached.
if state
.clipboard
.is_owner(clipboard::ClipboardKind::Clipboard)
.setter
.connection
.get_selection_owner(state.clipboard.setter.atoms.clipboard)
.ok()
.and_then(|r| r.reply().ok())
.map(|reply| reply.owner == state.clipboard.setter.window)
.unwrap_or(false)
{
return state.clipboard_item.clone();
}
return state
state
.clipboard
.get_any(clipboard::ClipboardKind::Clipboard)
.context("Failed to read from clipboard (clipboard)")
.log_with_level(log::Level::Debug);
.load(
state.clipboard.getter.atoms.clipboard,
state.clipboard.getter.atoms.utf8_string,
state.clipboard.getter.atoms.property,
Duration::from_secs(3),
)
.map(|text| crate::ClipboardItem::new_string(String::from_utf8(text).unwrap()))
.ok()
}
fn run(&self) {

View File

@@ -200,7 +200,7 @@ struct ClipboardData {
}
enum ReadSelNotifyResult {
GotData(ClipboardData),
GotData(Vec<u8>),
IncrStarted,
EventNotRecognized,
}
@@ -297,83 +297,30 @@ impl Inner {
}
let reader = XContext::new()?;
let highest_precedence_format =
match self.read_single(&reader, selection, self.atoms.TARGETS) {
Err(err) => {
log::trace!("Clipboard TARGETS query failed with {err:?}");
None
}
Ok(ClipboardData { bytes, format }) => {
if format == self.atoms.ATOM {
let available_formats = Self::parse_formats(&bytes);
formats
.iter()
.find(|format| available_formats.contains(format))
} else {
log::trace!(
"Unexpected clipboard TARGETS format {}",
self.atom_name(format)
);
None
}
}
};
if let Some(&format) = highest_precedence_format {
let data = self.read_single(&reader, selection, format)?;
if !formats.contains(&data.format) {
// This shouldn't happen since the format is from the TARGETS list.
log::trace!(
"Conversion to {} responded with {} which is not supported",
self.atom_name(format),
self.atom_name(data.format),
);
return Err(Error::ConversionFailure);
}
return Ok(data);
}
log::trace!("Falling back on attempting to convert clipboard to each format.");
log::trace!("Trying to get the clipboard data.");
for format in formats {
match self.read_single(&reader, selection, *format) {
Ok(data) => {
if formats.contains(&data.format) {
return Ok(data);
} else {
log::trace!(
"Conversion to {} responded with {} which is not supported",
self.atom_name(*format),
self.atom_name(data.format),
);
continue;
}
Ok(bytes) => {
return Ok(ClipboardData {
bytes,
format: *format,
});
}
Err(Error::ContentNotAvailable) => {
continue;
}
Err(e) => {
log::trace!("Conversion to {} failed: {}", self.atom_name(*format), e);
return Err(e);
}
Err(e) => return Err(e),
}
}
log::trace!("All conversions to supported formats failed.");
Err(Error::ContentNotAvailable)
}
fn parse_formats(bytes: &[u8]) -> Vec<Atom> {
bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
fn read_single(
&self,
reader: &XContext,
selection: ClipboardKind,
target_format: Atom,
) -> Result<ClipboardData> {
) -> Result<Vec<u8>> {
// Delete the property so that we can detect (using property notify)
// when the selection owner receives our request.
reader
@@ -445,16 +392,10 @@ impl Inner {
event,
)?;
if result {
return Ok(ClipboardData {
bytes: incr_data,
format: target_format,
});
return Ok(incr_data);
}
}
_ => log::trace!(
"An unexpected event arrived while reading the clipboard: {:?}",
event
),
_ => log::trace!("An unexpected event arrived while reading the clipboard."),
}
}
log::info!("Time-out hit while reading the clipboard.");
@@ -499,7 +440,7 @@ impl Inner {
Ok(current == self.server.win_id)
}
fn query_atom_name(&self, atom: x11rb::protocol::xproto::Atom) -> Result<String> {
fn atom_name(&self, atom: x11rb::protocol::xproto::Atom) -> Result<String> {
String::from_utf8(
self.server
.conn
@@ -512,14 +453,14 @@ impl Inner {
.map_err(into_unknown)
}
fn atom_name(&self, atom: x11rb::protocol::xproto::Atom) -> &'static str {
fn atom_name_dbg(&self, atom: x11rb::protocol::xproto::Atom) -> &'static str {
ATOM_NAME_CACHE.with(|cache| {
let mut cache = cache.borrow_mut();
match cache.entry(atom) {
Entry::Occupied(entry) => *entry.get(),
Entry::Vacant(entry) => {
let s = self
.query_atom_name(atom)
.atom_name(atom)
.map(|s| Box::leak(s.into_boxed_str()) as &str)
.unwrap_or("FAILED-TO-GET-THE-ATOM-NAME");
entry.insert(s);
@@ -555,12 +496,6 @@ impl Inner {
log::warn!("Received a SelectionNotify while already expecting INCR segments.");
return Ok(ReadSelNotifyResult::EventNotRecognized);
}
// Accept any property type. The property type will typically match the format type except
// when it is `TARGETS` in which case it is `ATOM`. `ANY` is provided to handle the case
// where the clipboard is not convertible to the requested format. In this case
// `reply.type_` will have format information, but `bytes` will only be non-empty if `ANY`
// is provided.
let property_type = AtomEnum::ANY;
// request the selection
let mut reply = reader
.conn
@@ -568,7 +503,7 @@ impl Inner {
true,
event.requestor,
event.property,
property_type,
event.target,
0,
u32::MAX / 4,
)
@@ -576,8 +511,12 @@ impl Inner {
.reply()
.map_err(into_unknown)?;
//log::trace!("Property.type: {:?}", self.atom_name(reply.type_));
// we found something
if reply.type_ == self.atoms.INCR {
if reply.type_ == target_format {
Ok(ReadSelNotifyResult::GotData(reply.value))
} else if reply.type_ == self.atoms.INCR {
// Note that we call the get_property again because we are
// indicating that we are ready to receive the data by deleting the
// property, however deleting only works if the type matches the
@@ -606,10 +545,8 @@ impl Inner {
}
Ok(ReadSelNotifyResult::IncrStarted)
} else {
Ok(ReadSelNotifyResult::GotData(ClipboardData {
bytes: reply.value,
format: reply.type_,
}))
// this should never happen, we have sent a request only for supported types
Err(Error::unknown("incorrect type received from clipboard"))
}
}
@@ -637,11 +574,7 @@ impl Inner {
true,
event.window,
event.atom,
if target_format == self.atoms.TARGETS {
self.atoms.ATOM
} else {
target_format
},
target_format,
0,
u32::MAX / 4,
)
@@ -679,7 +612,7 @@ impl Inner {
if event.target == self.atoms.TARGETS {
log::trace!(
"Handling TARGETS, dst property is {}",
self.atom_name(event.property)
self.atom_name_dbg(event.property)
);
let mut targets = Vec::with_capacity(10);
targets.push(self.atoms.TARGETS);
@@ -879,8 +812,8 @@ fn serve_requests(context: Arc<Inner>) -> Result<(), Box<dyn std::error::Error>>
Event::SelectionRequest(event) => {
log::trace!(
"SelectionRequest - selection is: {}, target is {}",
context.atom_name(event.selection),
context.atom_name(event.target),
context.atom_name_dbg(event.selection),
context.atom_name_dbg(event.target),
);
// Someone is requesting the clipboard content from us.
context
@@ -1054,11 +987,6 @@ impl Clipboard {
let result = self.inner.read(&format_atoms, selection)?;
log::trace!(
"read clipboard as format {:?}",
self.inner.atom_name(result.format)
);
for (format_atom, image_format) in image_format_atoms.into_iter().zip(image_formats) {
if result.format == format_atom {
let bytes = result.bytes;

View File

@@ -12,7 +12,7 @@ use std::ffi::c_void;
use util::ResultExt;
pub struct DisplayLink {
display_link: Option<sys::DisplayLink>,
display_link: sys::DisplayLink,
frame_requests: dispatch_source_t,
}
@@ -59,7 +59,7 @@ impl DisplayLink {
)?;
Ok(Self {
display_link: Some(display_link),
display_link,
frame_requests,
})
}
@@ -70,7 +70,7 @@ impl DisplayLink {
dispatch_resume(crate::dispatch_sys::dispatch_object_t {
_ds: self.frame_requests,
});
self.display_link.as_mut().unwrap().start()?;
self.display_link.start()?;
}
Ok(())
}
@@ -80,7 +80,7 @@ impl DisplayLink {
dispatch_suspend(crate::dispatch_sys::dispatch_object_t {
_ds: self.frame_requests,
});
self.display_link.as_mut().unwrap().stop()?;
self.display_link.stop()?;
}
Ok(())
}
@@ -89,14 +89,6 @@ impl DisplayLink {
impl Drop for DisplayLink {
fn drop(&mut self) {
self.stop().log_err();
// We see occasional segfaults on the CVDisplayLink thread.
//
// It seems possible that this happens because CVDisplayLinkRelease releases the CVDisplayLink
// on the main thread immediately, but the background thread that CVDisplayLink uses for timers
// is still accessing it.
//
// We might also want to upgrade to CADisplayLink, but that requires dropping old macOS support.
std::mem::forget(self.display_link.take());
unsafe {
dispatch_source_cancel(self.frame_requests);
}

View File

@@ -1,21 +1,14 @@
use crate::{
KeyDownEvent, KeyUpEvent, Keystroke, Modifiers, ModifiersChangedEvent, MouseButton,
MouseDownEvent, MouseExitEvent, MouseMoveEvent, MouseUpEvent, NavigationDirection, Pixels,
PlatformInput, ScrollDelta, ScrollWheelEvent, TouchPhase,
platform::mac::{
LMGetKbdType, NSStringExt, TISCopyCurrentKeyboardLayoutInputSource,
TISGetInputSourceProperty, UCKeyTranslate, kTISPropertyUnicodeKeyLayoutData,
},
point, px,
CMD_MOD, KeyDownEvent, KeyUpEvent, Keystroke, Modifiers, ModifiersChangedEvent, MouseButton,
MouseDownEvent, MouseExitEvent, MouseMoveEvent, MouseUpEvent, NO_MOD, NavigationDirection,
OPTION_MOD, Pixels, PlatformInput, SHIFT_MOD, ScrollDelta, ScrollWheelEvent, TouchPhase,
always_use_command_layout, chars_for_modified_key, platform::mac::NSStringExt, point, px,
};
use cocoa::{
appkit::{NSEvent, NSEventModifierFlags, NSEventPhase, NSEventType},
base::{YES, id},
};
use core_foundation::data::{CFDataGetBytePtr, CFDataRef};
use core_graphics::event::CGKeyCode;
use objc::{msg_send, sel, sel_impl};
use std::{borrow::Cow, ffi::c_void};
use std::borrow::Cow;
const BACKSPACE_KEY: u16 = 0x7f;
const SPACE_KEY: u16 = b' ' as u16;
@@ -452,80 +445,3 @@ unsafe fn parse_keystroke(native_event: id) -> Keystroke {
}
}
}
fn always_use_command_layout() -> bool {
if chars_for_modified_key(0, NO_MOD).is_ascii() {
return false;
}
chars_for_modified_key(0, CMD_MOD).is_ascii()
}
const NO_MOD: u32 = 0;
const CMD_MOD: u32 = 1;
const SHIFT_MOD: u32 = 2;
const OPTION_MOD: u32 = 8;
fn chars_for_modified_key(code: CGKeyCode, modifiers: u32) -> String {
// Values from: https://github.com/phracker/MacOSX-SDKs/blob/master/MacOSX10.6.sdk/System/Library/Frameworks/Carbon.framework/Versions/A/Frameworks/HIToolbox.framework/Versions/A/Headers/Events.h#L126
// shifted >> 8 for UCKeyTranslate
const CG_SPACE_KEY: u16 = 49;
// https://github.com/phracker/MacOSX-SDKs/blob/master/MacOSX10.6.sdk/System/Library/Frameworks/CoreServices.framework/Versions/A/Frameworks/CarbonCore.framework/Versions/A/Headers/UnicodeUtilities.h#L278
#[allow(non_upper_case_globals)]
const kUCKeyActionDown: u16 = 0;
#[allow(non_upper_case_globals)]
const kUCKeyTranslateNoDeadKeysMask: u32 = 0;
let keyboard_type = unsafe { LMGetKbdType() as u32 };
const BUFFER_SIZE: usize = 4;
let mut dead_key_state = 0;
let mut buffer: [u16; BUFFER_SIZE] = [0; BUFFER_SIZE];
let mut buffer_size: usize = 0;
let keyboard = unsafe { TISCopyCurrentKeyboardLayoutInputSource() };
if keyboard.is_null() {
return "".to_string();
}
let layout_data = unsafe {
TISGetInputSourceProperty(keyboard, kTISPropertyUnicodeKeyLayoutData as *const c_void)
as CFDataRef
};
if layout_data.is_null() {
unsafe {
let _: () = msg_send![keyboard, release];
}
return "".to_string();
}
let keyboard_layout = unsafe { CFDataGetBytePtr(layout_data) };
unsafe {
UCKeyTranslate(
keyboard_layout as *const c_void,
code,
kUCKeyActionDown,
modifiers,
keyboard_type,
kUCKeyTranslateNoDeadKeysMask,
&mut dead_key_state,
BUFFER_SIZE,
&mut buffer_size as *mut usize,
&mut buffer as *mut u16,
);
if dead_key_state != 0 {
UCKeyTranslate(
keyboard_layout as *const c_void,
CG_SPACE_KEY,
kUCKeyActionDown,
modifiers,
keyboard_type,
kUCKeyTranslateNoDeadKeysMask,
&mut dead_key_state,
BUFFER_SIZE,
&mut buffer_size as *mut usize,
&mut buffer as *mut u16,
);
}
let _: () = msg_send![keyboard, release];
}
String::from_utf16(&buffer[..buffer_size]).unwrap_or_default()
}

View File

@@ -1,8 +1,14 @@
use std::ffi::{CStr, c_void};
use collections::HashMap;
use core_foundation::data::{CFDataGetBytePtr, CFDataRef};
use core_graphics::event::CGKeyCode;
use objc::{msg_send, runtime::Object, sel, sel_impl};
use crate::PlatformKeyboardLayout;
use crate::{
PlatformKeyboardLayout, PlatformKeyboardMapper, ScanCode, is_alphabetic_key,
platform::mac::{LMGetKbdType, UCKeyTranslate, kTISPropertyUnicodeKeyLayoutData},
};
use super::{
TISCopyCurrentKeyboardLayoutInputSource, TISGetInputSourceProperty, kTISPropertyInputSourceID,
@@ -26,24 +32,338 @@ impl PlatformKeyboardLayout for MacKeyboardLayout {
impl MacKeyboardLayout {
pub(crate) fn new() -> Self {
unsafe {
let current_keyboard = TISCopyCurrentKeyboardLayoutInputSource();
let id: *mut Object = TISGetInputSourceProperty(
current_keyboard,
kTISPropertyInputSourceID as *const c_void,
);
let id: *const std::os::raw::c_char = msg_send![id, UTF8String];
let id = CStr::from_ptr(id).to_str().unwrap().to_string();
let name: *mut Object = TISGetInputSourceProperty(
current_keyboard,
kTISPropertyLocalizedName as *const c_void,
);
let (keyboard, id) = get_keyboard_layout_id();
let name = unsafe {
let name: *mut Object =
TISGetInputSourceProperty(keyboard, kTISPropertyLocalizedName as *const c_void);
let name: *const std::os::raw::c_char = msg_send![name, UTF8String];
let name = CStr::from_ptr(name).to_str().unwrap().to_string();
CStr::from_ptr(name).to_str().unwrap().to_string()
};
Self { id, name }
Self { id, name }
}
}
fn get_keyboard_layout_id() -> (*mut Object, String) {
unsafe {
let current_keyboard = TISCopyCurrentKeyboardLayoutInputSource();
let id: *mut Object =
TISGetInputSourceProperty(current_keyboard, kTISPropertyInputSourceID as *const c_void);
let id: *const std::os::raw::c_char = msg_send![id, UTF8String];
(
current_keyboard,
CStr::from_ptr(id).to_str().unwrap().to_string(),
)
}
}
pub(crate) struct MacKeyboardMapper {
code_to_key: HashMap<u16, String>,
key_to_code: HashMap<String, u16>,
code_to_shifted_key: HashMap<u16, String>,
}
impl MacKeyboardMapper {
pub(crate) fn new() -> Self {
let mut code_to_key = HashMap::default();
let mut key_to_code = HashMap::default();
let mut code_to_shifted_key = HashMap::default();
let always_use_cmd_layout = always_use_command_layout();
for &scan_code in TYPEABLE_CODES.iter() {
let (key, shifted_key) = generate_key_pairs(scan_code, always_use_cmd_layout);
code_to_key.insert(scan_code, key.clone());
key_to_code.insert(key, scan_code);
code_to_shifted_key.insert(scan_code, shifted_key);
}
Self {
code_to_key,
key_to_code,
code_to_shifted_key,
}
}
}
impl PlatformKeyboardMapper for MacKeyboardMapper {
fn scan_code_to_key(&self, gpui_scan_code: ScanCode) -> anyhow::Result<String> {
if let Some(key) = gpui_scan_code.try_to_key() {
return Ok(key);
}
let Some(scan_code) = get_scan_code(gpui_scan_code) else {
return Err(anyhow::anyhow!("Scan code not found: {:?}", gpui_scan_code));
};
if let Some(key) = self.code_to_key.get(&scan_code) {
Ok(key.clone())
} else {
Err(anyhow::anyhow!(
"Key not found for input scan code: {:?}, scan code: {}",
gpui_scan_code,
scan_code
))
}
}
fn get_shifted_key(&self, key: &str) -> anyhow::Result<Option<String>> {
if key.chars().count() != 1 {
return Ok(None);
}
if is_alphabetic_key(key) {
return Ok(Some(key.to_uppercase()));
}
let Some(scan_code) = self.key_to_code.get(key) else {
return Err(anyhow::anyhow!("Key not found: {}", key));
};
if let Some(shifted_key) = self.code_to_shifted_key.get(scan_code) {
Ok(Some(shifted_key.clone()))
} else {
Err(anyhow::anyhow!(
"Shifted key not found for key {} with scan code: {}",
key,
scan_code
))
}
}
}
pub(crate) const NO_MOD: u32 = 0;
pub(crate) const CMD_MOD: u32 = 1;
pub(crate) const SHIFT_MOD: u32 = 2;
pub(crate) const OPTION_MOD: u32 = 8;
pub(crate) fn chars_for_modified_key(code: CGKeyCode, modifiers: u32) -> String {
// Values from: https://github.com/phracker/MacOSX-SDKs/blob/master/MacOSX10.6.sdk/System/Library/Frameworks/Carbon.framework/Versions/A/Frameworks/HIToolbox.framework/Versions/A/Headers/Events.h#L126
// shifted >> 8 for UCKeyTranslate
const CG_SPACE_KEY: u16 = 49;
// https://github.com/phracker/MacOSX-SDKs/blob/master/MacOSX10.6.sdk/System/Library/Frameworks/CoreServices.framework/Versions/A/Frameworks/CarbonCore.framework/Versions/A/Headers/UnicodeUtilities.h#L278
#[allow(non_upper_case_globals)]
const kUCKeyActionDown: u16 = 0;
#[allow(non_upper_case_globals)]
const kUCKeyTranslateNoDeadKeysMask: u32 = 0;
let keyboard_type = unsafe { LMGetKbdType() as u32 };
const BUFFER_SIZE: usize = 4;
let mut dead_key_state = 0;
let mut buffer: [u16; BUFFER_SIZE] = [0; BUFFER_SIZE];
let mut buffer_size: usize = 0;
let keyboard = unsafe { TISCopyCurrentKeyboardLayoutInputSource() };
if keyboard.is_null() {
return "".to_string();
}
let layout_data = unsafe {
TISGetInputSourceProperty(keyboard, kTISPropertyUnicodeKeyLayoutData as *const c_void)
as CFDataRef
};
if layout_data.is_null() {
unsafe {
let _: () = msg_send![keyboard, release];
}
return "".to_string();
}
let keyboard_layout = unsafe { CFDataGetBytePtr(layout_data) };
unsafe {
UCKeyTranslate(
keyboard_layout as *const c_void,
code,
kUCKeyActionDown,
modifiers,
keyboard_type,
kUCKeyTranslateNoDeadKeysMask,
&mut dead_key_state,
BUFFER_SIZE,
&mut buffer_size as *mut usize,
&mut buffer as *mut u16,
);
if dead_key_state != 0 {
UCKeyTranslate(
keyboard_layout as *const c_void,
CG_SPACE_KEY,
kUCKeyActionDown,
modifiers,
keyboard_type,
kUCKeyTranslateNoDeadKeysMask,
&mut dead_key_state,
BUFFER_SIZE,
&mut buffer_size as *mut usize,
&mut buffer as *mut u16,
);
}
let _: () = msg_send![keyboard, release];
}
String::from_utf16(&buffer[..buffer_size]).unwrap_or_default()
}
pub(crate) fn always_use_command_layout() -> bool {
if chars_for_modified_key(0, NO_MOD).is_ascii() {
return false;
}
chars_for_modified_key(0, CMD_MOD).is_ascii()
}
fn generate_key_pairs(scan_code: u16, always_use_cmd_layout: bool) -> (String, String) {
let mut chars_ignoring_modifiers = chars_for_modified_key(scan_code, NO_MOD);
let mut chars_with_shift = chars_for_modified_key(scan_code, SHIFT_MOD);
// Handle Dvorak+QWERTY / Russian / Armenian
if always_use_cmd_layout {
let chars_with_cmd = chars_for_modified_key(scan_code, CMD_MOD);
let chars_with_both = chars_for_modified_key(scan_code, CMD_MOD | SHIFT_MOD);
// We don't do this in the case that the shifted command key generates
// the same character as the unshifted command key (Norwegian, e.g.)
if chars_with_both != chars_with_cmd {
chars_with_shift = chars_with_both;
// Handle edge-case where cmd-shift-s reports cmd-s instead of
// cmd-shift-s (Ukrainian, etc.)
} else if chars_with_cmd.to_ascii_uppercase() != chars_with_cmd {
chars_with_shift = chars_with_cmd.to_ascii_uppercase();
}
chars_ignoring_modifiers = chars_with_cmd;
}
(chars_ignoring_modifiers, chars_with_shift)
}
// All typeable scan codes for the standard US keyboard layout, ANSI104
const TYPEABLE_CODES: &[u16] = &[
0x0000, // a
0x000b, // b
0x0008, // c
0x0002, // d
0x000e, // e
0x0003, // f
0x0005, // g
0x0004, // h
0x0022, // i
0x0026, // j
0x0028, // k
0x0025, // l
0x002e, // m
0x002d, // n
0x001f, // o
0x0023, // p
0x000c, // q
0x000f, // r
0x0001, // s
0x0011, // t
0x0020, // u
0x0009, // v
0x000d, // w
0x0007, // x
0x0010, // y
0x0006, // z
0x001d, // Digit 0
0x0012, // Digit 1
0x0013, // Digit 2
0x0014, // Digit 3
0x0015, // Digit 4
0x0017, // Digit 5
0x0016, // Digit 6
0x001a, // Digit 7
0x001c, // Digit 8
0x0019, // Digit 9
0x0032, // ` Tilde
0x001b, // - Minus
0x0018, // = Equal
0x0021, // [ Left bracket
0x001e, // ] Right bracket
0x002a, // \ Backslash
0x0029, // ; Semicolon
0x0027, // ' Quote
0x002b, // , Comma
0x002f, // . Period
0x002c, // / Slash
];
fn get_scan_code(scan_code: ScanCode) -> Option<u16> {
// https://github.com/microsoft/node-native-keymap/blob/main/deps/chromium/dom_code_data.inc
Some(match scan_code {
ScanCode::F1 => 0x007a,
ScanCode::F2 => 0x0078,
ScanCode::F3 => 0x0063,
ScanCode::F4 => 0x0076,
ScanCode::F5 => 0x0060,
ScanCode::F6 => 0x0061,
ScanCode::F7 => 0x0062,
ScanCode::F8 => 0x0064,
ScanCode::F9 => 0x0065,
ScanCode::F10 => 0x006d,
ScanCode::F11 => 0x0067,
ScanCode::F12 => 0x006f,
ScanCode::F13 => 0x0069,
ScanCode::F14 => 0x006b,
ScanCode::F15 => 0x0071,
ScanCode::F16 => 0x006a,
ScanCode::F17 => 0x0040,
ScanCode::F18 => 0x004f,
ScanCode::F19 => 0x0050,
ScanCode::F20 => 0x005a,
ScanCode::F21 | ScanCode::F22 | ScanCode::F23 | ScanCode::F24 => return None,
ScanCode::A => 0x0000,
ScanCode::B => 0x000b,
ScanCode::C => 0x0008,
ScanCode::D => 0x0002,
ScanCode::E => 0x000e,
ScanCode::F => 0x0003,
ScanCode::G => 0x0005,
ScanCode::H => 0x0004,
ScanCode::I => 0x0022,
ScanCode::J => 0x0026,
ScanCode::K => 0x0028,
ScanCode::L => 0x0025,
ScanCode::M => 0x002e,
ScanCode::N => 0x002d,
ScanCode::O => 0x001f,
ScanCode::P => 0x0023,
ScanCode::Q => 0x000c,
ScanCode::R => 0x000f,
ScanCode::S => 0x0001,
ScanCode::T => 0x0011,
ScanCode::U => 0x0020,
ScanCode::V => 0x0009,
ScanCode::W => 0x000d,
ScanCode::X => 0x0007,
ScanCode::Y => 0x0010,
ScanCode::Z => 0x0006,
ScanCode::Digit0 => 0x001d,
ScanCode::Digit1 => 0x0012,
ScanCode::Digit2 => 0x0013,
ScanCode::Digit3 => 0x0014,
ScanCode::Digit4 => 0x0015,
ScanCode::Digit5 => 0x0017,
ScanCode::Digit6 => 0x0016,
ScanCode::Digit7 => 0x001a,
ScanCode::Digit8 => 0x001c,
ScanCode::Digit9 => 0x0019,
ScanCode::Backquote => 0x0032,
ScanCode::Minus => 0x001b,
ScanCode::Equal => 0x0018,
ScanCode::BracketLeft => 0x0021,
ScanCode::BracketRight => 0x001e,
ScanCode::Backslash => 0x002a,
ScanCode::Semicolon => 0x0029,
ScanCode::Quote => 0x0027,
ScanCode::Comma => 0x002b,
ScanCode::Period => 0x002f,
ScanCode::Slash => 0x002c,
ScanCode::Left => 0x007b,
ScanCode::Up => 0x007e,
ScanCode::Right => 0x007c,
ScanCode::Down => 0x007d,
ScanCode::PageUp => 0x0074,
ScanCode::PageDown => 0x0079,
ScanCode::End => 0x0077,
ScanCode::Home => 0x0073,
ScanCode::Tab => 0x0030,
ScanCode::Enter => 0x0024,
ScanCode::Escape => 0x0035,
ScanCode::Space => 0x0031,
ScanCode::Backspace => 0x0033,
ScanCode::Delete => 0x0075,
ScanCode::Insert => 0x0072,
})
}

View File

@@ -1,5 +1,5 @@
use super::{
BoolExt, MacKeyboardLayout,
BoolExt, MacKeyboardLayout, MacKeyboardMapper,
attributed_string::{NSAttributedString, NSMutableAttributedString},
events::key_to_native,
is_macos_version_at_least, renderer, screen_capture,
@@ -8,8 +8,8 @@ use crate::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardEntry, ClipboardItem, ClipboardString,
CursorStyle, ForegroundExecutor, Image, ImageFormat, KeyContext, Keymap, MacDispatcher,
MacDisplay, MacWindow, Menu, MenuItem, PathPromptOptions, Platform, PlatformDisplay,
PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, Result, ScreenCaptureSource,
SemanticVersion, Task, WindowAppearance, WindowParams, hash,
PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem, PlatformWindow, Result,
ScreenCaptureSource, SemanticVersion, Task, WindowAppearance, WindowParams, hash,
};
use anyhow::{Context as _, anyhow};
use block::ConcreteBlock;
@@ -843,6 +843,10 @@ impl Platform for MacPlatform {
self.0.lock().validate_menu_command = Some(callback);
}
fn keyboard_mapper(&self) -> Box<dyn PlatformKeyboardMapper> {
Box::new(MacKeyboardMapper::new())
}
fn keyboard_layout(&self) -> Box<dyn PlatformKeyboardLayout> {
Box::new(MacKeyboardLayout::new())
}

View File

@@ -1,8 +1,9 @@
use crate::{
AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DevicePixels,
ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay, PlatformKeyboardLayout,
PlatformTextSystem, PromptButton, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream,
Size, Task, TestDisplay, TestWindow, WindowAppearance, WindowParams, size,
PlatformKeyboardMapper, PlatformTextSystem, PromptButton, ScreenCaptureFrame,
ScreenCaptureSource, ScreenCaptureStream, Size, Task, TestDisplay, TestKeyboardMapper,
TestWindow, WindowAppearance, WindowParams, size,
};
use anyhow::Result;
use collections::VecDeque;
@@ -223,6 +224,10 @@ impl Platform for TestPlatform {
self.text_system.clone()
}
fn keyboard_mapper(&self) -> Box<dyn PlatformKeyboardMapper> {
Box::new(TestKeyboardMapper::new())
}
fn keyboard_layout(&self) -> Box<dyn PlatformKeyboardLayout> {
Box::new(TestKeyboardLayout)
}

View File

@@ -679,14 +679,6 @@ fn handle_ime_composition_inner(
lparam: LPARAM,
state_ptr: Rc<WindowsWindowStatePtr>,
) -> Option<isize> {
if lparam.0 == 0 {
// Japanese IME may send this message with lparam = 0, which indicates that
// there is no composition string.
with_input_handler(&state_ptr, |input_handler| {
input_handler.replace_text_in_range(None, "");
})?;
return Some(0);
}
let mut ime_input = None;
if lparam.0 as u32 & GCS_COMPSTR.0 > 0 {
let comp_string = parse_ime_compostion_string(ctx)?;

View File

@@ -1,16 +1,19 @@
use anyhow::Result;
use anyhow::{Context, Result};
use windows::Win32::UI::{
Input::KeyboardAndMouse::{
GetKeyboardLayoutNameW, MAPVK_VK_TO_CHAR, MapVirtualKeyW, ToUnicode, VIRTUAL_KEY, VK_0,
VK_1, VK_2, VK_3, VK_4, VK_5, VK_6, VK_7, VK_8, VK_9, VK_ABNT_C1, VK_CONTROL, VK_MENU,
VK_OEM_1, VK_OEM_2, VK_OEM_3, VK_OEM_4, VK_OEM_5, VK_OEM_6, VK_OEM_7, VK_OEM_8, VK_OEM_102,
VK_OEM_COMMA, VK_OEM_MINUS, VK_OEM_PERIOD, VK_OEM_PLUS, VK_SHIFT,
GetKeyboardLayoutNameW, MAPVK_VK_TO_CHAR, MAPVK_VK_TO_VSC, MAPVK_VSC_TO_VK, MapVirtualKeyW,
ToUnicode, VIRTUAL_KEY, VK_0, VK_1, VK_2, VK_3, VK_4, VK_5, VK_6, VK_7, VK_8, VK_9,
VK_ABNT_C1, VK_CONTROL, VK_MENU, VK_OEM_1, VK_OEM_2, VK_OEM_3, VK_OEM_4, VK_OEM_5,
VK_OEM_6, VK_OEM_7, VK_OEM_8, VK_OEM_102, VK_OEM_COMMA, VK_OEM_MINUS, VK_OEM_PERIOD,
VK_OEM_PLUS, VK_SHIFT, VkKeyScanW,
},
WindowsAndMessaging::KL_NAMELENGTH,
};
use windows_core::HSTRING;
use crate::{Modifiers, PlatformKeyboardLayout};
use crate::{
Modifiers, PlatformKeyboardLayout, PlatformKeyboardMapper, ScanCode, is_alphabetic_key,
};
pub(crate) struct WindowsKeyboardLayout {
id: String,
@@ -48,6 +51,38 @@ impl WindowsKeyboardLayout {
}
}
pub(crate) struct WindowsKeyboardMapper;
impl PlatformKeyboardMapper for WindowsKeyboardMapper {
fn scan_code_to_key(&self, scan_code: ScanCode) -> Result<String> {
if let Some(key) = scan_code.try_to_key() {
return Ok(key);
}
let vkey = get_virtual_key_from_scan_code(scan_code)?;
let (key, _) = vkey_to_key(vkey).context(format!(
"Failed to get key from scan code: {:?}, vkey: {:?}",
scan_code, vkey
))?;
Ok(key)
}
fn get_shifted_key(&self, key: &str) -> Result<Option<String>> {
if key.chars().count() != 1 {
return Ok(None);
}
if is_alphabetic_key(key) {
return Ok(Some(key.to_uppercase()));
}
Ok(Some(get_shifted_character(key)?))
}
}
impl WindowsKeyboardMapper {
pub(crate) fn new() -> Self {
Self
}
}
pub(crate) fn get_keystroke_key(
vkey: VIRTUAL_KEY,
scan_code: u32,
@@ -138,3 +173,200 @@ pub(crate) fn generate_key_char(
}
None
}
fn get_vkey_from_char(key: &str, modifiers: &mut Modifiers) -> Result<VIRTUAL_KEY> {
if key.chars().count() != 1 {
anyhow::bail!("Key must be a single character, but got: {}", key);
}
let key_char = key
.encode_utf16()
.next()
.context("Empty key in keystorke")?;
let result = unsafe { VkKeyScanW(key_char) };
if result == -1 {
anyhow::bail!("Failed to get vkey from char: {}", key);
}
let high = (result >> 8) as i8;
let low = result as u8;
let (shift, ctrl, alt) = get_modifiers(high);
if ctrl {
if modifiers.control {
anyhow::bail!(
"Error parsing: {}, Ctrl modifier already set, but ctrl is required for this key: {}, you may be unable to use this shortcut.",
display_keystroke(key, modifiers),
key
);
}
modifiers.control = true;
}
if alt {
if modifiers.alt {
anyhow::bail!(
"Error parsing: {}, Alt modifier already set, but alt is required for this key: {}, you may be unable to use this shortcut.",
display_keystroke(key, modifiers),
key
);
}
modifiers.alt = true;
}
if shift {
if modifiers.shift {
anyhow::bail!(
"Error parsing: {}, Shift modifier already set, but shift is required for this key: {}, you may be unable to use this shortcut.",
display_keystroke(key, modifiers),
key
);
}
modifiers.shift = true;
}
Ok(VIRTUAL_KEY(low as u16))
}
fn get_modifiers(high: i8) -> (bool, bool, bool) {
let shift = high & 1;
let ctrl = (high >> 1) & 1;
let alt = (high >> 2) & 1;
(shift != 0, ctrl != 0, alt != 0)
}
fn get_shifted_character(key: &str) -> Result<String> {
let mut modifiers = Modifiers::default();
let virtual_key = get_vkey_from_char(key, &mut modifiers).context(format!(
"Failed to get virtual key from char while key_to_shifted: {}",
key
))?;
if modifiers != Modifiers::default() {
return Err(anyhow::anyhow!(
"Key is not a single character or has modifiers: {}",
key
));
}
let mut state = [0; 256];
state[VK_SHIFT.0 as usize] = 0x80;
let scan_code = unsafe { MapVirtualKeyW(virtual_key.0 as u32, MAPVK_VK_TO_VSC) };
let mut buffer = [0; 4];
let len = unsafe {
ToUnicode(
virtual_key.0 as u32,
scan_code,
Some(&state),
&mut buffer,
0,
)
};
if len > 0 {
let candidate = String::from_utf16_lossy(&buffer[..len as usize]);
if !candidate.is_empty() && !candidate.chars().next().unwrap().is_control() {
return Ok(candidate);
}
}
Err(anyhow::anyhow!("Failed to get shifted key for: {}", key))
}
/// Converts a Windows virtual key code to its corresponding character and dead key status.
///
/// # Parameters
/// * `vkey` - The virtual key code to convert
///
/// # Returns
/// * `Some((String, bool))` - The character as a string and a boolean indicating if it's a dead key.
/// A dead key is a key that doesn't produce a character by itself but modifies the next key pressed
/// (e.g., accent keys like ^ or `).
/// * `None` - If the virtual key code doesn't map to a character
pub fn vkey_to_key(vkey: VIRTUAL_KEY) -> Option<(String, bool)> {
let key_data = unsafe { MapVirtualKeyW(vkey.0 as u32, MAPVK_VK_TO_CHAR) };
if key_data == 0 {
return None;
}
// The high word contains dead key flag, the low word contains the character
let is_dead_key = (key_data >> 16) > 0;
let key = char::from_u32(key_data & 0xFFFF)?;
Some((key.to_ascii_lowercase().to_string(), is_dead_key))
}
fn display_keystroke(key: &str, modifiers: &Modifiers) -> String {
let mut display = String::new();
if modifiers.platform {
display.push_str("win-");
}
if modifiers.control {
display.push_str("ctrl-");
}
if modifiers.shift {
display.push_str("shift-");
}
if modifiers.alt {
display.push_str("alt-");
}
display.push_str(key);
display
}
fn get_virtual_key_from_scan_code(gpui_scan_code: ScanCode) -> Result<VIRTUAL_KEY> {
// https://github.com/microsoft/node-native-keymap/blob/main/deps/chromium/dom_code_data.inc
let scan_code = match gpui_scan_code {
ScanCode::A => 0x001e,
ScanCode::B => 0x0030,
ScanCode::C => 0x002e,
ScanCode::D => 0x0020,
ScanCode::E => 0x0012,
ScanCode::F => 0x0021,
ScanCode::G => 0x0022,
ScanCode::H => 0x0023,
ScanCode::I => 0x0017,
ScanCode::J => 0x0024,
ScanCode::K => 0x0025,
ScanCode::L => 0x0026,
ScanCode::M => 0x0032,
ScanCode::N => 0x0031,
ScanCode::O => 0x0018,
ScanCode::P => 0x0019,
ScanCode::Q => 0x0010,
ScanCode::R => 0x0013,
ScanCode::S => 0x001f,
ScanCode::T => 0x0014,
ScanCode::U => 0x0016,
ScanCode::V => 0x002f,
ScanCode::W => 0x0011,
ScanCode::X => 0x002d,
ScanCode::Y => 0x0015,
ScanCode::Z => 0x002c,
ScanCode::Digit0 => 0x000b,
ScanCode::Digit1 => 0x0002,
ScanCode::Digit2 => 0x0003,
ScanCode::Digit3 => 0x0004,
ScanCode::Digit4 => 0x0005,
ScanCode::Digit5 => 0x0006,
ScanCode::Digit6 => 0x0007,
ScanCode::Digit7 => 0x0008,
ScanCode::Digit8 => 0x0009,
ScanCode::Digit9 => 0x000a,
ScanCode::Backquote => 0x0029,
ScanCode::Minus => 0x000c,
ScanCode::Equal => 0x000d,
ScanCode::BracketLeft => 0x001a,
ScanCode::BracketRight => 0x001b,
ScanCode::Backslash => 0x002b,
ScanCode::Semicolon => 0x0027,
ScanCode::Quote => 0x0028,
ScanCode::Comma => 0x0033,
ScanCode::Period => 0x0034,
ScanCode::Slash => 0x0035,
_ => anyhow::bail!("Unsupported scan code: {:?}", gpui_scan_code),
};
let virtual_key = unsafe { MapVirtualKeyW(scan_code, MAPVK_VSC_TO_VK) };
if virtual_key == 0 {
anyhow::bail!(
"Failed to get virtual key from scan code: {:?}, {}",
gpui_scan_code,
scan_code
);
}
Ok(VIRTUAL_KEY(virtual_key as u16))
}

View File

@@ -313,6 +313,10 @@ impl Platform for WindowsPlatform {
self.text_system.clone()
}
fn keyboard_mapper(&self) -> Box<dyn PlatformKeyboardMapper> {
Box::new(WindowsKeyboardMapper::new())
}
fn keyboard_layout(&self) -> Box<dyn PlatformKeyboardLayout> {
Box::new(
WindowsKeyboardLayout::new()

View File

@@ -485,8 +485,6 @@ pub struct Chunk<'a> {
pub is_unnecessary: bool,
/// Whether this chunk of text was originally a tab character.
pub is_tab: bool,
/// Whether this chunk of text was originally a tab character.
pub is_inlay: bool,
/// Whether to underline the corresponding text range in the editor.
pub underline: bool,
}

View File

@@ -374,6 +374,7 @@ pub trait LanguageModelProvider: 'static {
fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
Vec::new()
}
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
fn is_authenticated(&self, cx: &App) -> bool;
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;

View File

@@ -15,7 +15,7 @@ use language_model::{
LanguageModelRequest, RateLimiter, Role,
};
use lmstudio::{
ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models,
ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models, preload_model,
stream_chat_completion,
};
use schemars::JsonSchema;
@@ -216,6 +216,15 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
.collect()
}
fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let id = model.id().0.to_string();
cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
.detach_and_log_err(cx);
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}

View File

@@ -12,7 +12,7 @@ use language_model::{
};
use ollama::{
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
OllamaToolCall, get_models, show_model, stream_chat_completion,
OllamaToolCall, get_models, preload_model, show_model, stream_chat_completion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -243,6 +243,15 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
models
}
fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let id = model.id().0.to_string();
cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
.detach_and_log_err(cx);
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}

View File

@@ -3,7 +3,7 @@ use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::B
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{convert::TryFrom, time::Duration};
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
@@ -391,3 +391,34 @@ pub async fn get_models(
serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
Ok(response.data)
}
/// Sends an empty request to LM Studio to trigger loading the model
pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
let uri = format!("{api_url}/completions");
let request = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.body(AsyncBody::from(serde_json::to_string(
&serde_json::json!({
"model": model,
"messages": [],
"stream": false,
"max_tokens": 0,
}),
)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
Ok(())
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"Failed to connect to LM Studio API: {} {}",
response.status(),
body,
);
}
}

View File

@@ -3,7 +3,7 @@ use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::B
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;
use std::{sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@@ -357,6 +357,36 @@ pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) ->
Ok(details)
}
/// Sends an empty request to Ollama to trigger loading the model
pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
let uri = format!("{api_url}/api/generate");
let request = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.body(AsyncBody::from(
serde_json::json!({
"model": model,
"keep_alive": "15m",
})
.to_string(),
))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
Ok(())
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"Failed to connect to Ollama API: {} {}",
response.status(),
body,
);
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -426,19 +426,58 @@ pub fn global_ssh_config_file() -> &'static Path {
Path::new("/etc/ssh/ssh_config")
}
/// Returns the path to the vscode user settings file
/// Returns the path to the vscode user settings file.
/// Note: This returns the `Default` profile settings file.
pub fn vscode_settings_file() -> &'static PathBuf {
static LOGS_DIR: OnceLock<PathBuf> = OnceLock::new();
let rel_path = "Code/User/settings.json";
LOGS_DIR.get_or_init(|| {
if cfg!(target_os = "macos") {
#[cfg(target_os = "macos")]
{
LOGS_DIR.get_or_init(|| {
home_dir()
.join("Library/Application Support")
.join(rel_path)
} else {
home_dir().join(".config").join(rel_path)
}
})
})
}
#[cfg(target_os = "windows")]
{
LOGS_DIR.get_or_init(|| {
dirs::config_dir()
.expect("failed to determine RoamingAppData directory")
.join(rel_path)
})
}
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
{
LOGS_DIR.get_or_init(|| home_dir().join(".config").join(rel_path))
}
}
/// Returns the path to the vscode user keymap file.
/// Note: This returns the `Default` profile keymap file.
pub fn vscode_shortcuts_file() -> &'static PathBuf {
static RESULT: OnceLock<PathBuf> = OnceLock::new();
let rel_path = "Code/User/keybindings.json";
#[cfg(target_os = "macos")]
{
RESULT.get_or_init(|| {
home_dir()
.join("Library/Application Support")
.join(rel_path)
})
}
#[cfg(target_os = "windows")]
{
RESULT.get_or_init(|| {
dirs::config_dir()
.expect("failed to determine RoamingAppData directory")
.join(rel_path)
})
}
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
{
RESULT.get_or_init(|| home_dir().join(".config").join(rel_path))
}
}
/// Returns the path to the cursor user settings file

View File

@@ -82,7 +82,6 @@ text.workspace = true
toml.workspace = true
url.workspace = true
util.workspace = true
uuid.workspace = true
which.workspace = true
worktree.workspace = true
zlog.workspace = true

View File

@@ -8,7 +8,6 @@ use task::{
BuildTaskDefinition, DebugScenario, RevealStrategy, RevealTarget, Shell, SpawnInTerminal,
TaskTemplate,
};
use uuid::Uuid;
pub(crate) struct GoLocator;
@@ -32,7 +31,11 @@ impl DapLocator for GoLocator {
match go_action.as_str() {
"test" => {
let binary_path = format!("__debug_{}", Uuid::new_v4().simple());
let binary_path = if build_config.env.contains_key("OUT_DIR") {
"${OUT_DIR}/__debug".to_string()
} else {
"__debug".to_string()
};
let build_task = TaskTemplate {
label: "go test debug".into(),
@@ -130,15 +133,14 @@ impl DapLocator for GoLocator {
match go_action.as_str() {
"test" => {
let binary_arg = build_config
.args
.get(4)
.ok_or_else(|| anyhow::anyhow!("can't locate debug binary"))?;
let program = PathBuf::from(&cwd)
.join(binary_arg)
.to_string_lossy()
.into_owned();
let program = if let Some(out_dir) = build_config.env.get("OUT_DIR") {
format!("{}/__debug", out_dir)
} else {
PathBuf::from(&cwd)
.join("__debug")
.to_string_lossy()
.to_string()
};
Ok(DebugRequest::Launch(task::LaunchRequest {
program,
@@ -169,7 +171,7 @@ impl DapLocator for GoLocator {
#[cfg(test)]
mod tests {
use super::*;
use task::{HideStrategy, RevealStrategy, RevealTarget, Shell, TaskId, TaskTemplate};
use task::{HideStrategy, RevealStrategy, RevealTarget, Shell, TaskTemplate};
#[test]
fn test_create_scenario_for_go_run() {
@@ -316,12 +318,7 @@ mod tests {
.contains(&"-gcflags \"all=-N -l\"".into())
);
assert!(task_template.args.contains(&"-o".into()));
assert!(
task_template
.args
.iter()
.any(|arg| arg.starts_with("__debug_"))
);
assert!(task_template.args.contains(&"__debug".into()));
} else {
panic!("Expected BuildTaskDefinition::Template");
}
@@ -333,14 +330,16 @@ mod tests {
}
#[test]
fn test_create_scenario_for_go_test_with_cwd_binary() {
fn test_create_scenario_for_go_test_with_out_dir() {
let locator = GoLocator;
let mut env = FxHashMap::default();
env.insert("OUT_DIR".to_string(), "/tmp/build".to_string());
let task = TaskTemplate {
label: "go test".into(),
command: "go".into(),
args: vec!["test".into(), ".".into()],
env: Default::default(),
env,
cwd: Some("${ZED_WORKTREE_ROOT}".into()),
use_new_terminal: false,
allow_concurrent_runs: false,
@@ -360,12 +359,7 @@ mod tests {
let scenario = scenario.unwrap();
if let Some(BuildTaskDefinition::Template { task_template, .. }) = &scenario.build {
assert!(
task_template
.args
.iter()
.any(|arg| arg.starts_with("__debug_"))
);
assert!(task_template.args.contains(&"${OUT_DIR}/__debug".into()));
} else {
panic!("Expected BuildTaskDefinition::Template");
}
@@ -395,42 +389,4 @@ mod tests {
locator.create_scenario(&task, "test label", DebugAdapterName("Delve".into()));
assert!(scenario.is_none());
}
#[test]
fn test_run_go_test_missing_binary_path() {
let locator = GoLocator;
let build_config = SpawnInTerminal {
id: TaskId("test_task".to_string()),
full_label: "go test".to_string(),
label: "go test".to_string(),
command: "go".into(),
args: vec![
"test".into(),
"-c".into(),
"-gcflags \"all=-N -l\"".into(),
"-o".into(),
], // Missing the binary path (arg 4)
command_label: "go test -c -gcflags \"all=-N -l\" -o".to_string(),
env: Default::default(),
cwd: Some(PathBuf::from("/test/path")),
use_new_terminal: false,
allow_concurrent_runs: false,
reveal: RevealStrategy::Always,
reveal_target: RevealTarget::Dock,
hide: HideStrategy::Never,
shell: Shell::System,
show_summary: true,
show_command: true,
show_rerun: true,
};
let result = futures::executor::block_on(locator.run(build_config));
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("can't locate debug binary")
);
}
}

View File

@@ -171,8 +171,7 @@ impl ConflictSet {
let mut conflicts = Vec::new();
let mut line_pos = 0;
let buffer_len = buffer.len();
let mut lines = buffer.text_for_range(0..buffer_len).lines();
let mut lines = buffer.text_for_range(0..buffer.len()).lines();
let mut conflict_start: Option<usize> = None;
let mut ours_start: Option<usize> = None;
@@ -213,7 +212,7 @@ impl ConflictSet {
&& theirs_start.is_some()
{
let theirs_end = line_pos;
let conflict_end = (line_end + 1).min(buffer_len);
let conflict_end = line_end + 1;
let range = buffer.anchor_after(conflict_start.unwrap())
..buffer.anchor_before(conflict_end);
@@ -391,22 +390,6 @@ mod tests {
assert_eq!(their_text, "This is their version in a nested conflict\n");
}
#[test]
fn test_conflict_markers_at_eof() {
let test_content = r#"
<<<<<<< ours
=======
This is their version
>>>>>>> "#
.unindent();
let buffer_id = BufferId::new(1).unwrap();
let buffer = Buffer::new(0, buffer_id, test_content.to_string());
let snapshot = buffer.snapshot();
let conflict_snapshot = ConflictSet::parse(&snapshot);
assert_eq!(conflict_snapshot.conflicts.len(), 1);
}
#[test]
fn test_conflicts_in_range() {
// Create a buffer with conflict markers

View File

@@ -741,7 +741,6 @@ mod tests {
("a.txt".into(), "".into()),
("b/c.txt".into(), "something-else".into()),
],
"deadbeef",
);
cx.executor().run_until_parked();
cx.executor().advance_clock(Duration::from_secs(1));

View File

@@ -2308,7 +2308,7 @@ impl LocalLspStore {
});
(false, lsp_delegate, servers)
});
let servers_and_adapters = servers
let servers = servers
.into_iter()
.filter_map(|server_node| {
if reused && server_node.server_id().is_none() {
@@ -2384,14 +2384,14 @@ impl LocalLspStore {
},
)?;
let server_state = self.language_servers.get(&server_id)?;
if let LanguageServerState::Running { server, adapter, .. } = server_state {
Some((server.clone(), adapter.clone()))
if let LanguageServerState::Running { server, .. } = server_state {
Some(server.clone())
} else {
None
}
})
.collect::<Vec<_>>();
for (server, adapter) in servers_and_adapters {
for server in servers {
buffer_handle.update(cx, |buffer, cx| {
buffer.set_completion_triggers(
server.server_id(),
@@ -2409,26 +2409,47 @@ impl LocalLspStore {
cx,
);
});
let snapshot = LspBufferSnapshot {
version: 0,
snapshot: initial_snapshot.clone(),
}
for adapter in self.languages.lsp_adapters(&language.name()) {
let servers = self
.language_server_ids
.get(&(worktree_id, adapter.name.clone()))
.map(|ids| {
ids.iter().flat_map(|id| {
self.language_servers.get(id).and_then(|server_state| {
if let LanguageServerState::Running { server, .. } = server_state {
Some(server.clone())
} else {
None
}
})
})
});
let servers = match servers {
Some(server) => server,
None => continue,
};
self.buffer_snapshots
.entry(buffer_id)
.or_default()
.entry(server.server_id())
.or_insert_with(|| {
server.register_buffer(
uri.clone(),
adapter.language_id(&language.name()),
0,
initial_snapshot.text(),
);
for server in servers {
let snapshot = LspBufferSnapshot {
version: 0,
snapshot: initial_snapshot.clone(),
};
self.buffer_snapshots
.entry(buffer_id)
.or_default()
.entry(server.server_id())
.or_insert_with(|| {
server.register_buffer(
uri.clone(),
adapter.language_id(&language.name()),
0,
initial_snapshot.text(),
);
vec![snapshot]
});
vec![snapshot]
});
}
}
}
@@ -3960,15 +3981,6 @@ impl LspStore {
let buffer_id = buffer.read(cx).remote_id();
let handle = cx.new(|_| buffer.clone());
if let Some(local) = self.as_local_mut() {
let refcount = local.registered_buffers.entry(buffer_id).or_insert(0);
if !ignore_refcounts {
*refcount += 1;
}
// We run early exits on non-existing buffers AFTER we mark the buffer as registered in order to handle buffer saving.
// When a new unnamed buffer is created and saved, we will start loading it's language. Once the language is loaded, we go over all "language-less" buffers and try to fit that new language
// with them. However, we do that only for the buffers that we think are open in at least one editor; thus, we need to keep tab of unnamed buffers as well, even though they're not actually registered with any language
// servers in practice (we don't support non-file URI schemes in our LSP impl).
let Some(file) = File::from_dyn(buffer.read(cx).file()) else {
return handle;
};
@@ -3976,6 +3988,11 @@ impl LspStore {
return handle;
}
let refcount = local.registered_buffers.entry(buffer_id).or_insert(0);
if !ignore_refcounts {
*refcount += 1;
}
if ignore_refcounts || *refcount == 1 {
local.register_buffer_with_language_servers(buffer, cx);
}

View File

@@ -3584,86 +3584,6 @@ async fn test_save_file(cx: &mut gpui::TestAppContext) {
assert_eq!(new_text, buffer.update(cx, |buffer, _| buffer.text()));
}
#[gpui::test(iterations = 10)]
async fn test_save_file_spawns_language_server(cx: &mut gpui::TestAppContext) {
// Issue: #24349
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(rust_lang());
let mut fake_rust_servers = language_registry.register_fake_lsp(
"Rust",
FakeLspAdapter {
name: "the-rust-language-server",
capabilities: lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions {
trigger_characters: Some(vec![".".to_string(), "::".to_string()]),
..Default::default()
}),
text_document_sync: Some(lsp::TextDocumentSyncCapability::Options(
lsp::TextDocumentSyncOptions {
save: Some(lsp::TextDocumentSyncSaveOptions::Supported(true)),
..Default::default()
},
)),
..Default::default()
},
..Default::default()
},
);
let buffer = project
.update(cx, |this, cx| this.create_buffer(cx))
.unwrap()
.await;
project.update(cx, |this, cx| {
this.register_buffer_with_language_servers(&buffer, cx);
buffer.update(cx, |buffer, cx| {
assert!(!this.has_language_servers_for(buffer, cx));
})
});
project
.update(cx, |this, cx| {
let worktree_id = this.worktrees(cx).next().unwrap().read(cx).id();
this.save_buffer_as(
buffer.clone(),
ProjectPath {
worktree_id,
path: Arc::from("file.rs".as_ref()),
},
cx,
)
})
.await
.unwrap();
// A server is started up, and it is notified about Rust files.
let mut fake_rust_server = fake_rust_servers.next().await.unwrap();
assert_eq!(
fake_rust_server
.receive_notification::<lsp::notification::DidOpenTextDocument>()
.await
.text_document,
lsp::TextDocumentItem {
uri: lsp::Url::from_file_path(path!("/dir/file.rs")).unwrap(),
version: 0,
text: "".to_string(),
language_id: "rust".to_string(),
}
);
project.update(cx, |this, cx| {
buffer.update(cx, |buffer, cx| {
assert!(this.has_language_servers_for(buffer, cx));
})
});
}
#[gpui::test(iterations = 30)]
async fn test_file_changes_multiple_times_on_disk(cx: &mut gpui::TestAppContext) {
init_test(cx);
@@ -6579,7 +6499,6 @@ async fn test_uncommitted_diff_for_buffer(cx: &mut gpui::TestAppContext) {
("src/modification.rs".into(), committed_contents),
("src/deletion.rs".into(), "// the-deleted-contents\n".into()),
],
"deadbeef",
);
fs.set_index_for_repo(
Path::new("/dir/.git"),
@@ -6646,7 +6565,6 @@ async fn test_uncommitted_diff_for_buffer(cx: &mut gpui::TestAppContext) {
("src/modification.rs".into(), committed_contents.clone()),
("src/deletion.rs".into(), "// the-deleted-contents\n".into()),
],
"deadbeef",
);
// Buffer now has an unstaged hunk.
@@ -7093,7 +7011,6 @@ async fn test_staging_hunks_with_delayed_fs_event(cx: &mut gpui::TestAppContext)
fs.set_head_for_repo(
"/dir/.git".as_ref(),
&[("file.txt".into(), committed_contents.clone())],
"deadbeef",
);
fs.set_index_for_repo(
"/dir/.git".as_ref(),
@@ -7290,7 +7207,6 @@ async fn test_staging_random_hunks(
fs.set_head_for_repo(
path!("/dir/.git").as_ref(),
&[("file.txt".into(), committed_text.clone())],
"deadbeef",
);
fs.set_index_for_repo(
path!("/dir/.git").as_ref(),
@@ -7402,7 +7318,6 @@ async fn test_single_file_diffs(cx: &mut gpui::TestAppContext) {
fs.set_head_for_repo(
Path::new("/dir/.git"),
&[("src/main.rs".into(), committed_contents.clone())],
"deadbeef",
);
fs.set_index_for_repo(
Path::new("/dir/.git"),

Some files were not shown because too many files have changed in this diff Show More