Compare commits

..

1 Commits

Author SHA1 Message Date
Conrad Irwin
0b81c19fa1 vim: Add zH/zL/zh/zl 2025-03-14 12:41:11 -06:00
598 changed files with 25614 additions and 53549 deletions

View File

@@ -19,10 +19,6 @@
# https://github.com/zed-industries/zed/pull/2394
eca93c124a488b4e538946cd2d313bd571aa2b86
# 2024-02-15 Format YAML files
# https://github.com/zed-industries/zed/pull/7887
a161a7d0c95ca7505bf9218bfae640ee5444c88b
# 2024-02-25 Format JSON files in assets/
# https://github.com/zed-industries/zed/pull/8405
ffdda588b41f7d9d270ffe76cab116f828ad545e

View File

@@ -10,7 +10,7 @@ runs:
cargo install cargo-nextest --locked
- name: Install Node
uses: actions/setup-node@cdca7365b2dadb8aad0a33bc7601856ffabcc48e # v4
uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4
with:
node-version: "18"

View File

@@ -16,7 +16,7 @@ runs:
run: cargo install cargo-nextest --locked
- name: Install Node
uses: actions/setup-node@cdca7365b2dadb8aad0a33bc7601856ffabcc48e # v4
uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4
with:
node-version: "18"

View File

@@ -209,6 +209,7 @@ jobs:
cargo check -p workspace
cargo build -p remote_server
cargo check -p gpui --examples
script/check-rust-livekit-macos
# Since the macOS runners are stateful, so we need to remove the config file to prevent potential bug.
- name: Clean CI config file
@@ -234,7 +235,7 @@ jobs:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
uses: swatinem/rust-cache@f0deed1e0edfc6a9be95417288c0e1099b1eeec3 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
@@ -286,7 +287,7 @@ jobs:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
uses: swatinem/rust-cache@f0deed1e0edfc6a9be95417288c0e1099b1eeec3 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
@@ -333,7 +334,7 @@ jobs:
Copy-Item -Path "${{ github.workspace }}" -Destination "${{ env.ZED_WORKSPACE }}" -Recurse
- name: Cache dependencies
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
uses: swatinem/rust-cache@f0deed1e0edfc6a9be95417288c0e1099b1eeec3 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
workspaces: ${{ env.ZED_WORKSPACE }}
@@ -392,7 +393,7 @@ jobs:
Copy-Item -Path "${{ github.workspace }}" -Destination "${{ env.ZED_WORKSPACE }}" -Recurse
- name: Cache dependencies
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
uses: swatinem/rust-cache@f0deed1e0edfc6a9be95417288c0e1099b1eeec3 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
workspaces: ${{ env.ZED_WORKSPACE }}
@@ -452,6 +453,7 @@ jobs:
[[ "${{ needs.linux_tests.result }}" != 'success' ]] && { RET_CODE=1; echo "Linux tests failed"; }
[[ "${{ needs.windows_tests.result }}" != 'success' ]] && { RET_CODE=1; echo "Windows tests failed"; }
[[ "${{ needs.windows_clippy.result }}" != 'success' ]] && { RET_CODE=1; echo "Windows clippy failed"; }
[[ "${{ needs.migration_checks.result }}" != 'success' ]] && { RET_CODE=1; echo "Migration checks failed"; }
[[ "${{ needs.build_remote_server.result }}" != 'success' ]] && { RET_CODE=1; echo "Remote server build failed"; }
fi
if [[ "$RET_CODE" -eq 0 ]]; then
@@ -481,7 +483,7 @@ jobs:
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
steps:
- name: Install Node
uses: actions/setup-node@cdca7365b2dadb8aad0a33bc7601856ffabcc48e # v4
uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4
with:
node-version: "18"
@@ -525,14 +527,14 @@ jobs:
mv target/x86_64-apple-darwin/release/Zed.dmg target/x86_64-apple-darwin/release/Zed-x86_64.dmg
- name: Upload app bundle (aarch64) to workflow run if main branch or specific label
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-aarch64.dmg
path: target/aarch64-apple-darwin/release/Zed-aarch64.dmg
- name: Upload app bundle (x86_64) to workflow run if main branch or specific label
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-x86_64.dmg
@@ -585,7 +587,7 @@ jobs:
run: script/bundle-linux
- name: Upload Linux bundle to workflow run if main branch or specific label
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
@@ -594,7 +596,7 @@ jobs:
path: target/release/zed-*.tar.gz
- name: Upload Linux remote server to workflow run if main branch or specific label
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
@@ -646,7 +648,7 @@ jobs:
run: script/bundle-linux
- name: Upload Linux bundle to workflow run if main branch or specific label
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
@@ -655,7 +657,7 @@ jobs:
path: target/release/zed-*.tar.gz
- name: Upload Linux remote server to workflow run if main branch or specific label
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')

View File

@@ -1,7 +1,7 @@
name: "Close Stale Issues"
on:
schedule:
- cron: "0 7,9,11 * * 2"
- cron: "0 11 * * 2"
workflow_dispatch:
jobs:

View File

@@ -13,12 +13,11 @@ jobs:
id: get-release-url
run: |
if [ "${{ github.event.release.prerelease }}" == "true" ]; then
URL="https://zed.dev/releases/preview/latest"
URL="https://zed.dev/releases/preview/latest"
else
URL="https://zed.dev/releases/stable/latest"
URL="https://zed.dev/releases/stable/latest"
fi
echo "URL=$URL" >> $GITHUB_OUTPUT
echo "::set-output name=URL::$URL"
- name: Get content
uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1
id: get-content
@@ -34,35 +33,3 @@ jobs:
with:
webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }}
content: ${{ steps.get-content.outputs.string }}
send_release_notes_email:
if: github.repository_owner == 'zed-industries' && !github.event.release.prerelease
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
fetch-depth: 0
- name: Check if release was promoted from preview
id: check-promotion-from-preview
run: |
VERSION="${{ github.event.release.tag_name }}"
PREVIEW_TAG="${VERSION}-pre"
if git rev-parse "$PREVIEW_TAG" > /dev/null 2>&1; then
echo "was_promoted_from_preview=true" >> $GITHUB_OUTPUT
else
echo "was_promoted_from_preview=false" >> $GITHUB_OUTPUT
fi
- name: Send release notes email
if: steps.check-promotion-from-preview.outputs.was_promoted_from_preview == 'true'
run: |
TAG="${{ github.event.release.tag_name }}"
echo \"${{ toJSON(github.event.release.body) }}\" > release_body.txt
jq -n --arg tag "$TAG" --rawfile body release_body.txt '{version: $tag, markdown_body: $body}' \
> release_data.json
curl -X POST "https://zed.dev/api/send_release_notes_email" \
-H "Authorization: Bearer ${{ secrets.RELEASE_NOTES_API_TOKEN }}" \
-H "Content-Type: application/json" \
-d @release_data.json

View File

@@ -22,7 +22,7 @@ jobs:
version: 9
- name: Setup Node
uses: actions/setup-node@cdca7365b2dadb8aad0a33bc7601856ffabcc48e # v4
uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4
with:
node-version: "20"
cache: "pnpm"

View File

@@ -37,35 +37,35 @@ jobs:
mdbook build ./docs --dest-dir=../target/deploy/docs/
- name: Deploy Docs
uses: cloudflare/wrangler-action@da0e0dfe58b7a431659754fdf3f186c529afbe65 # v3
uses: cloudflare/wrangler-action@392082e81ffbcb9ebdde27400634aa004b35ea37 # v3
with:
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
command: pages deploy target/deploy --project-name=docs
- name: Deploy Install
uses: cloudflare/wrangler-action@da0e0dfe58b7a431659754fdf3f186c529afbe65 # v3
uses: cloudflare/wrangler-action@392082e81ffbcb9ebdde27400634aa004b35ea37 # v3
with:
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
command: r2 object put -f script/install.sh zed-open-source-website-assets/install.sh
- name: Deploy Docs Workers
uses: cloudflare/wrangler-action@da0e0dfe58b7a431659754fdf3f186c529afbe65 # v3
uses: cloudflare/wrangler-action@392082e81ffbcb9ebdde27400634aa004b35ea37 # v3
with:
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
command: deploy .cloudflare/docs-proxy/src/worker.js
- name: Deploy Install Workers
uses: cloudflare/wrangler-action@da0e0dfe58b7a431659754fdf3f186c529afbe65 # v3
uses: cloudflare/wrangler-action@392082e81ffbcb9ebdde27400634aa004b35ea37 # v3
with:
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
command: deploy .cloudflare/docs-proxy/src/worker.js
- name: Preserve Wrangler logs
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4
if: always()
with:
name: wrangler_logs

View File

@@ -18,7 +18,7 @@ jobs:
version: 9
- name: Setup Node
uses: actions/setup-node@cdca7365b2dadb8aad0a33bc7601856ffabcc48e # v4
uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4
with:
node-version: "20"
cache: "pnpm"

View File

@@ -22,7 +22,7 @@ jobs:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
uses: swatinem/rust-cache@f0deed1e0edfc6a9be95417288c0e1099b1eeec3 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "github"

View File

@@ -23,7 +23,7 @@ jobs:
- buildjet-16vcpu-ubuntu-2204
steps:
- name: Install Node
uses: actions/setup-node@cdca7365b2dadb8aad0a33bc7601856ffabcc48e # v4
uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4
with:
node-version: "18"

View File

@@ -71,7 +71,7 @@ jobs:
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
steps:
- name: Install Node
uses: actions/setup-node@cdca7365b2dadb8aad0a33bc7601856ffabcc48e # v4
uses: actions/setup-node@1d0ff469b7ec7b3cb9d8673fde0c81c44821de2a # v4
with:
node-version: "18"
@@ -170,58 +170,6 @@ jobs:
- name: Upload Zed Nightly
run: script/upload-nightly linux-targz
bundle-nix:
timeout-minutes: 60
name: (${{ matrix.system.os }}) Nix Build
continue-on-error: true
strategy:
fail-fast: false
matrix:
system:
- os: x86 Linux
runner: buildjet-16vcpu-ubuntu-2204
install_nix: true
- os: arm Mac
# TODO: once other macs are provisioned for nix, remove that constraint from the runner
runner: [macOS, ARM64, nix]
install_nix: false
- os: arm Linux
runner: buildjet-16vcpu-ubuntu-2204-arm
install_nix: true
if: github.repository_owner == 'zed-industries'
runs-on: ${{ matrix.system.runner }}
needs: tests
env:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
GIT_LFS_SKIP_SMUDGE: 1 # breaks the livekit rust sdk examples which we don't actually depend on
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
clean: false
# on our macs we manually install nix. for some reason the cachix action is running
# under a non-login /bin/bash shell which doesn't source the proper script to add the
# nix profile to PATH, so we manually add them here
- name: Set path
if: ${{ ! matrix.system.install_nix }}
run: |
echo "/nix/var/nix/profiles/default/bin" >> $GITHUB_PATH
echo "/Users/administrator/.nix-profile/bin" >> $GITHUB_PATH
- uses: cachix/install-nix-action@02a151ada4993995686f9ed4f1be7cfbb229e56f # v31
if: ${{ matrix.system.install_nix }}
with:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
- uses: cachix/cachix-action@0fc020193b5a1fa3ac4575aa3a7d3aa6a35435ad # v16
with:
name: zed-industries
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- run: nix build
- run: nix-collect-garbage -d
update-nightly-tag:
name: Update nightly tag
if: github.repository_owner == 'zed-industries'

View File

@@ -1,19 +0,0 @@
[
{
"label": "Debug Zed with LLDB",
"adapter": "lldb",
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
"request": "launch",
"cwd": "$ZED_WORKTREE_ROOT"
},
{
"label": "Debug Zed with GDB",
"adapter": "gdb",
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
"request": "launch",
"cwd": "$ZED_WORKTREE_ROOT",
"initialize_args": {
"stopAtBeginningOfMainSubprogram": true
}
}
]

617
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,6 @@ members = [
"crates/assistant",
"crates/assistant2",
"crates/assistant_context_editor",
"crates/assistant_eval",
"crates/assistant_settings",
"crates/assistant_slash_command",
"crates/assistant_slash_commands",
@@ -37,10 +36,6 @@ members = [
"crates/context_server_settings",
"crates/copilot",
"crates/credentials_provider",
"crates/dap",
"crates/dap_adapters",
"crates/debugger_tools",
"crates/debugger_ui",
"crates/db",
"crates/deepseek",
"crates/diagnostics",
@@ -69,7 +64,6 @@ members = [
"crates/gpui_tokio",
"crates/html_to_markdown",
"crates/http_client",
"crates/http_client_tls",
"crates/image_viewer",
"crates/indexed_docs",
"crates/inline_completion",
@@ -86,6 +80,7 @@ members = [
"crates/languages",
"crates/livekit_api",
"crates/livekit_client",
"crates/livekit_client_macos",
"crates/lmstudio",
"crates/lsp",
"crates/markdown",
@@ -159,7 +154,6 @@ members = [
"crates/ui",
"crates/ui_input",
"crates/ui_macros",
"crates/ui_prompt",
"crates/util",
"crates/util_macros",
"crates/vim",
@@ -212,7 +206,6 @@ assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant2 = { path = "crates/assistant2" }
assistant_context_editor = { path = "crates/assistant_context_editor" }
assistant_eval = { path = "crates/assistant_eval" }
assistant_settings = { path = "crates/assistant_settings" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
@@ -240,11 +233,7 @@ context_server = { path = "crates/context_server" }
context_server_settings = { path = "crates/context_server_settings" }
copilot = { path = "crates/copilot" }
credentials_provider = { path = "crates/credentials_provider" }
dap = { path = "crates/dap" }
dap_adapters = { path = "crates/dap_adapters" }
db = { path = "crates/db" }
debugger_ui = { path = "crates/debugger_ui" }
debugger_tools = { path = "crates/debugger_tools" }
deepseek = { path = "crates/deepseek" }
diagnostics = { path = "crates/diagnostics" }
buffer_diff = { path = "crates/buffer_diff" }
@@ -271,7 +260,6 @@ gpui_macros = { path = "crates/gpui_macros" }
gpui_tokio = { path = "crates/gpui_tokio" }
html_to_markdown = { path = "crates/html_to_markdown" }
http_client = { path = "crates/http_client" }
http_client_tls = { path = "crates/http_client_tls" }
image_viewer = { path = "crates/image_viewer" }
indexed_docs = { path = "crates/indexed_docs" }
inline_completion = { path = "crates/inline_completion" }
@@ -288,6 +276,7 @@ language_tools = { path = "crates/language_tools" }
languages = { path = "crates/languages" }
livekit_api = { path = "crates/livekit_api" }
livekit_client = { path = "crates/livekit_client" }
livekit_client_macos = { path = "crates/livekit_client_macos" }
lmstudio = { path = "crates/lmstudio" }
lsp = { path = "crates/lsp" }
markdown = { path = "crates/markdown" }
@@ -361,7 +350,6 @@ toolchain_selector = { path = "crates/toolchain_selector" }
ui = { path = "crates/ui" }
ui_input = { path = "crates/ui_input" }
ui_macros = { path = "crates/ui_macros" }
ui_prompt = { path = "crates/ui_prompt" }
util = { path = "crates/util" }
util_macros = { path = "crates/util_macros" }
vim = { path = "crates/vim" }
@@ -407,20 +395,17 @@ blade-util = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f
naga = { version = "23.1.0", features = ["wgsl-in"] }
blake3 = "1.5.3"
bytes = "1.0"
cargo_metadata = { git = "https://github.com/zed-industries/cargo_metadata", rev = "ce8171bad673923d61a77b6761d0dc4aff63398a"}
cargo_metadata = "0.19"
cargo_toml = "0.21"
chrono = { version = "0.4", features = ["serde"] }
circular-buffer = "1.0"
clap = { version = "4.4", features = ["derive"] }
cocoa = "0.26"
cocoa-foundation = "0.2.0"
core-video = { version = "0.4.3", features = ["metal"] }
convert_case = "0.8.0"
core-foundation = "0.10.0"
core-foundation = "0.9.3"
core-foundation-sys = "0.8.6"
ctor = "0.4.0"
dashmap = "6.0"
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "bfd4af0" }
derive_more = "0.99.17"
dirs = "4.0"
ec4rs = "1.1"
@@ -432,7 +417,8 @@ fork = "0.2.0"
futures = "0.3"
futures-batch = "0.6.1"
futures-lite = "1.13"
git2 = { version = "0.20.1", default-features = false }
# TODO: get back to regular versions when https://github.com/rust-lang/git2-rs/pull/1120 is released
git2 = { git = "https://github.com/rust-lang/git2-rs", rev = "a3b90cb3756c1bb63e2317bf9cfa57838178de5c", default-features = false }
globset = "0.4"
handlebars = "4.3"
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
@@ -454,6 +440,11 @@ libc = "0.2"
libsqlite3-sys = { version = "0.30.1", features = ["bundled"] }
linkify = "0.10.0"
linkme = "0.3.31"
livekit = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "811ceae29fabee455f110c56cd66b3f49a7e5003", features = [
"dispatcher",
"services-dispatcher",
"rustls-tls-native-roots",
], default-features = false }
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
markup5ever_rcdom = "0.3.0"
mlua = { version = "0.10", features = ["lua54", "vendored", "async", "send"] }
@@ -529,7 +520,7 @@ sys-locale = "0.3.1"
sysinfo = "0.31.0"
take-until = "0.2.0"
tempfile = "3.9.0"
thiserror = "2.0.12"
thiserror = "1.0.29"
tiktoken-rs = "0.6.0"
time = { version = "0.3", features = [
"macros",
@@ -541,7 +532,6 @@ time = { version = "0.3", features = [
tiny_http = "0.8"
toml = "0.8"
tokio = { version = "1" }
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"]}
tower-http = "0.4.4"
tree-sitter = { version = "0.25.3", features = ["wasm"] }
tree-sitter-bash = "0.23"
@@ -571,7 +561,6 @@ unindent = "0.2.0"
unicode-segmentation = "1.10"
unicode-script = "0.5.7"
url = "2.2"
urlencoding = "2.1.2"
uuid = { version = "1.1.2", features = ["v4", "v5", "v7", "serde"] }
wasmparser = "0.221"
wasm-encoder = "0.221"
@@ -587,7 +576,7 @@ which = "6.0.0"
wit-component = "0.221"
zed_llm_client = "0.4"
zstd = "0.11"
metal = "0.29"
metal = "0.31"
[workspace.dependencies.async-stripe]
git = "https://github.com/zed-industries/async-stripe"

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-bug"><path d="m8 2 1.88 1.88"/><path d="M14.12 3.88 16 2"/><path d="M9 7.13v-1a3.003 3.003 0 1 1 6 0v1"/><path d="M12 20c-3.3 0-6-2.7-6-6v-3a4 4 0 0 1 4-4h4a4 4 0 0 1 4 4v3c0 3.3-2.7 6-6 6"/><path d="M12 20v-9"/><path d="M6.53 9C4.6 8.8 3 7.1 3 5"/><path d="M6 13H2"/><path d="M3 21c0-2.1 1.7-3.9 3.8-4"/><path d="M20.97 5c0 2.1-1.6 3.8-3.5 4"/><path d="M22 13h-4"/><path d="M17.2 17c2.1.1 3.8 1.9 3.8 4"/></svg>

Before

Width:  |  Height:  |  Size: 615 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-circle"><circle cx="12" cy="12" r="10"/></svg>

Before

Width:  |  Height:  |  Size: 257 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-step-forward"><line x1="6" x2="6" y1="4" y2="20"/><polygon points="10,4 20,12 10,20"/></svg>

Before

Width:  |  Height:  |  Size: 295 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-unplug"><path d="m19 5 3-3"/><path d="m2 22 3-3"/><path d="M6.3 20.3a2.4 2.4 0 0 0 3.4 0L12 18l-6-6-2.3 2.3a2.4 2.4 0 0 0 0 3.4Z"/><path d="M7.5 13.5 10 11"/><path d="M10.5 16.5 13 14"/><path d="m12 6 6 6 2.3-2.3a2.4 2.4 0 0 0 0-3.4l-2.6-2.6a2.4 2.4 0 0 0-3.4 0Z"/></svg>

Before

Width:  |  Height:  |  Size: 474 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-circle-off"><path d="m2 2 20 20"/><path d="M8.35 2.69A10 10 0 0 1 21.3 15.65"/><path d="M19.08 19.08A10 10 0 1 1 4.92 4.92"/></svg>

Before

Width:  |  Height:  |  Size: 334 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-message-circle"><path d="M7.9 20A9 9 0 1 0 4 16.1L2 22Z"/></svg>

Before

Width:  |  Height:  |  Size: 275 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-pause"><rect x="14" y="4" width="4" height="16" rx="1"/><rect x="6" y="4" width="4" height="16" rx="1"/></svg>

Before

Width:  |  Height:  |  Size: 313 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-rotate-ccw"><path d="M3 12a9 9 0 1 0 9-9 9.75 9.75 0 0 0-6.74 2.74L3 8"/><path d="M3 3v5h5"/></svg>

Before

Width:  |  Height:  |  Size: 302 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-undo-dot"><path d="M21 17a9 9 0 0 0-15-6.7L3 13"/><path d="M3 7v6h6"/><circle cx="12" cy="17" r="1"/></svg>

Before

Width:  |  Height:  |  Size: 310 B

View File

@@ -1,5 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-arrow-up-from-dot">
<path d="m5 15 7 7 7-7"/>
<path d="M12 8v14"/>
<circle cx="12" cy="3" r="1"/>
</svg>

Before

Width:  |  Height:  |  Size: 313 B

View File

@@ -1,5 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-arrow-up-from-dot">
<path d="m3 10 9-8 9 8"/>
<path d="M12 17V2"/>
<circle cx="12" cy="21" r="1"/>
</svg>

Before

Width:  |  Height:  |  Size: 314 B

View File

@@ -1,5 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-redo-dot">
<circle cx="12" cy="17" r="1"/>
<path d="M21 7v6h-6"/>
<path d="M3 17a9 9 0 0 1 9-9 9 9 0 0 1 6 2.3l3 2.7"/>
</svg>

Before

Width:  |  Height:  |  Size: 335 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-square"><rect width="18" height="18" x="3" y="3" rx="2"/></svg>

Before

Width:  |  Height:  |  Size: 266 B

View File

@@ -30,13 +30,6 @@
"ctrl-0": ["zed::ResetBufferFontSize", { "persist": false }],
"ctrl-,": "zed::OpenSettings",
"ctrl-q": "zed::Quit",
"f4": "debugger::Start",
"f5": "debugger::Continue",
"shift-f5": "debugger::Stop",
"f6": "debugger::Pause",
"f7": "debugger::StepOver",
"cmd-f11": "debugger::StepInto",
"shift-f11": "debugger::StepOut",
"f11": "zed::ToggleFullScreen",
"ctrl-alt-z": "edit_prediction::RateCompletions",
"ctrl-shift-i": "edit_prediction::ToggleMenu"
@@ -53,9 +46,7 @@
"context": "Prompt",
"bindings": {
"left": "menu::SelectPrevious",
"right": "menu::SelectNext",
"h": "menu::SelectPrevious",
"l": "menu::SelectNext"
"right": "menu::SelectNext"
}
},
{
@@ -116,7 +107,6 @@
"ctrl-a": "editor::SelectAll",
"ctrl-l": "editor::SelectLine",
"ctrl-shift-i": "editor::Format",
"alt-shift-o": "editor::OrganizeImports",
// "cmd-shift-left": ["editor::SelectToBeginningOfLine", {"stop_at_soft_wraps": true, "stop_at_indent": true }],
// "ctrl-shift-a": ["editor::SelectToBeginningOfLine", { "stop_at_soft_wraps": true, "stop_at_indent": true }],
"shift-home": ["editor::SelectToBeginningOfLine", { "stop_at_soft_wraps": true, "stop_at_indent": true }],
@@ -133,9 +123,7 @@
"alt-g b": "editor::ToggleGitBlame",
"menu": "editor::OpenContextMenu",
"shift-f10": "editor::OpenContextMenu",
"ctrl-shift-e": "editor::ToggleEditPrediction",
"f9": "editor::ToggleBreakpoint",
"shift-f9": "editor::EditLogBreakpoint"
"ctrl-shift-e": "editor::ToggleEditPrediction"
}
},
{
@@ -754,8 +742,6 @@
"escape": "git_panel::ToggleFocus",
"ctrl-enter": "git::Commit",
"alt-enter": "menu::SecondaryConfirm",
"delete": "git::RestoreFile",
"shift-delete": "git::RestoreFile",
"backspace": "git::RestoreFile"
}
},

View File

@@ -14,13 +14,6 @@
{
"use_key_equivalents": true,
"bindings": {
"f4": "debugger::Start",
"f5": "debugger::Continue",
"shift-f5": "debugger::Stop",
"f6": "debugger::Pause",
"f7": "debugger::StepOver",
"f11": "debugger::StepInto",
"shift-f11": "debugger::StepOut",
"home": "menu::SelectFirst",
"shift-pageup": "menu::SelectFirst",
"pageup": "menu::SelectFirst",
@@ -155,8 +148,6 @@
"cmd-\"": "editor::ExpandAllDiffHunks",
"cmd-alt-g b": "editor::ToggleGitBlame",
"cmd-i": "editor::ShowSignatureHelp",
"f9": "editor::ToggleBreakpoint",
"shift-f9": "editor::EditLogBreakpoint",
"ctrl-f12": "editor::GoToDeclaration",
"alt-ctrl-f12": "editor::GoToDeclarationSplit",
"ctrl-cmd-e": "editor::ToggleEditPrediction"
@@ -705,16 +696,6 @@
"ctrl-]": "assistant::CycleNextInlineAssist"
}
},
{
"context": "Prompt",
"use_key_equivalents": true,
"bindings": {
"left": "menu::SelectPrevious",
"right": "menu::SelectNext",
"h": "menu::SelectPrevious",
"l": "menu::SelectNext"
}
},
{
"context": "ProjectSearchBar && !in_replace",
"use_key_equivalents": true,
@@ -775,14 +756,6 @@
"space": "project_panel::Open"
}
},
{
"context": "VariableList",
"use_key_equivalents": true,
"bindings": {
"left": "variable_list::CollapseSelectedEntry",
"right": "variable_list::ExpandSelectedEntry"
}
},
{
"context": "GitPanel && ChangesList",
"use_key_equivalents": true,
@@ -801,8 +774,6 @@
"shift-tab": "git_panel::FocusEditor",
"escape": "git_panel::ToggleFocus",
"cmd-enter": "git::Commit",
"delete": "git::RestoreFile",
"cmd-backspace": "git::RestoreFile",
"backspace": "git::RestoreFile"
}
},

View File

@@ -3,14 +3,7 @@
"bindings": {
"ctrl-alt-s": "zed::OpenSettings",
"ctrl-{": "pane::ActivatePreviousItem",
"ctrl-}": "pane::ActivateNextItem",
"ctrl-f2": "debugger::Stop",
"f6": "debugger::Pause",
"f7": "debugger::StepInto",
"f8": "debugger::StepOver",
"shift-f8": "debugger::StepOut",
"f9": "debugger::Continue",
"alt-shift-f9": "debugger::Start"
"ctrl-}": "pane::ActivateNextItem"
}
},
{
@@ -38,7 +31,6 @@
"shift-alt-up": "editor::MoveLineUp",
"shift-alt-down": "editor::MoveLineDown",
"ctrl-alt-l": "editor::Format",
"ctrl-alt-o": "editor::OrganizeImports",
"shift-f6": "editor::Rename",
"ctrl-alt-left": "pane::GoBack",
"ctrl-alt-right": "pane::GoForward",
@@ -56,9 +48,7 @@
"ctrl-home": "editor::MoveToBeginning",
"ctrl-end": "editor::MoveToEnd",
"ctrl-shift-home": "editor::SelectToBeginning",
"ctrl-shift-end": "editor::SelectToEnd",
"ctrl-f8": "editor::ToggleBreakpoint",
"ctrl-shift-f8": "editor::EditLogBreakpoint"
"ctrl-shift-end": "editor::SelectToEnd"
}
},
{

View File

@@ -2,14 +2,7 @@
{
"bindings": {
"cmd-{": "pane::ActivatePreviousItem",
"cmd-}": "pane::ActivateNextItem",
"ctrl-f2": "debugger::Stop",
"f6": "debugger::Pause",
"f7": "debugger::StepInto",
"f8": "debugger::StepOver",
"shift-f8": "debugger::StepOut",
"f9": "debugger::Continue",
"alt-shift-f9": "debugger::Start"
"cmd-}": "pane::ActivateNextItem"
}
},
{
@@ -36,7 +29,6 @@
"shift-alt-up": "editor::MoveLineUp",
"shift-alt-down": "editor::MoveLineDown",
"cmd-alt-l": "editor::Format",
"ctrl-alt-o": "editor::OrganizeImports",
"shift-f6": "editor::Rename",
"cmd-[": "pane::GoBack",
"cmd-]": "pane::GoForward",
@@ -53,9 +45,7 @@
"cmd-home": "editor::MoveToBeginning",
"cmd-end": "editor::MoveToEnd",
"cmd-shift-home": "editor::SelectToBeginning",
"cmd-shift-end": "editor::SelectToEnd",
"ctrl-f8": "editor::ToggleBreakpoint",
"ctrl-shift-f8": "editor::EditLogBreakpoint"
"cmd-shift-end": "editor::SelectToEnd"
}
},
{

View File

@@ -155,6 +155,7 @@
"z +": ["workspace::SendKeystrokes", "shift-l j z t ^"],
"z t": "editor::ScrollCursorTop",
"z z": "editor::ScrollCursorCenter",
"z l": "vim::ScrollLeftHalfWay",
"z .": ["workspace::SendKeystrokes", "z z ^"],
"z b": "editor::ScrollCursorBottom",
"z a": "editor::ToggleFold",

View File

@@ -14,19 +14,5 @@ Be concise and direct in your responses.
The user has opened a project that contains the following root directories/files:
{{#each worktrees}}
- `{{root_name}}` (absolute path: `{{abs_path}}`)
- {{root_name}} (absolute path: {{abs_path}})
{{/each}}
{{#if has_rules}}
There are rules that apply to these root directories:
{{#each worktrees}}
{{#if rules_file}}
`{{root_name}}/{{rules_file.rel_path}}`:
``````
{{{rules_file.text}}}
``````
{{/if}}
{{/each}}
{{/if}}

View File

@@ -136,11 +136,6 @@
// Whether to use the system provided dialogs for Open and Save As.
// When set to false, Zed will use the built-in keyboard-first pickers.
"use_system_path_prompts": true,
// Whether to use the system provided dialogs for prompts, such as confirmation
// prompts.
// When set to false, Zed will use its built-in prompts. Note that on Linux,
// this option is ignored and Zed will always use the built-in prompts.
"use_system_prompts": true,
// Whether the cursor blinks in the editor.
"cursor_blink": true,
// Cursor shape for the default editor.
@@ -329,8 +324,6 @@
"code_actions": true,
// Whether to show runnables buttons in the gutter.
"runnables": true,
// Whether to show breakpoints in the gutter.
"breakpoints": true,
// Whether to show fold buttons in the gutter.
"folds": true
},
@@ -857,24 +850,8 @@
//
// The minimum column number to show the inline blame information at
// "min_column": 0
},
// How git hunks are displayed visually in the editor.
// This setting can take two values:
//
// 1. Show unstaged hunks filled and staged hunks hollow:
// "hunk_style": "staged_hollow"
// 2. Show unstaged hunks hollow and staged hunks filled:
// "hunk_style": "unstaged_hollow"
"hunk_style": "staged_hollow"
}
},
// The list of custom Git hosting providers.
"git_hosting_providers": [
// {
// "provider": "github",
// "name": "BigCorp GitHub",
// "base_url": "https://code.big-corp.com"
// }
],
// Configuration for how direnv configuration should be loaded. May take 2 values:
// 1. Load direnv configuration using `direnv export json` directly.
// "load_direnv": "direct"
@@ -1460,12 +1437,6 @@
// }
// ]
"ssh_connections": [],
// Configures context servers for use in the Assistant.
"context_servers": {},
"debugger": {
"stepping_granularity": "line",
"save_breakpoints": true,
"button": true
}
"context_servers": {}
}

View File

@@ -1,32 +0,0 @@
[
{
"label": "Debug active PHP file",
"adapter": "php",
"program": "$ZED_FILE",
"request": "launch",
"cwd": "$ZED_WORKTREE_ROOT"
},
{
"label": "Debug active Python file",
"adapter": "python",
"program": "$ZED_FILE",
"request": "launch",
"cwd": "$ZED_WORKTREE_ROOT"
},
{
"label": "Debug active JavaScript file",
"adapter": "javascript",
"program": "$ZED_FILE",
"request": "launch",
"cwd": "$ZED_WORKTREE_ROOT"
},
{
"label": "JavaScript debug terminal",
"adapter": "javascript",
"request": "launch",
"cwd": "$ZED_WORKTREE_ROOT",
"initialize_args": {
"console": "integratedTerminal"
}
}
]

View File

@@ -20,6 +20,7 @@ extension_host.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
lsp.workspace = true
project.workspace = true
smallvec.workspace = true
ui.workspace = true

View File

@@ -7,7 +7,8 @@ use gpui::{
EventEmitter, InteractiveElement as _, ParentElement as _, Render, SharedString,
StatefulInteractiveElement, Styled, Transformation, Window,
};
use language::{BinaryStatus, LanguageRegistry, LanguageServerId};
use language::{LanguageRegistry, LanguageServerBinaryStatus, LanguageServerId};
use lsp::LanguageServerName;
use project::{
EnvironmentErrorMessage, LanguageServerProgress, LspStoreEvent, Project,
ProjectEnvironmentEvent, WorktreeId,
@@ -22,21 +23,21 @@ actions!(activity_indicator, [ShowErrorMessage]);
pub enum Event {
ShowError {
server_name: SharedString,
lsp_name: LanguageServerName,
error: String,
},
}
pub struct ActivityIndicator {
statuses: Vec<ServerStatus>,
statuses: Vec<LspStatus>,
project: Entity<Project>,
auto_updater: Option<Entity<AutoUpdater>>,
context_menu_handle: PopoverMenuHandle<ContextMenu>,
}
struct ServerStatus {
name: SharedString,
status: BinaryStatus,
struct LspStatus {
name: LanguageServerName,
status: LanguageServerBinaryStatus,
}
struct PendingWork<'a> {
@@ -63,24 +64,11 @@ impl ActivityIndicator {
let auto_updater = AutoUpdater::get(cx);
let this = cx.new(|cx| {
let mut status_events = languages.language_server_binary_statuses();
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
while let Some((name, status)) = status_events.next().await {
this.update(cx, |this: &mut ActivityIndicator, cx| {
this.update(&mut cx, |this: &mut ActivityIndicator, cx| {
this.statuses.retain(|s| s.name != name);
this.statuses.push(ServerStatus { name, status });
cx.notify();
})?;
}
anyhow::Ok(())
})
.detach();
let mut status_events = languages.dap_server_binary_statuses();
cx.spawn(async move |this, cx| {
while let Some((name, status)) = status_events.next().await {
this.update(cx, |this, cx| {
this.statuses.retain(|s| s.name != name);
this.statuses.push(ServerStatus { name, status });
this.statuses.push(LspStatus { name, status });
cx.notify();
})?;
}
@@ -118,25 +106,25 @@ impl ActivityIndicator {
});
cx.subscribe_in(&this, window, move |_, _, event, window, cx| match event {
Event::ShowError { server_name, error } => {
Event::ShowError { lsp_name, error } => {
let create_buffer = project.update(cx, |project, cx| project.create_buffer(cx));
let project = project.clone();
let error = error.clone();
let server_name = server_name.clone();
cx.spawn_in(window, async move |workspace, cx| {
let lsp_name = lsp_name.clone();
cx.spawn_in(window, |workspace, mut cx| async move {
let buffer = create_buffer.await?;
buffer.update(cx, |buffer, cx| {
buffer.update(&mut cx, |buffer, cx| {
buffer.edit(
[(
0..0,
format!("Language server error: {}\n\n{}", server_name, error),
format!("Language server error: {}\n\n{}", lsp_name, error),
)],
None,
cx,
);
buffer.set_capability(language::Capability::ReadOnly, cx);
})?;
workspace.update_in(cx, |workspace, window, cx| {
workspace.update_in(&mut cx, |workspace, window, cx| {
workspace.add_item_to_active_pane(
Box::new(cx.new(|cx| {
Editor::for_buffer(buffer, Some(project.clone()), window, cx)
@@ -159,9 +147,9 @@ impl ActivityIndicator {
fn show_error_message(&mut self, _: &ShowErrorMessage, _: &mut Window, cx: &mut Context<Self>) {
self.statuses.retain(|status| {
if let BinaryStatus::Failed { error } = &status.status {
if let LanguageServerBinaryStatus::Failed { error } = &status.status {
cx.emit(Event::ShowError {
server_name: status.name.clone(),
lsp_name: status.name.clone(),
error: error.clone(),
});
false
@@ -290,10 +278,12 @@ impl ActivityIndicator {
let mut failed = SmallVec::<[_; 3]>::new();
for status in &self.statuses {
match status.status {
BinaryStatus::CheckingForUpdate => checking_for_update.push(status.name.clone()),
BinaryStatus::Downloading => downloading.push(status.name.clone()),
BinaryStatus::Failed { .. } => failed.push(status.name.clone()),
BinaryStatus::None => {}
LanguageServerBinaryStatus::CheckingForUpdate => {
checking_for_update.push(status.name.clone())
}
LanguageServerBinaryStatus::Downloading => downloading.push(status.name.clone()),
LanguageServerBinaryStatus::Failed { .. } => failed.push(status.name.clone()),
LanguageServerBinaryStatus::None => {}
}
}
@@ -306,7 +296,7 @@ impl ActivityIndicator {
),
message: format!(
"Downloading {}...",
downloading.iter().map(|name| name.as_ref()).fold(
downloading.iter().map(|name| name.0.as_ref()).fold(
String::new(),
|mut acc, s| {
if !acc.is_empty() {
@@ -334,7 +324,7 @@ impl ActivityIndicator {
),
message: format!(
"Checking for updates to {}...",
checking_for_update.iter().map(|name| name.as_ref()).fold(
checking_for_update.iter().map(|name| name.0.as_ref()).fold(
String::new(),
|mut acc, s| {
if !acc.is_empty() {
@@ -364,7 +354,7 @@ impl ActivityIndicator {
"Failed to run {}. Click to show error.",
failed
.iter()
.map(|name| name.as_ref())
.map(|name| name.0.as_ref())
.fold(String::new(), |mut acc, s| {
if !acc.is_empty() {
acc.push_str(", ");

View File

@@ -34,9 +34,9 @@ impl AskPassDelegate {
password_prompt: impl Fn(String, oneshot::Sender<String>, &mut AsyncApp) + Send + Sync + 'static,
) -> Self {
let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
let task = cx.spawn(|mut cx| async move {
while let Some((prompt, channel)) = rx.next().await {
password_prompt(prompt, channel, cx);
password_prompt(prompt, channel, &mut cx);
}
});
Self { tx, _task: task }

View File

@@ -98,9 +98,9 @@ pub fn init(
AssistantSettings::register(cx);
SlashCommandSettings::register(cx);
cx.spawn({
cx.spawn(|mut cx| {
let client = client.clone();
async move |cx| {
async move {
let is_search_slash_command_enabled = cx
.update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
.await;
@@ -116,7 +116,7 @@ pub fn init(
let semantic_index = SemanticDb::new(
paths::embeddings_dir().join("semantic-index-db.0.mdb"),
Arc::new(embedding_provider),
cx,
&mut cx,
)
.await?;

View File

@@ -98,16 +98,16 @@ impl AssistantPanel {
prompt_builder: Arc<PromptBuilder>,
cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
cx.spawn(|mut cx| async move {
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
let context_store = workspace
.update(cx, |workspace, cx| {
.update(&mut cx, |workspace, cx| {
let project = workspace.project().clone();
ContextStore::new(project, prompt_builder.clone(), slash_commands, cx)
})?
.await?;
workspace.update_in(cx, |workspace, window, cx| {
workspace.update_in(&mut cx, |workspace, window, cx| {
// TODO: deserialize state.
cx.new(|cx| Self::new(workspace, context_store, window, cx))
})
@@ -357,9 +357,9 @@ impl AssistantPanel {
) -> Task<()> {
let mut status_rx = client.status();
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
while let Some(status) = status_rx.next().await {
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
if this.client_status.is_none()
|| this
.client_status
@@ -371,7 +371,7 @@ impl AssistantPanel {
})
.log_err();
}
this.update(cx, |this, _cx| this.watch_client_status = None)
this.update(&mut cx, |this, _cx| this.watch_client_status = None)
.log_err();
})
}
@@ -576,11 +576,11 @@ impl AssistantPanel {
if self.authenticate_provider_task.is_none() {
self.authenticate_provider_task = Some((
provider.id(),
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
if let Some(future) = load_credentials {
let _ = future.await;
}
this.update(cx, |this, _cx| {
this.update(&mut cx, |this, _cx| {
this.authenticate_provider_task = None;
})
.log_err();
@@ -641,9 +641,9 @@ impl AssistantPanel {
}
} else {
let assistant_panel = assistant_panel.downgrade();
cx.spawn_in(window, async move |workspace, cx| {
cx.spawn_in(window, |workspace, mut cx| async move {
let Some(task) =
assistant_panel.update(cx, |assistant, cx| assistant.authenticate(cx))?
assistant_panel.update(&mut cx, |assistant, cx| assistant.authenticate(cx))?
else {
let answer = cx
.prompt(
@@ -665,7 +665,7 @@ impl AssistantPanel {
return Ok(());
};
task.await?;
if assistant_panel.update(cx, |panel, cx| panel.is_authenticated(cx))? {
if assistant_panel.update(&mut cx, |panel, cx| panel.is_authenticated(cx))? {
cx.update(|window, cx| match inline_assist_target {
InlineAssistTarget::Editor(active_editor, include_context) => {
let assistant_panel = if include_context {
@@ -698,7 +698,7 @@ impl AssistantPanel {
}
})?
} else {
workspace.update_in(cx, |workspace, window, cx| {
workspace.update_in(&mut cx, |workspace, window, cx| {
workspace.focus_panel::<AssistantPanel>(window, cx)
})?;
}
@@ -791,10 +791,10 @@ impl AssistantPanel {
.context_store
.update(cx, |store, cx| store.create_remote_context(cx));
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
let context = task.await?;
this.update_in(cx, |this, window, cx| {
this.update_in(&mut cx, |this, window, cx| {
let workspace = this.workspace.clone();
let project = this.project.clone();
let lsp_adapter_delegate =
@@ -847,9 +847,9 @@ impl AssistantPanel {
self.show_context(editor.clone(), window, cx);
let workspace = self.workspace.clone();
cx.spawn_in(window, async move |_, cx| {
cx.spawn_in(window, move |_, mut cx| async move {
workspace
.update_in(cx, |workspace, window, cx| {
.update_in(&mut cx, |workspace, window, cx| {
workspace.focus_panel::<AssistantPanel>(window, cx);
})
.ok();
@@ -1069,8 +1069,8 @@ impl AssistantPanel {
.filter(|editor| editor.read(cx).context().read(cx).path() == Some(&path))
});
if let Some(existing_context) = existing_context {
return cx.spawn_in(window, async move |this, cx| {
this.update_in(cx, |this, window, cx| {
return cx.spawn_in(window, |this, mut cx| async move {
this.update_in(&mut cx, |this, window, cx| {
this.show_context(existing_context, window, cx)
})
});
@@ -1085,9 +1085,9 @@ impl AssistantPanel {
let lsp_adapter_delegate = make_lsp_adapter_delegate(&project, cx).log_err().flatten();
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
let context = context.await?;
this.update_in(cx, |this, window, cx| {
this.update_in(&mut cx, |this, window, cx| {
let editor = cx.new(|cx| {
ContextEditor::for_context(
context,
@@ -1117,8 +1117,8 @@ impl AssistantPanel {
.filter(|editor| *editor.read(cx).context().read(cx).id() == id)
});
if let Some(existing_context) = existing_context {
return cx.spawn_in(window, async move |this, cx| {
this.update_in(cx, |this, window, cx| {
return cx.spawn_in(window, |this, mut cx| async move {
this.update_in(&mut cx, |this, window, cx| {
this.show_context(existing_context.clone(), window, cx)
})?;
Ok(existing_context)
@@ -1134,9 +1134,9 @@ impl AssistantPanel {
.log_err()
.flatten();
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
let context = context.await?;
this.update_in(cx, |this, window, cx| {
this.update_in(&mut cx, |this, window, cx| {
let editor = cx.new(|cx| {
ContextEditor::for_context(
context,

View File

@@ -1311,9 +1311,9 @@ impl EditorInlineAssists {
assist_ids: Vec::new(),
scroll_lock: None,
highlight_updates: highlight_updates_tx,
_update_highlights: cx.spawn({
_update_highlights: cx.spawn(|cx| {
let editor = editor.downgrade();
async move |cx| {
async move {
while let Ok(()) = highlight_updates_rx.changed().await {
let editor = editor.upgrade().context("editor was dropped")?;
cx.update_global(|assistant: &mut InlineAssistant, cx| {
@@ -1850,7 +1850,7 @@ impl PromptEditor {
fn count_tokens(&mut self, cx: &mut Context<Self>) {
let assist_id = self.id;
self.pending_token_count = cx.spawn(async move |this, cx| {
self.pending_token_count = cx.spawn(|this, mut cx| async move {
cx.background_executor().timer(Duration::from_secs(1)).await;
let token_count = cx
.update_global(|inline_assistant: &mut InlineAssistant, cx| {
@@ -1862,7 +1862,7 @@ impl PromptEditor {
})??
.await?;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.token_counts = Some(token_count);
cx.notify();
})
@@ -2882,7 +2882,7 @@ impl CodegenAlternative {
let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
self.request = Some(request.clone());
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
@@ -2999,207 +2999,213 @@ impl CodegenAlternative {
let completion = Arc::new(Mutex::new(String::new()));
let completion_clone = completion.clone();
self.generation = cx.spawn(async move |codegen, cx| {
let stream = stream.await;
let message_id = stream
.as_ref()
.ok()
.and_then(|stream| stream.message_id.clone());
let generate = async {
let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
let executor = cx.background_executor().clone();
let message_id = message_id.clone();
let line_based_stream_diff: Task<anyhow::Result<()>> =
cx.background_spawn(async move {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
self.generation = cx.spawn(|codegen, mut cx| {
async move {
let stream = stream.await;
let message_id = stream
.as_ref()
.ok()
.and_then(|stream| stream.message_id.clone());
let generate = async {
let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
let executor = cx.background_executor().clone();
let message_id = message_id.clone();
let line_based_stream_diff: Task<anyhow::Result<()>> =
cx.background_spawn(async move {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
while let Some(chunk) = chunks.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
completion_clone.lock().push_str(&chunk);
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta =
line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(
selection_start.column as usize,
);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
while let Some(chunk) = chunks.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
completion_clone.lock().push_str(&chunk);
if line_indent.is_some() {
let char_ops = diff.push_new(&new_text);
line_diff.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
new_text.clear();
}
if lines.peek().is_some() {
let char_ops = diff.push_new("\n");
line_diff.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
// Don't write out the leading indentation in empty lines on the next line
// This is the case where the above if statement didn't clear the buffer
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta =
line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(
selection_start.column as usize,
);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
}
if line_indent.is_some() {
let char_ops = diff.push_new(&new_text);
line_diff
.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
new_text.clear();
}
line_indent = None;
first_line = false;
if lines.peek().is_some() {
let char_ops = diff.push_new("\n");
line_diff
.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
if line_indent.is_none() {
// Don't write out the leading indentation in empty lines on the next line
// This is the case where the above if statement didn't clear the buffer
new_text.clear();
}
line_indent = None;
first_line = false;
}
}
}
let mut char_ops = diff.push_new(&new_text);
char_ops.extend(diff.finish());
line_diff.push_char_operations(&char_ops, &selected_text);
line_diff.finish(&selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
anyhow::Ok(())
};
let result = diff.await;
let error_message =
result.as_ref().err().map(|error| error.to_string());
report_assistant_event(
AssistantEvent {
conversation_id: None,
message_id,
kind: AssistantKind::Inline,
phase: AssistantPhase::Response,
model: model_telemetry_id,
model_provider: model_provider_id.to_string(),
response_latency,
error_message,
language_name: language_name.map(|name| name.to_proto()),
},
telemetry,
http_client,
model_api_key,
&executor,
);
result?;
Ok(())
});
while let Some((char_ops, line_ops)) = diff_rx.next().await {
codegen.update(&mut cx, |codegen, cx| {
codegen.last_equal_ranges.clear();
let edits = char_ops
.into_iter()
.filter_map(|operation| match operation {
CharOperation::Insert { text } => {
let edit_start = snapshot.anchor_after(edit_start);
Some((edit_start..edit_start, text))
}
CharOperation::Delete { bytes } => {
let edit_end = edit_start + bytes;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
Some((edit_range, String::new()))
}
CharOperation::Keep { bytes } => {
let edit_end = edit_start + bytes;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
codegen.last_equal_ranges.push(edit_range);
None
}
})
.collect::<Vec<_>>();
if codegen.active {
codegen.apply_edits(edits.iter().cloned(), cx);
codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
}
codegen.edits.extend(edits);
codegen.line_operations = line_ops;
codegen.edit_position = Some(snapshot.anchor_after(edit_start));
let mut char_ops = diff.push_new(&new_text);
char_ops.extend(diff.finish());
line_diff.push_char_operations(&char_ops, &selected_text);
line_diff.finish(&selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
anyhow::Ok(())
};
let result = diff.await;
let error_message = result.as_ref().err().map(|error| error.to_string());
report_assistant_event(
AssistantEvent {
conversation_id: None,
message_id,
kind: AssistantKind::Inline,
phase: AssistantPhase::Response,
model: model_telemetry_id,
model_provider: model_provider_id.to_string(),
response_latency,
error_message,
language_name: language_name.map(|name| name.to_proto()),
},
telemetry,
http_client,
model_api_key,
&executor,
);
result?;
Ok(())
});
while let Some((char_ops, line_ops)) = diff_rx.next().await {
codegen.update(cx, |codegen, cx| {
codegen.last_equal_ranges.clear();
let edits = char_ops
.into_iter()
.filter_map(|operation| match operation {
CharOperation::Insert { text } => {
let edit_start = snapshot.anchor_after(edit_start);
Some((edit_start..edit_start, text))
}
CharOperation::Delete { bytes } => {
let edit_end = edit_start + bytes;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
Some((edit_range, String::new()))
}
CharOperation::Keep { bytes } => {
let edit_end = edit_start + bytes;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
codegen.last_equal_ranges.push(edit_range);
None
}
})
.collect::<Vec<_>>();
if codegen.active {
codegen.apply_edits(edits.iter().cloned(), cx);
codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
}
codegen.edits.extend(edits);
codegen.line_operations = line_ops;
codegen.edit_position = Some(snapshot.anchor_after(edit_start));
cx.notify();
})?;
}
// Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
// That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
// It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
let batch_diff_task =
codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
line_based_stream_diff?;
anyhow::Ok(())
};
let result = generate.await;
let elapsed_time = start_time.elapsed().as_secs_f64();
codegen
.update(cx, |this, cx| {
this.message_id = message_id;
this.last_equal_ranges.clear();
if let Err(error) = result {
this.status = CodegenStatus::Error(error);
} else {
this.status = CodegenStatus::Done;
cx.notify();
})?;
}
this.elapsed_time = Some(elapsed_time);
this.completion = Some(completion.lock().clone());
cx.emit(CodegenEvent::Finished);
cx.notify();
})
.ok();
// Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
// That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
// It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
let batch_diff_task =
codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
let (line_based_stream_diff, ()) =
join!(line_based_stream_diff, batch_diff_task);
line_based_stream_diff?;
anyhow::Ok(())
};
let result = generate.await;
let elapsed_time = start_time.elapsed().as_secs_f64();
codegen
.update(&mut cx, |this, cx| {
this.message_id = message_id;
this.last_equal_ranges.clear();
if let Err(error) = result {
this.status = CodegenStatus::Error(error);
} else {
this.status = CodegenStatus::Done;
}
this.elapsed_time = Some(elapsed_time);
this.completion = Some(completion.lock().clone());
cx.emit(CodegenEvent::Finished);
cx.notify();
})
.ok();
}
});
cx.notify();
}
@@ -3317,7 +3323,7 @@ impl CodegenAlternative {
let new_snapshot = self.buffer.read(cx).snapshot(cx);
let new_range = self.range.to_point(&new_snapshot);
cx.spawn(async move |codegen, cx| {
cx.spawn(|codegen, mut cx| async move {
let (deleted_row_ranges, inserted_row_ranges) = cx
.background_spawn(async move {
let old_text = old_snapshot
@@ -3367,7 +3373,7 @@ impl CodegenAlternative {
.await;
codegen
.update(cx, |codegen, cx| {
.update(&mut cx, |codegen, cx| {
codegen.diff.deleted_row_ranges = deleted_row_ranges;
codegen.diff.inserted_row_ranges = inserted_row_ranges;
cx.notify();
@@ -3563,7 +3569,6 @@ impl CodeActionProvider for AssistantCodeActionProvider {
title: "Fix with Assistant".into(),
..Default::default()
})),
resolved: true,
}]))
} else {
Task::ready(Ok(Vec::new()))
@@ -3581,10 +3586,10 @@ impl CodeActionProvider for AssistantCodeActionProvider {
) -> Task<Result<ProjectTransaction>> {
let editor = self.editor.clone();
let workspace = self.workspace.clone();
window.spawn(cx, async move |cx| {
window.spawn(cx, |mut cx| async move {
let editor = editor.upgrade().context("editor was released")?;
let range = editor
.update(cx, |editor, cx| {
.update(&mut cx, |editor, cx| {
editor.buffer().update(cx, |multibuffer, cx| {
let buffer = buffer.read(cx);
let multibuffer_snapshot = multibuffer.read(cx);
@@ -3619,7 +3624,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
})
})?
.context("invalid range")?;
let assistant_panel = workspace.update(cx, |workspace, cx| {
let assistant_panel = workspace.update(&mut cx, |workspace, cx| {
workspace
.panel::<AssistantPanel>(cx)
.context("assistant panel was released")

View File

@@ -825,7 +825,7 @@ impl PromptEditor {
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
self.pending_token_count = cx.spawn(async move |this, cx| {
self.pending_token_count = cx.spawn(|this, mut cx| async move {
cx.background_executor().timer(Duration::from_secs(1)).await;
let request =
cx.update_global(|inline_assistant: &mut TerminalInlineAssistant, cx| {
@@ -833,7 +833,7 @@ impl PromptEditor {
})??;
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify();
})
@@ -1140,7 +1140,7 @@ impl Codegen {
let telemetry = self.telemetry.clone();
self.status = CodegenStatus::Pending;
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
self.generation = cx.spawn(async move |this, cx| {
self.generation = cx.spawn(|this, mut cx| async move {
let model_telemetry_id = model.telemetry_id();
let model_provider_id = model.provider_id();
let response = model.stream_completion_text(prompt, &cx).await;
@@ -1197,12 +1197,12 @@ impl Codegen {
}
});
this.update(cx, |this, _| {
this.update(&mut cx, |this, _| {
this.message_id = message_id;
})?;
while let Some(hunk) = hunks_rx.next().await {
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
if let Some(transaction) = &mut this.transaction {
transaction.push(hunk, cx);
cx.notify();
@@ -1216,7 +1216,7 @@ impl Codegen {
let result = generate.await;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
if let Err(error) = result {
this.status = CodegenStatus::Error(error);
} else {

View File

@@ -39,7 +39,6 @@ fs.workspace = true
futures.workspace = true
fuzzy.workspace = true
git.workspace = true
git_ui.workspace = true
gpui.workspace = true
heed.workspace = true
html_to_markdown.workspace = true

View File

@@ -1,16 +1,14 @@
use crate::thread::{
LastRestoreCheckpoint, MessageId, RequestKind, Thread, ThreadError, ThreadEvent,
};
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
use collections::HashMap;
use editor::{Editor, MultiBuffer};
use gpui::{
list, percentage, pulsating_between, AbsoluteLength, Animation, AnimationExt, AnyElement, App,
ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment,
ListOffset, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement,
Transformation, UnderlineStyle, WeakEntity,
list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent,
DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, ListOffset,
ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Transformation,
UnderlineStyle,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
@@ -20,24 +18,19 @@ use settings::Settings as _;
use std::sync::Arc;
use std::time::Duration;
use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding, Tooltip};
use ui::Color;
use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _;
use workspace::{OpenOptions, Workspace};
use crate::context_store::{refresh_context_store_text, ContextStore};
pub struct ActiveThread {
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<Thread>,
context_store: Entity<ContextStore>,
workspace: WeakEntity<Workspace>,
save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>,
list_state: ListState,
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
rendered_scripting_tool_uses: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
rendered_tool_use_labels: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
last_error: Option<ThreadError>,
@@ -53,8 +46,6 @@ impl ActiveThread {
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
language_registry: Arc<LanguageRegistry>,
context_store: Entity<ContextStore>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -67,13 +58,10 @@ impl ActiveThread {
language_registry,
thread_store,
thread: thread.clone(),
context_store,
workspace,
save_thread_task: None,
messages: Vec::new(),
rendered_messages_by_id: HashMap::default(),
rendered_scripting_tool_uses: HashMap::default(),
rendered_tool_use_labels: HashMap::default(),
expanded_tool_uses: HashMap::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.entity().downgrade();
@@ -90,29 +78,10 @@ impl ActiveThread {
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
this.push_message(&message.id, message.text.clone(), window, cx);
for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
this.render_tool_use_label_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
window,
cx,
);
}
for tool_use in thread
.read(cx)
.scripting_tool_uses_for_message(message.id, cx)
{
this.render_tool_use_label_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
window,
cx,
);
for tool_use in thread.read(cx).scripting_tool_uses_for_message(message.id) {
this.render_scripting_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.as_ref(),
tool_use.name.as_ref(),
tool_use.input.clone(),
window,
cx,
@@ -142,7 +111,7 @@ impl ActiveThread {
pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
self.last_error.take();
self.thread
.update(cx, |thread, cx| thread.cancel_last_completion(cx))
.update(cx, |thread, _cx| thread.cancel_last_completion())
}
pub fn last_error(&self) -> Option<ThreadError> {
@@ -310,19 +279,6 @@ impl ActiveThread {
.insert(tool_use_id, lua_script);
}
fn render_tool_use_label_markdown(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_label: impl Into<SharedString>,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.rendered_tool_use_labels.insert(
tool_use_id,
self.render_markdown(tool_label.into(), window, cx),
);
}
fn handle_thread_event(
&mut self,
_thread: &Entity<Thread>,
@@ -337,7 +293,6 @@ impl ActiveThread {
ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
self.save_thread(cx);
}
ThreadEvent::DoneStreaming => {}
ThreadEvent::StreamedAssistantText(message_id, text) => {
if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
markdown.update(cx, |markdown, cx| {
@@ -377,32 +332,14 @@ impl ActiveThread {
cx.notify();
}
ThreadEvent::UsePendingTools => {
let tool_uses = self
.thread
.update(cx, |thread, cx| thread.use_pending_tools(cx));
for tool_use in tool_uses {
self.render_tool_use_label_markdown(
tool_use.id,
tool_use.ui_text.clone(),
window,
cx,
);
}
self.thread.update(cx, |thread, cx| {
thread.use_pending_tools(cx);
});
}
ThreadEvent::ToolFinished {
pending_tool_use,
canceled,
..
pending_tool_use, ..
} => {
let canceled = *canceled;
if let Some(tool_use) = pending_tool_use {
self.render_tool_use_label_markdown(
tool_use.id.clone(),
SharedString::from(tool_use.ui_text.clone()),
window,
cx,
);
self.render_scripting_tool_use_markdown(
tool_use.id.clone(),
tool_use.name.as_ref(),
@@ -413,58 +350,14 @@ impl ActiveThread {
}
if self.thread.read(cx).all_tools_finished() {
let pending_refresh_buffers = self.thread.update(cx, |thread, cx| {
thread.action_log().update(cx, |action_log, _cx| {
action_log.take_stale_buffers_in_context()
})
});
let context_update_task = if !pending_refresh_buffers.is_empty() {
let refresh_task = refresh_context_store_text(
self.context_store.clone(),
&pending_refresh_buffers,
cx,
);
cx.spawn(async move |this, cx| {
let updated_context_ids = refresh_task.await;
this.update(cx, |this, cx| {
this.context_store.read_with(cx, |context_store, cx| {
context_store
.context()
.iter()
.filter(|context| {
updated_context_ids.contains(&context.id())
})
.flat_map(|context| context.snapshot(cx))
.collect()
})
})
})
} else {
Task::ready(anyhow::Ok(Vec::new()))
};
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
cx.spawn(async move |this, cx| {
let updated_context = context_update_task.await?;
this.update(cx, |this, cx| {
this.thread.update(cx, |thread, cx| {
thread.attach_tool_results(updated_context, cx);
if !canceled {
thread.send_to_model(model, RequestKind::Chat, cx);
}
});
})
})
.detach();
self.thread.update(cx, |thread, cx| {
thread.send_tool_results_to_model(model, cx);
});
}
}
}
ThreadEvent::CheckpointChanged => cx.notify(),
}
}
@@ -473,9 +366,9 @@ impl ActiveThread {
/// Only one task to save the thread will be in flight at a time.
fn save_thread(&mut self, cx: &mut Context<Self>) {
let thread = self.thread.clone();
self.save_thread_task = Some(cx.spawn(async move |this, cx| {
self.save_thread_task = Some(cx.spawn(|this, mut cx| async move {
let task = this
.update(cx, |this, cx| {
.update(&mut cx, |this, cx| {
this.thread_store
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
})
@@ -605,10 +498,9 @@ impl ActiveThread {
let thread = self.thread.read(cx);
// Get all the data we need from thread before we start using it in closures
let checkpoint = thread.checkpoint_for_message(message_id);
let context = thread.context_for_message(message_id);
let tool_uses = thread.tool_uses_for_message(message_id, cx);
let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id, cx);
let tool_uses = thread.tool_uses_for_message(message_id);
let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id);
// Don't render user messages that are just there for returning tool results.
if message.role == Role::User
@@ -639,7 +531,7 @@ impl ActiveThread {
.p_2p5()
.child(edit_message_editor)
} else {
div().text_ui(cx).child(markdown.clone())
div().p_2p5().text_ui(cx).child(markdown.clone())
},
)
.when_some(context, |parent, context| {
@@ -659,16 +551,15 @@ impl ActiveThread {
let styled_message = match message.role {
Role::User => v_flex()
.id(("message-container", ix))
.pt_2()
.pl_2()
.pr_2p5()
.pt_2p5()
.px_2p5()
.child(
v_flex()
.bg(colors.editor_background)
.rounded_lg()
.border_1()
.border_color(colors.border)
.shadow_md()
.shadow_sm()
.child(
h_flex()
.py_1()
@@ -759,97 +650,38 @@ impl ActiveThread {
},
),
)
.child(div().p_2().child(message_content)),
),
Role::Assistant => v_flex()
.id(("message-container", ix))
.child(div().py_3().px_4().child(message_content))
.when(
!tool_uses.is_empty() || !scripting_tool_uses.is_empty(),
|parent| {
parent.child(
v_flex()
.children(
tool_uses
.into_iter()
.map(|tool_use| self.render_tool_use(tool_use, cx)),
)
.children(scripting_tool_uses.into_iter().map(|tool_use| {
self.render_scripting_tool_use(tool_use, window, cx)
})),
)
},
.child(message_content),
),
Role::Assistant => {
v_flex()
.id(("message-container", ix))
.child(message_content)
.when(
!tool_uses.is_empty() || !scripting_tool_uses.is_empty(),
|parent| {
parent.child(
v_flex()
.children(
tool_uses
.into_iter()
.map(|tool_use| self.render_tool_use(tool_use, cx)),
)
.children(scripting_tool_uses.into_iter().map(|tool_use| {
self.render_scripting_tool_use(tool_use, cx)
})),
)
},
)
}
Role::System => div().id(("message-container", ix)).py_1().px_2().child(
v_flex()
.bg(colors.editor_background)
.rounded_sm()
.child(div().p_4().child(message_content)),
.child(message_content),
),
};
v_flex()
.when(ix == 0, |parent| parent.child(self.render_rules_item(cx)))
.when_some(checkpoint, |parent, checkpoint| {
let mut is_pending = false;
let mut error = None;
if let Some(last_restore_checkpoint) =
self.thread.read(cx).last_restore_checkpoint()
{
if last_restore_checkpoint.message_id() == message_id {
match last_restore_checkpoint {
LastRestoreCheckpoint::Pending { .. } => is_pending = true,
LastRestoreCheckpoint::Error { error: err, .. } => {
error = Some(err.clone());
}
}
}
}
let restore_checkpoint_button =
Button::new(("restore-checkpoint", ix), "Restore Checkpoint")
.icon(if error.is_some() {
IconName::XCircle
} else {
IconName::Undo
})
.size(ButtonSize::Compact)
.disabled(is_pending)
.icon_color(if error.is_some() {
Some(Color::Error)
} else {
None
})
.on_click(cx.listener(move |this, _, _window, cx| {
this.thread.update(cx, |thread, cx| {
thread
.restore_checkpoint(checkpoint.clone(), cx)
.detach_and_log_err(cx);
});
}));
let restore_checkpoint_button = if is_pending {
restore_checkpoint_button
.with_animation(
("pulsating-restore-checkpoint-button", ix),
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any_element()
} else if let Some(error) = error {
restore_checkpoint_button
.tooltip(Tooltip::text(error.to_string()))
.into_any_element()
} else {
restore_checkpoint_button.into_any_element()
};
parent.child(h_flex().pl_2().child(restore_checkpoint_button))
})
.child(styled_message)
.into_any()
styled_message.into_any()
}
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
@@ -861,7 +693,7 @@ impl ActiveThread {
let lighter_border = cx.theme().colors().border.opacity(0.5);
div().px_4().child(
div().px_2p5().child(
v_flex()
.rounded_lg()
.border_1()
@@ -897,10 +729,11 @@ impl ActiveThread {
}
}),
))
.child(div().text_ui_sm(cx).children(
self.rendered_tool_use_labels.get(&tool_use.id).cloned(),
))
.truncate(),
.child(
Label::new(tool_use.name)
.size(LabelSize::Small)
.buffer_font(cx),
),
)
.child({
let (icon_name, color, animated) = match &tool_use.status {
@@ -1028,7 +861,6 @@ impl ActiveThread {
fn render_scripting_tool_use(
&self,
tool_use: ToolUse,
window: &Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let is_open = self
@@ -1074,12 +906,7 @@ impl ActiveThread {
}
}),
))
.child(div().text_ui_sm(cx).child(self.render_markdown(
tool_use.ui_text.clone(),
window,
cx,
)))
.truncate(),
.child(Label::new(tool_use.name)),
)
.child(
Label::new(match tool_use.status {
@@ -1143,86 +970,6 @@ impl ActiveThread {
}),
)
}
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
return div().into_any();
};
let rules_files = system_prompt_context
.worktrees
.iter()
.filter_map(|worktree| worktree.rules_file.as_ref())
.collect::<Vec<_>>();
let label_text = match rules_files.as_slice() {
&[] => return div().into_any(),
&[rules_file] => {
format!("Using {:?} file", rules_file.rel_path)
}
rules_files => {
format!("Using {} rules files", rules_files.len())
}
};
div()
.pt_1()
.px_2p5()
.child(
h_flex()
.group("rules-item")
.w_full()
.gap_2()
.justify_between()
.child(
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::File)
.size(IconSize::XSmall)
.color(Color::Disabled),
)
.child(
Label::new(label_text)
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx),
),
)
.child(
div().visible_on_hover("rules-item").child(
Button::new("open-rules", "Open Rules")
.label_size(LabelSize::XSmall)
.on_click(cx.listener(Self::handle_open_rules)),
),
),
)
.into_any()
}
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
return;
};
let abs_paths = system_prompt_context
.worktrees
.iter()
.flat_map(|worktree| worktree.rules_file.as_ref())
.map(|rules_file| rules_file.abs_path.to_path_buf())
.collect::<Vec<_>>();
if let Ok(task) = self.workspace.update(cx, move |workspace, cx| {
// TODO: Open a multibuffer instead? In some cases this doesn't make the set of rules
// files clear. For example, if rules file 1 is already open but rules file 2 is not,
// this would open and focus rules file 2 in a tab that is not next to rules file 1.
workspace.open_paths(abs_paths, OpenOptions::default(), None, window, cx)
}) {
task.detach();
}
}
}
impl Render for ActiveThread {

View File

@@ -31,11 +31,8 @@ use gpui::{actions, App};
use prompt_store::PromptBuilder;
use settings::Settings as _;
pub use crate::active_thread::ActiveThread;
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;
actions!(
assistant2,

View File

@@ -1,33 +1,19 @@
use std::sync::Arc;
use assistant_tool::{ToolSource, ToolWorkingSet};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
use gpui::{Action, AnyView, App, EventEmitter, FocusHandle, Focusable, Subscription};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
use ui::{
prelude::*, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip,
};
use util::ResultExt as _;
use ui::{prelude::*, Divider, DividerColor, ElevationIndex};
use zed_actions::assistant::DeployPromptLibrary;
use zed_actions::ExtensionCategoryFilter;
pub struct AssistantConfiguration {
focus_handle: FocusHandle,
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
context_server_manager: Entity<ContextServerManager>,
expanded_context_server_tools: HashMap<Arc<str>, bool>,
tools: Arc<ToolWorkingSet>,
_registry_subscription: Subscription,
}
impl AssistantConfiguration {
pub fn new(
context_server_manager: Entity<ContextServerManager>,
tools: Arc<ToolWorkingSet>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
let focus_handle = cx.focus_handle();
let registry_subscription = cx.subscribe_in(
@@ -50,9 +36,6 @@ impl AssistantConfiguration {
let mut this = Self {
focus_handle,
configuration_views_by_provider: HashMap::default(),
context_server_manager,
expanded_context_server_tools: HashMap::default(),
tools,
_registry_subscription: registry_subscription,
};
this.build_provider_configuration_views(window, cx);
@@ -160,186 +143,6 @@ impl AssistantConfiguration {
}),
)
}
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
let tools_by_source = self.tools.tools_by_source(cx);
let empty = Vec::new();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.mt_1()
.gap_2()
.flex_1()
.child(
v_flex()
.gap_0p5()
.child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
.child(Label::new(SUBHEADING).color(Color::Muted)),
)
.children(context_servers.into_iter().map(|context_server| {
let is_running = context_server.client().is_some();
let are_tools_expanded = self
.expanded_context_server_tools
.get(&context_server.id())
.copied()
.unwrap_or_default();
let tools = tools_by_source
.get(&ToolSource::ContextServer {
id: context_server.id().into(),
})
.unwrap_or_else(|| &empty);
let tool_count = tools.len();
v_flex()
.id(SharedString::from(context_server.id()))
.border_1()
.rounded_sm()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.justify_between()
.px_2()
.py_1()
.when(are_tools_expanded, |element| {
element
.border_b_1()
.border_color(cx.theme().colors().border)
})
.child(
h_flex()
.gap_2()
.child(
Disclosure::new("tool-list-disclosure", are_tools_expanded)
.on_click(cx.listener({
let context_server_id = context_server.id();
move |this, _event, _window, _cx| {
let is_open = this
.expanded_context_server_tools
.entry(context_server_id.clone())
.or_insert(false);
*is_open = !*is_open;
}
})),
)
.child(Indicator::dot().color(if is_running {
Color::Success
} else {
Color::Error
}))
.child(Label::new(context_server.id()))
.child(
Label::new(format!("{tool_count} tools"))
.color(Color::Muted),
),
)
.child(h_flex().child(
Switch::new("context-server-switch", is_running.into()).on_click({
let context_server_manager =
self.context_server_manager.clone();
let context_server = context_server.clone();
move |state, _window, cx| match state {
ToggleState::Unselected | ToggleState::Indeterminate => {
context_server_manager.update(cx, |this, cx| {
this.stop_server(context_server.clone(), cx)
.log_err();
});
}
ToggleState::Selected => {
cx.spawn({
let context_server_manager =
context_server_manager.clone();
let context_server = context_server.clone();
async move |cx| {
if let Some(start_server_task) =
context_server_manager
.update(cx, |this, cx| {
this.start_server(
context_server,
cx,
)
})
.log_err()
{
start_server_task.await.log_err();
}
}
})
.detach();
}
}
}),
)),
)
.map(|parent| {
if !are_tools_expanded {
return parent;
}
parent.child(v_flex().children(tools.into_iter().enumerate().map(
|(ix, tool)| {
h_flex()
.px_2()
.py_1()
.when(ix < tool_count - 1, |element| {
element
.border_b_1()
.border_color(cx.theme().colors().border)
})
.child(Label::new(tool.name()))
},
)))
})
}))
.child(
h_flex()
.justify_between()
.gap_2()
.child(
h_flex().w_full().child(
Button::new("add-context-server", "Add Context Server")
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
.full_width()
.icon(IconName::Plus)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.disabled(true)
.tooltip(Tooltip::text("Not yet implemented")),
),
)
.child(
h_flex().w_full().child(
Button::new(
"install-context-server-extensions",
"Install Context Server Extensions",
)
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
.full_width()
.icon(IconName::DatabaseZap)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.on_click(|_event, window, cx| {
window.dispatch_action(
zed_actions::Extensions {
category_filter: Some(
ExtensionCategoryFilter::ContextServers,
),
}
.boxed_clone(),
cx,
)
}),
),
),
)
}
}
impl Render for AssistantConfiguration {
@@ -379,8 +182,6 @@ impl Render for AssistantConfiguration {
),
)
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_context_servers_section(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(
v_flex()
.p(DynamicSpacing::Base16.rems(cx))

View File

@@ -110,16 +110,19 @@ impl AssistantPanel {
prompt_builder: Arc<PromptBuilder>,
cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
cx.spawn(|mut cx| async move {
let tools = Arc::new(ToolWorkingSet::default());
let thread_store = workspace.update(cx, |workspace, cx| {
log::info!("[assistant2-debug] initializing ThreadStore");
let thread_store = workspace.update(&mut cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::new(project, tools.clone(), prompt_builder.clone(), cx)
})??;
log::info!("[assistant2-debug] finished initializing ThreadStore");
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
log::info!("[assistant2-debug] initializing ContextStore");
let context_store = workspace
.update(cx, |workspace, cx| {
.update(&mut cx, |workspace, cx| {
let project = workspace.project().clone();
assistant_context_editor::ContextStore::new(
project,
@@ -129,8 +132,9 @@ impl AssistantPanel {
)
})?
.await?;
log::info!("[assistant2-debug] finished initializing ContextStore");
workspace.update_in(cx, |workspace, window, cx| {
workspace.update_in(&mut cx, |workspace, window, cx| {
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
})
})
@@ -143,6 +147,7 @@ impl AssistantPanel {
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
log::info!("[assistant2-debug] AssistantPanel::new");
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone();
let project = workspace.project().clone();
@@ -150,14 +155,10 @@ impl AssistantPanel {
let workspace = workspace.weak_handle();
let weak_self = cx.entity().downgrade();
let message_editor_context_store =
cx.new(|_cx| crate::context_store::ContextStore::new(workspace.clone()));
let message_editor = cx.new(|cx| {
MessageEditor::new(
fs.clone(),
workspace.clone(),
message_editor_context_store.clone(),
thread_store.downgrade(),
thread.clone(),
window,
@@ -173,8 +174,6 @@ impl AssistantPanel {
thread.clone(),
thread_store.clone(),
language_registry.clone(),
message_editor_context_store.clone(),
workspace.clone(),
window,
cx,
)
@@ -243,17 +242,11 @@ impl AssistantPanel {
.update(cx, |this, cx| this.create_thread(cx));
self.active_view = ActiveView::Thread;
let message_editor_context_store =
cx.new(|_cx| crate::context_store::ContextStore::new(self.workspace.clone()));
self.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
self.language_registry.clone(),
message_editor_context_store.clone(),
self.workspace.clone(),
window,
cx,
)
@@ -262,7 +255,6 @@ impl AssistantPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
message_editor_context_store,
self.thread_store.downgrade(),
thread,
window,
@@ -346,9 +338,9 @@ impl AssistantPanel {
let lsp_adapter_delegate = make_lsp_adapter_delegate(&project, cx).log_err().flatten();
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
let context = context.await?;
this.update_in(cx, |this, window, cx| {
this.update_in(&mut cx, |this, window, cx| {
let editor = cx.new(|cx| {
ContextEditor::for_context(
context,
@@ -379,19 +371,15 @@ impl AssistantPanel {
.thread_store
.update(cx, |this, cx| this.open_thread(thread_id, cx));
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
let thread = open_thread_task.await?;
this.update_in(cx, |this, window, cx| {
this.update_in(&mut cx, |this, window, cx| {
this.active_view = ActiveView::Thread;
let message_editor_context_store =
cx.new(|_cx| crate::context_store::ContextStore::new(this.workspace.clone()));
this.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
this.thread_store.clone(),
this.language_registry.clone(),
message_editor_context_store.clone(),
this.workspace.clone(),
window,
cx,
)
@@ -400,7 +388,6 @@ impl AssistantPanel {
MessageEditor::new(
this.fs.clone(),
this.workspace.clone(),
message_editor_context_store,
this.thread_store.downgrade(),
thread,
window,
@@ -413,13 +400,8 @@ impl AssistantPanel {
}
pub(crate) fn open_configuration(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let context_server_manager = self.thread_store.read(cx).context_server_manager();
let tools = self.thread_store.read(cx).tools();
self.active_view = ActiveView::Configuration;
self.configuration = Some(
cx.new(|cx| AssistantConfiguration::new(context_server_manager, tools, window, cx)),
);
self.configuration = Some(cx.new(|cx| AssistantConfiguration::new(window, cx)));
if let Some(configuration) = self.configuration.as_ref() {
self.configuration_subscription = Some(cx.subscribe_in(
@@ -453,12 +435,12 @@ impl AssistantPanel {
.languages
.language_for_name("Markdown");
let thread = self.active_thread(cx);
cx.spawn_in(window, async move |_this, cx| {
cx.spawn_in(window, |_this, mut cx| async move {
let markdown_language = markdown_language_task.await?;
workspace.update_in(cx, |workspace, window, cx| {
workspace.update_in(&mut cx, |workspace, window, cx| {
let thread = thread.read(cx);
let markdown = thread.to_markdown(cx)?;
let markdown = thread.to_markdown()?;
let thread_summary = thread
.summary()
.map(|summary| summary.to_string())
@@ -925,8 +907,8 @@ impl AssistantPanel {
ThreadError::MaxMonthlySpendReached => {
self.render_max_monthly_spend_reached_error(cx)
}
ThreadError::Message { header, message } => {
self.render_error_message(header, message, cx)
ThreadError::Message(error_message) => {
self.render_error_message(&error_message, cx)
}
})
.into_any(),
@@ -1029,8 +1011,7 @@ impl AssistantPanel {
fn render_error_message(
&self,
header: SharedString,
message: SharedString,
error_message: &SharedString,
cx: &mut Context<Self>,
) -> AnyElement {
v_flex()
@@ -1040,14 +1021,17 @@ impl AssistantPanel {
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new(header).weight(FontWeight::MEDIUM)),
.child(
Label::new("Error interacting with language model")
.weight(FontWeight::MEDIUM),
),
)
.child(
div()
.id("error-message")
.max_h_32()
.overflow_y_scroll()
.child(Label::new(message)),
.child(Label::new(error_message.clone())),
)
.child(
h_flex()

View File

@@ -367,7 +367,7 @@ impl CodegenAlternative {
let request = self.build_request(user_prompt, cx)?;
self.request = Some(request.clone());
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
@@ -480,207 +480,213 @@ impl CodegenAlternative {
let completion = Arc::new(Mutex::new(String::new()));
let completion_clone = completion.clone();
self.generation = cx.spawn(async move |codegen, cx| {
let stream = stream.await;
let message_id = stream
.as_ref()
.ok()
.and_then(|stream| stream.message_id.clone());
let generate = async {
let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
let executor = cx.background_executor().clone();
let message_id = message_id.clone();
let line_based_stream_diff: Task<anyhow::Result<()>> =
cx.background_spawn(async move {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
self.generation = cx.spawn(|codegen, mut cx| {
async move {
let stream = stream.await;
let message_id = stream
.as_ref()
.ok()
.and_then(|stream| stream.message_id.clone());
let generate = async {
let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
let executor = cx.background_executor().clone();
let message_id = message_id.clone();
let line_based_stream_diff: Task<anyhow::Result<()>> =
cx.background_spawn(async move {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
while let Some(chunk) = chunks.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
completion_clone.lock().push_str(&chunk);
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta =
line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(
selection_start.column as usize,
);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
while let Some(chunk) = chunks.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
completion_clone.lock().push_str(&chunk);
if line_indent.is_some() {
let char_ops = diff.push_new(&new_text);
line_diff.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
new_text.clear();
}
if lines.peek().is_some() {
let char_ops = diff.push_new("\n");
line_diff.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
// Don't write out the leading indentation in empty lines on the next line
// This is the case where the above if statement didn't clear the buffer
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta =
line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(
selection_start.column as usize,
);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
}
if line_indent.is_some() {
let char_ops = diff.push_new(&new_text);
line_diff
.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
new_text.clear();
}
line_indent = None;
first_line = false;
if lines.peek().is_some() {
let char_ops = diff.push_new("\n");
line_diff
.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
if line_indent.is_none() {
// Don't write out the leading indentation in empty lines on the next line
// This is the case where the above if statement didn't clear the buffer
new_text.clear();
}
line_indent = None;
first_line = false;
}
}
}
let mut char_ops = diff.push_new(&new_text);
char_ops.extend(diff.finish());
line_diff.push_char_operations(&char_ops, &selected_text);
line_diff.finish(&selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
anyhow::Ok(())
};
let result = diff.await;
let error_message =
result.as_ref().err().map(|error| error.to_string());
report_assistant_event(
AssistantEvent {
conversation_id: None,
message_id,
kind: AssistantKind::Inline,
phase: AssistantPhase::Response,
model: model_telemetry_id,
model_provider: model_provider_id.to_string(),
response_latency,
error_message,
language_name: language_name.map(|name| name.to_proto()),
},
telemetry,
http_client,
model_api_key,
&executor,
);
result?;
Ok(())
});
while let Some((char_ops, line_ops)) = diff_rx.next().await {
codegen.update(&mut cx, |codegen, cx| {
codegen.last_equal_ranges.clear();
let edits = char_ops
.into_iter()
.filter_map(|operation| match operation {
CharOperation::Insert { text } => {
let edit_start = snapshot.anchor_after(edit_start);
Some((edit_start..edit_start, text))
}
CharOperation::Delete { bytes } => {
let edit_end = edit_start + bytes;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
Some((edit_range, String::new()))
}
CharOperation::Keep { bytes } => {
let edit_end = edit_start + bytes;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
codegen.last_equal_ranges.push(edit_range);
None
}
})
.collect::<Vec<_>>();
if codegen.active {
codegen.apply_edits(edits.iter().cloned(), cx);
codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
}
codegen.edits.extend(edits);
codegen.line_operations = line_ops;
codegen.edit_position = Some(snapshot.anchor_after(edit_start));
let mut char_ops = diff.push_new(&new_text);
char_ops.extend(diff.finish());
line_diff.push_char_operations(&char_ops, &selected_text);
line_diff.finish(&selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
anyhow::Ok(())
};
let result = diff.await;
let error_message = result.as_ref().err().map(|error| error.to_string());
report_assistant_event(
AssistantEvent {
conversation_id: None,
message_id,
kind: AssistantKind::Inline,
phase: AssistantPhase::Response,
model: model_telemetry_id,
model_provider: model_provider_id.to_string(),
response_latency,
error_message,
language_name: language_name.map(|name| name.to_proto()),
},
telemetry,
http_client,
model_api_key,
&executor,
);
result?;
Ok(())
});
while let Some((char_ops, line_ops)) = diff_rx.next().await {
codegen.update(cx, |codegen, cx| {
codegen.last_equal_ranges.clear();
let edits = char_ops
.into_iter()
.filter_map(|operation| match operation {
CharOperation::Insert { text } => {
let edit_start = snapshot.anchor_after(edit_start);
Some((edit_start..edit_start, text))
}
CharOperation::Delete { bytes } => {
let edit_end = edit_start + bytes;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
Some((edit_range, String::new()))
}
CharOperation::Keep { bytes } => {
let edit_end = edit_start + bytes;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
codegen.last_equal_ranges.push(edit_range);
None
}
})
.collect::<Vec<_>>();
if codegen.active {
codegen.apply_edits(edits.iter().cloned(), cx);
codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
}
codegen.edits.extend(edits);
codegen.line_operations = line_ops;
codegen.edit_position = Some(snapshot.anchor_after(edit_start));
cx.notify();
})?;
}
// Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
// That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
// It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
let batch_diff_task =
codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
line_based_stream_diff?;
anyhow::Ok(())
};
let result = generate.await;
let elapsed_time = start_time.elapsed().as_secs_f64();
codegen
.update(cx, |this, cx| {
this.message_id = message_id;
this.last_equal_ranges.clear();
if let Err(error) = result {
this.status = CodegenStatus::Error(error);
} else {
this.status = CodegenStatus::Done;
cx.notify();
})?;
}
this.elapsed_time = Some(elapsed_time);
this.completion = Some(completion.lock().clone());
cx.emit(CodegenEvent::Finished);
cx.notify();
})
.ok();
// Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
// That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
// It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
let batch_diff_task =
codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
let (line_based_stream_diff, ()) =
join!(line_based_stream_diff, batch_diff_task);
line_based_stream_diff?;
anyhow::Ok(())
};
let result = generate.await;
let elapsed_time = start_time.elapsed().as_secs_f64();
codegen
.update(&mut cx, |this, cx| {
this.message_id = message_id;
this.last_equal_ranges.clear();
if let Err(error) = result {
this.status = CodegenStatus::Error(error);
} else {
this.status = CodegenStatus::Done;
}
this.elapsed_time = Some(elapsed_time);
this.completion = Some(completion.lock().clone());
cx.emit(CodegenEvent::Finished);
cx.notify();
})
.ok();
}
});
cx.notify();
}
@@ -798,7 +804,7 @@ impl CodegenAlternative {
let new_snapshot = self.buffer.read(cx).snapshot(cx);
let new_range = self.range.to_point(&new_snapshot);
cx.spawn(async move |codegen, cx| {
cx.spawn(|codegen, mut cx| async move {
let (deleted_row_ranges, inserted_row_ranges) = cx
.background_spawn(async move {
let old_text = old_snapshot
@@ -848,7 +854,7 @@ impl CodegenAlternative {
.await;
codegen
.update(cx, |codegen, cx| {
.update(&mut cx, |codegen, cx| {
codegen.diff.deleted_row_ranges = deleted_row_ranges;
codegen.diff.inserted_row_ranges = inserted_row_ranges;
cx.notify();

View File

@@ -43,6 +43,15 @@ pub enum ContextKind {
}
impl ContextKind {
pub fn label(&self) -> &'static str {
match self {
ContextKind::File => "File",
ContextKind::Directory => "Folder",
ContextKind::FetchedUrl => "Fetch",
ContextKind::Thread => "Thread",
}
}
pub fn icon(&self) -> IconName {
match self {
ContextKind::File => IconName::File,

View File

@@ -1,3 +1,4 @@
mod directory_context_picker;
mod fetch_context_picker;
mod file_context_picker;
mod thread_context_picker;
@@ -14,6 +15,8 @@ use thread_context_picker::{render_thread_context_entry, ThreadContextEntry};
use ui::{prelude::*, ContextMenu, ContextMenuEntry, ContextMenuItem};
use workspace::{notifications::NotifyResultExt, Workspace};
use crate::context::ContextKind;
use crate::context_picker::directory_context_picker::DirectoryContextPicker;
use crate::context_picker::fetch_context_picker::FetchContextPicker;
use crate::context_picker::file_context_picker::FileContextPicker;
use crate::context_picker::thread_context_picker::ThreadContextPicker;
@@ -27,41 +30,17 @@ pub enum ConfirmBehavior {
Close,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContextPickerMode {
File,
Fetch,
Thread,
}
impl ContextPickerMode {
pub fn label(&self) -> &'static str {
match self {
Self::File => "File/Directory",
Self::Fetch => "Fetch",
Self::Thread => "Thread",
}
}
pub fn icon(&self) -> IconName {
match self {
Self::File => IconName::File,
Self::Fetch => IconName::Globe,
Self::Thread => IconName::MessageCircle,
}
}
}
#[derive(Debug, Clone)]
enum ContextPickerState {
enum ContextPickerMode {
Default(Entity<ContextMenu>),
File(Entity<FileContextPicker>),
Directory(Entity<DirectoryContextPicker>),
Fetch(Entity<FetchContextPicker>),
Thread(Entity<ThreadContextPicker>),
}
pub(super) struct ContextPicker {
mode: ContextPickerState,
mode: ContextPickerMode,
workspace: WeakEntity<Workspace>,
editor: WeakEntity<Editor>,
context_store: WeakEntity<ContextStore>,
@@ -80,7 +59,7 @@ impl ContextPicker {
cx: &mut Context<Self>,
) -> Self {
ContextPicker {
mode: ContextPickerState::Default(ContextMenu::build(
mode: ContextPickerMode::Default(ContextMenu::build(
window,
cx,
|menu, _window, _cx| menu,
@@ -94,7 +73,7 @@ impl ContextPicker {
}
pub fn init(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.mode = ContextPickerState::Default(self.build_menu(window, cx));
self.mode = ContextPickerMode::Default(self.build_menu(window, cx));
cx.notify();
}
@@ -109,9 +88,13 @@ impl ContextPicker {
.enumerate()
.map(|(ix, entry)| self.recent_menu_item(context_picker.clone(), ix, entry));
let mut modes = vec![ContextPickerMode::File, ContextPickerMode::Fetch];
let mut context_kinds = vec![
ContextKind::File,
ContextKind::Directory,
ContextKind::FetchedUrl,
];
if self.allow_threads() {
modes.push(ContextPickerMode::Thread);
context_kinds.push(ContextKind::Thread);
}
let menu = menu
@@ -129,15 +112,15 @@ impl ContextPicker {
})
.extend(recent_entries)
.when(has_recent, |menu| menu.separator())
.extend(modes.into_iter().map(|mode| {
.extend(context_kinds.into_iter().map(|kind| {
let context_picker = context_picker.clone();
ContextMenuEntry::new(mode.label())
.icon(mode.icon())
ContextMenuEntry::new(kind.label())
.icon(kind.icon())
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.handler(move |window, cx| {
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
context_picker.update(cx, |this, cx| this.select_kind(kind, window, cx))
})
}));
@@ -160,17 +143,12 @@ impl ContextPicker {
self.thread_store.is_some()
}
fn select_mode(
&mut self,
mode: ContextPickerMode,
window: &mut Window,
cx: &mut Context<Self>,
) {
fn select_kind(&mut self, kind: ContextKind, window: &mut Window, cx: &mut Context<Self>) {
let context_picker = cx.entity().downgrade();
match mode {
ContextPickerMode::File => {
self.mode = ContextPickerState::File(cx.new(|cx| {
match kind {
ContextKind::File => {
self.mode = ContextPickerMode::File(cx.new(|cx| {
FileContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
@@ -182,8 +160,20 @@ impl ContextPicker {
)
}));
}
ContextPickerMode::Fetch => {
self.mode = ContextPickerState::Fetch(cx.new(|cx| {
ContextKind::Directory => {
self.mode = ContextPickerMode::Directory(cx.new(|cx| {
DirectoryContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
}));
}
ContextKind::FetchedUrl => {
self.mode = ContextPickerMode::Fetch(cx.new(|cx| {
FetchContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
@@ -194,9 +184,9 @@ impl ContextPicker {
)
}));
}
ContextPickerMode::Thread => {
ContextKind::Thread => {
if let Some(thread_store) = self.thread_store.as_ref() {
self.mode = ContextPickerState::Thread(cx.new(|cx| {
self.mode = ContextPickerMode::Thread(cx.new(|cx| {
ThreadContextPicker::new(
thread_store.clone(),
context_picker.clone(),
@@ -234,7 +224,6 @@ impl ContextPicker {
ElementId::NamedInteger("ctx-recent".into(), ix),
&path,
&path_prefix,
false,
context_store.clone(),
cx,
)
@@ -281,8 +270,10 @@ impl ContextPicker {
context_store.add_file_from_path(project_path.clone(), cx)
});
cx.spawn_in(window, async move |_, cx| task.await.notify_async_err(cx))
.detach();
cx.spawn_in(window, |_, mut cx| async move {
task.await.notify_async_err(&mut cx)
})
.detach();
cx.notify();
}
@@ -305,13 +296,13 @@ impl ContextPicker {
};
let open_thread_task = thread_store.update(cx, |this, cx| this.open_thread(&thread.id, cx));
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
let thread = open_thread_task.await?;
context_store.update(cx, |context_store, cx| {
context_store.update(&mut cx, |context_store, cx| {
context_store.add_thread(thread, cx);
})?;
this.update(cx, |_this, cx| cx.notify())
this.update(&mut cx, |_this, cx| cx.notify())
})
}
@@ -401,10 +392,11 @@ impl EventEmitter<DismissEvent> for ContextPicker {}
impl Focusable for ContextPicker {
fn focus_handle(&self, cx: &App) -> FocusHandle {
match &self.mode {
ContextPickerState::Default(menu) => menu.focus_handle(cx),
ContextPickerState::File(file_picker) => file_picker.focus_handle(cx),
ContextPickerState::Fetch(fetch_picker) => fetch_picker.focus_handle(cx),
ContextPickerState::Thread(thread_picker) => thread_picker.focus_handle(cx),
ContextPickerMode::Default(menu) => menu.focus_handle(cx),
ContextPickerMode::File(file_picker) => file_picker.focus_handle(cx),
ContextPickerMode::Directory(directory_picker) => directory_picker.focus_handle(cx),
ContextPickerMode::Fetch(fetch_picker) => fetch_picker.focus_handle(cx),
ContextPickerMode::Thread(thread_picker) => thread_picker.focus_handle(cx),
}
}
}
@@ -415,10 +407,13 @@ impl Render for ContextPicker {
.w(px(400.))
.min_w(px(400.))
.map(|parent| match &self.mode {
ContextPickerState::Default(menu) => parent.child(menu.clone()),
ContextPickerState::File(file_picker) => parent.child(file_picker.clone()),
ContextPickerState::Fetch(fetch_picker) => parent.child(fetch_picker.clone()),
ContextPickerState::Thread(thread_picker) => parent.child(thread_picker.clone()),
ContextPickerMode::Default(menu) => parent.child(menu.clone()),
ContextPickerMode::File(file_picker) => parent.child(file_picker.clone()),
ContextPickerMode::Directory(directory_picker) => {
parent.child(directory_picker.clone())
}
ContextPickerMode::Fetch(fetch_picker) => parent.child(fetch_picker.clone()),
ContextPickerMode::Thread(thread_picker) => parent.child(thread_picker.clone()),
})
}
}

View File

@@ -0,0 +1,269 @@
use std::path::Path;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use fuzzy::PathMatch;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use project::{PathMatchCandidateSet, ProjectPath, WorktreeId};
use ui::{prelude::*, ListItem};
use util::ResultExt as _;
use workspace::{notifications::NotifyResultExt, Workspace};
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_store::ContextStore;
pub struct DirectoryContextPicker {
picker: Entity<Picker<DirectoryContextPickerDelegate>>,
}
impl DirectoryContextPicker {
pub fn new(
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = DirectoryContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
}
}
impl Focusable for DirectoryContextPicker {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.picker.focus_handle(cx)
}
}
impl Render for DirectoryContextPicker {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.picker.clone()
}
}
pub struct DirectoryContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<PathMatch>,
selected_index: usize,
}
impl DirectoryContextPickerDelegate {
pub fn new(
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
Self {
context_picker,
workspace,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
}
fn search(
&mut self,
query: String,
cancellation_flag: Arc<AtomicBool>,
workspace: &Entity<Workspace>,
cx: &mut Context<Picker<Self>>,
) -> Task<Vec<PathMatch>> {
if query.is_empty() {
let workspace = workspace.read(cx);
let project = workspace.project().read(cx);
let directory_matches = project.worktrees(cx).flat_map(|worktree| {
let worktree = worktree.read(cx);
let path_prefix: Arc<str> = worktree.root_name().into();
worktree.directories(false, 0).map(move |entry| PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: worktree.id().to_usize(),
path: entry.path.clone(),
path_prefix: path_prefix.clone(),
distance_to_relative_ancestor: 0,
is_dir: true,
})
});
Task::ready(directory_matches.collect())
} else {
let worktrees = workspace.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
let candidate_sets = worktrees
.into_iter()
.map(|worktree| {
let worktree = worktree.read(cx);
PathMatchCandidateSet {
snapshot: worktree.snapshot(),
include_ignored: worktree
.root_entry()
.map_or(false, |entry| entry.is_ignored),
include_root_name: true,
candidates: project::Candidates::Directories,
}
})
.collect::<Vec<_>>();
let executor = cx.background_executor().clone();
cx.foreground_executor().spawn(async move {
fuzzy::match_path_sets(
candidate_sets.as_slice(),
query.as_str(),
None,
false,
100,
&cancellation_flag,
executor,
)
.await
})
}
}
}
impl PickerDelegate for DirectoryContextPickerDelegate {
type ListItem = ListItem;
fn match_count(&self) -> usize {
self.matches.len()
}
fn selected_index(&self) -> usize {
self.selected_index
}
fn set_selected_index(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) {
self.selected_index = ix;
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Search folders…".into()
}
fn update_matches(
&mut self,
query: String,
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let Some(workspace) = self.workspace.upgrade() else {
return Task::ready(());
};
let search_task = self.search(query, Arc::<AtomicBool>::default(), &workspace, cx);
cx.spawn(|this, mut cx| async move {
let mut paths = search_task.await;
let empty_path = Path::new("");
paths.retain(|path_match| path_match.path.as_ref() != empty_path);
this.update(&mut cx, |this, _cx| {
this.delegate.matches = paths;
})
.log_err();
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(mat) = self.matches.get(self.selected_index) else {
return;
};
let project_path = ProjectPath {
worktree_id: WorktreeId::from_usize(mat.worktree_id),
path: mat.path.clone(),
};
let Some(task) = self
.context_store
.update(cx, |context_store, cx| {
context_store.add_directory(project_path, cx)
})
.ok()
else {
return;
};
let confirm_behavior = self.confirm_behavior;
cx.spawn_in(window, |this, mut cx| async move {
match task.await.notify_async_err(&mut cx) {
None => anyhow::Ok(()),
Some(()) => this.update_in(&mut cx, |this, window, cx| match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}),
}
})
.detach_and_log_err(cx);
}
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
self.context_picker
.update(cx, |_, cx| {
cx.emit(DismissEvent);
})
.ok();
}
fn render_match(
&self,
ix: usize,
selected: bool,
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let path_match = &self.matches[ix];
let directory_name = path_match.path.to_string_lossy().to_string();
let added = self.context_store.upgrade().map_or(false, |context_store| {
context_store
.read(cx)
.includes_directory(&path_match.path)
.is_some()
});
Some(
ListItem::new(ix)
.inset(true)
.toggle_state(selected)
.start_slot(
Icon::new(IconName::Folder)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(Label::new(directory_name))
.when(added, |el| {
el.end_slot(
h_flex()
.gap_1()
.child(
Icon::new(IconName::Check)
.size(IconSize::Small)
.color(Color::Success),
)
.child(Label::new("Added").size(LabelSize::Small)),
)
}),
)
}
}

View File

@@ -206,12 +206,12 @@ impl PickerDelegate for FetchContextPickerDelegate {
let http_client = workspace.read(cx).client().http_client().clone();
let url = self.url.clone();
let confirm_behavior = self.confirm_behavior;
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
let text = cx
.background_spawn(Self::build_message(http_client, url.clone()))
.await?;
this.update_in(cx, |this, window, cx| {
this.update_in(&mut cx, |this, window, cx| {
this.delegate
.context_store
.update(cx, |context_store, _cx| {

View File

@@ -99,6 +99,7 @@ impl FileContextPickerDelegate {
query: String,
cancellation_flag: Arc<AtomicBool>,
workspace: &Entity<Workspace>,
cx: &mut Context<Picker<Self>>,
) -> Task<Vec<PathMatch>> {
if query.is_empty() {
@@ -123,14 +124,14 @@ impl FileContextPickerDelegate {
let file_matches = project.worktrees(cx).flat_map(|worktree| {
let worktree = worktree.read(cx);
let path_prefix: Arc<str> = worktree.root_name().into();
worktree.entries(false, 0).map(move |entry| PathMatch {
worktree.files(false, 0).map(move |entry| PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: worktree.id().to_usize(),
path: entry.path.clone(),
path_prefix: path_prefix.clone(),
distance_to_relative_ancestor: 0,
is_dir: entry.is_dir(),
is_dir: false,
})
});
@@ -148,7 +149,7 @@ impl FileContextPickerDelegate {
.root_entry()
.map_or(false, |entry| entry.is_ignored),
include_root_name: true,
candidates: project::Candidates::Entries,
candidates: project::Candidates::Files,
}
})
.collect::<Vec<_>>();
@@ -191,7 +192,7 @@ impl PickerDelegate for FileContextPickerDelegate {
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Search files & directories".into()
"Search files…".into()
}
fn update_matches(
@@ -206,11 +207,11 @@ impl PickerDelegate for FileContextPickerDelegate {
let search_task = self.search(query, Arc::<AtomicBool>::default(), &workspace, cx);
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
// TODO: This should be probably be run in the background.
let paths = search_task.await;
this.update(cx, |this, _cx| {
this.update(&mut cx, |this, _cx| {
this.delegate.matches = paths;
})
.log_err();
@@ -222,11 +223,13 @@ impl PickerDelegate for FileContextPickerDelegate {
return;
};
let file_name = mat
let Some(file_name) = mat
.path
.file_name()
.map(|os_str| os_str.to_string_lossy().into_owned())
.unwrap_or(mat.path_prefix.to_string());
else {
return;
};
let full_path = mat.path.display().to_string();
@@ -235,8 +238,6 @@ impl PickerDelegate for FileContextPickerDelegate {
path: mat.path.clone(),
};
let is_directory = mat.is_dir;
let Some(editor_entity) = self.editor.upgrade() else {
return;
};
@@ -287,12 +288,8 @@ impl PickerDelegate for FileContextPickerDelegate {
editor.insert("\n", window, cx); // Needed to end the fold
let file_icon = if is_directory {
FileIcons::get_folder_icon(false, cx)
} else {
FileIcons::get_icon(&Path::new(&full_path), cx)
}
.unwrap_or_else(|| SharedString::new(""));
let file_icon = FileIcons::get_icon(&Path::new(&full_path), cx)
.unwrap_or_else(|| SharedString::new(""));
let placeholder = FoldPlaceholder {
render: render_fold_icon_button(
@@ -333,11 +330,7 @@ impl PickerDelegate for FileContextPickerDelegate {
let Some(task) = self
.context_store
.update(cx, |context_store, cx| {
if is_directory {
context_store.add_directory(project_path, cx)
} else {
context_store.add_file_from_path(project_path, cx)
}
context_store.add_file_from_path(project_path, cx)
})
.ok()
else {
@@ -345,10 +338,10 @@ impl PickerDelegate for FileContextPickerDelegate {
};
let confirm_behavior = self.confirm_behavior;
cx.spawn_in(window, async move |this, cx| {
match task.await.notify_async_err(cx) {
cx.spawn_in(window, |this, mut cx| async move {
match task.await.notify_async_err(&mut cx) {
None => anyhow::Ok(()),
Some(()) => this.update_in(cx, |this, window, cx| match confirm_behavior {
Some(()) => this.update_in(&mut cx, |this, window, cx| match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}),
@@ -382,7 +375,6 @@ impl PickerDelegate for FileContextPickerDelegate {
ElementId::NamedInteger("file-ctx-picker".into(), ix),
&path_match.path,
&path_match.path_prefix,
path_match.is_dir,
self.context_store.clone(),
cx,
)),
@@ -394,7 +386,6 @@ pub fn render_file_context_entry(
id: ElementId,
path: &Path,
path_prefix: &Arc<str>,
is_directory: bool,
context_store: WeakEntity<ContextStore>,
cx: &App,
) -> Stateful<Div> {
@@ -418,24 +409,13 @@ pub fn render_file_context_entry(
(file_name, Some(directory))
};
let added = context_store.upgrade().and_then(|context_store| {
if is_directory {
context_store
.read(cx)
.includes_directory(path)
.map(FileInclusion::Direct)
} else {
context_store.read(cx).will_include_file_path(path, cx)
}
});
let added = context_store
.upgrade()
.and_then(|context_store| context_store.read(cx).will_include_file_path(path, cx));
let file_icon = if is_directory {
FileIcons::get_folder_icon(false, cx)
} else {
FileIcons::get_icon(&path, cx)
}
.map(Icon::from_path)
.unwrap_or_else(|| Icon::new(IconName::File));
let file_icon = FileIcons::get_icon(&path, cx)
.map(Icon::from_path)
.unwrap_or_else(|| Icon::new(IconName::File));
h_flex()
.id(id)

View File

@@ -149,9 +149,9 @@ impl PickerDelegate for ThreadContextPickerDelegate {
}
});
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
let matches = search_task.await;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.delegate.matches = matches;
this.delegate.selected_index = 0;
cx.notify();
@@ -171,9 +171,9 @@ impl PickerDelegate for ThreadContextPickerDelegate {
let open_thread_task = thread_store.update(cx, |this, cx| this.open_thread(&entry.id, cx));
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
let thread = open_thread_task.await?;
this.update_in(cx, |this, window, cx| {
this.update_in(&mut cx, |this, window, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| context_store.add_thread(thread, cx))

View File

@@ -9,7 +9,6 @@ use language::Buffer;
use project::{ProjectPath, Worktree};
use rope::Rope;
use text::BufferId;
use util::maybe;
use workspace::Workspace;
use crate::context::{
@@ -75,15 +74,15 @@ impl ContextStore {
return Task::ready(Err(anyhow!("failed to read project")));
};
cx.spawn(async move |this, cx| {
let open_buffer_task = project.update(cx, |project, cx| {
cx.spawn(|this, mut cx| async move {
let open_buffer_task = project.update(&mut cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})?;
let buffer_entity = open_buffer_task.await?;
let buffer_id = this.update(cx, |_, cx| buffer_entity.read(cx).remote_id())?;
let buffer_id = this.update(&mut cx, |_, cx| buffer_entity.read(cx).remote_id())?;
let already_included = this.update(cx, |this, _cx| {
let already_included = this.update(&mut cx, |this, _cx| {
match this.will_include_buffer(buffer_id, &project_path.path) {
Some(FileInclusion::Direct(context_id)) => {
this.remove_context(context_id);
@@ -98,7 +97,7 @@ impl ContextStore {
return anyhow::Ok(());
}
let (buffer_info, text_task) = this.update(cx, |_, cx| {
let (buffer_info, text_task) = this.update(&mut cx, |_, cx| {
let buffer = buffer_entity.read(cx);
collect_buffer_info_and_text(
project_path.path.clone(),
@@ -110,7 +109,7 @@ impl ContextStore {
let text = text_task.await;
this.update(cx, |this, _cx| {
this.update(&mut cx, |this, _cx| {
this.insert_file(make_context_buffer(buffer_info, text));
})?;
@@ -123,8 +122,8 @@ impl ContextStore {
buffer_entity: Entity<Buffer>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
cx.spawn(async move |this, cx| {
let (buffer_info, text_task) = this.update(cx, |_, cx| {
cx.spawn(|this, mut cx| async move {
let (buffer_info, text_task) = this.update(&mut cx, |_, cx| {
let buffer = buffer_entity.read(cx);
let Some(file) = buffer.file() else {
return Err(anyhow!("Buffer has no path."));
@@ -139,7 +138,7 @@ impl ContextStore {
let text = text_task.await;
this.update(cx, |this, _cx| {
this.update(&mut cx, |this, _cx| {
this.insert_file(make_context_buffer(buffer_info, text))
})?;
@@ -179,18 +178,18 @@ impl ContextStore {
}
let worktree_id = project_path.worktree_id;
cx.spawn(async move |this, cx| {
let worktree = project.update(cx, |project, cx| {
cx.spawn(|this, mut cx| async move {
let worktree = project.update(&mut cx, |project, cx| {
project
.worktree_for_id(worktree_id, cx)
.ok_or_else(|| anyhow!("no worktree found for {worktree_id:?}"))
})??;
let files = worktree.update(cx, |worktree, _cx| {
let files = worktree.update(&mut cx, |worktree, _cx| {
collect_files_in_path(worktree, &project_path.path)
})?;
let open_buffers_task = project.update(cx, |project, cx| {
let open_buffers_task = project.update(&mut cx, |project, cx| {
let tasks = files.iter().map(|file_path| {
project.open_buffer(
ProjectPath {
@@ -207,7 +206,7 @@ impl ContextStore {
let mut buffer_infos = Vec::new();
let mut text_tasks = Vec::new();
this.update(cx, |_, cx| {
this.update(&mut cx, |_, cx| {
for (path, buffer_entity) in files.into_iter().zip(buffers) {
// Skip all binary files and other non-UTF8 files
if let Ok(buffer_entity) = buffer_entity {
@@ -236,7 +235,7 @@ impl ContextStore {
bail!("No text files found in {}", &project_path.path.display());
}
this.update(cx, |this, _| {
this.update(&mut cx, |this, _| {
this.insert_directory(&project_path.path, context_buffers);
})?;
@@ -532,59 +531,35 @@ fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<Arc<Path>> {
pub fn refresh_context_store_text(
context_store: Entity<ContextStore>,
changed_buffers: &HashSet<Entity<Buffer>>,
cx: &App,
) -> impl Future<Output = Vec<ContextId>> {
) -> impl Future<Output = ()> {
let mut tasks = Vec::new();
for context in &context_store.read(cx).context {
let id = context.id();
let task = maybe!({
match context {
AssistantContext::File(file_context) => {
if changed_buffers.is_empty()
|| changed_buffers.contains(&file_context.context_buffer.buffer)
{
let context_store = context_store.clone();
return refresh_file_text(context_store, file_context, cx);
}
match context {
AssistantContext::File(file_context) => {
let context_store = context_store.clone();
if let Some(task) = refresh_file_text(context_store, file_context, cx) {
tasks.push(task);
}
AssistantContext::Directory(directory_context) => {
let should_refresh = changed_buffers.is_empty()
|| changed_buffers.iter().any(|buffer| {
let buffer = buffer.read(cx);
buffer_path_log_err(&buffer)
.map_or(false, |path| path.starts_with(&directory_context.path))
});
if should_refresh {
let context_store = context_store.clone();
return refresh_directory_text(context_store, directory_context, cx);
}
}
AssistantContext::Thread(thread_context) => {
if changed_buffers.is_empty() {
let context_store = context_store.clone();
return Some(refresh_thread_text(context_store, thread_context, cx));
}
}
// Intentionally omit refreshing fetched URLs as it doesn't seem all that useful,
// and doing the caching properly could be tricky (unless it's already handled by
// the HttpClient?).
AssistantContext::FetchedUrl(_) => {}
}
None
});
if let Some(task) = task {
tasks.push(task.map(move |_| id));
AssistantContext::Directory(directory_context) => {
let context_store = context_store.clone();
if let Some(task) = refresh_directory_text(context_store, directory_context, cx) {
tasks.push(task);
}
}
AssistantContext::Thread(thread_context) => {
let context_store = context_store.clone();
tasks.push(refresh_thread_text(context_store, thread_context, cx));
}
// Intentionally omit refreshing fetched URLs as it doesn't seem all that useful,
// and doing the caching properly could be tricky (unless it's already handled by
// the HttpClient?).
AssistantContext::FetchedUrl(_) => {}
}
}
future::join_all(tasks)
future::join_all(tasks).map(|_| ())
}
fn refresh_file_text(
@@ -595,10 +570,10 @@ fn refresh_file_text(
let id = file_context.id;
let task = refresh_context_buffer(&file_context.context_buffer, cx);
if let Some(task) = task {
Some(cx.spawn(async move |cx| {
Some(cx.spawn(|mut cx| async move {
let context_buffer = task.await;
context_store
.update(cx, |context_store, _| {
.update(&mut cx, |context_store, _| {
let new_file_context = FileContext { id, context_buffer };
context_store.replace_context(AssistantContext::File(new_file_context));
})
@@ -636,10 +611,10 @@ fn refresh_directory_text(
let id = directory_context.snapshot.id;
let path = directory_context.path.clone();
Some(cx.spawn(async move |cx| {
Some(cx.spawn(|mut cx| async move {
let context_buffers = context_buffers.await;
context_store
.update(cx, |context_store, _| {
.update(&mut cx, |context_store, _| {
let new_directory_context = DirectoryContext::new(id, &path, context_buffers);
context_store.replace_context(AssistantContext::Directory(new_directory_context));
})
@@ -654,9 +629,9 @@ fn refresh_thread_text(
) -> Task<()> {
let id = thread_context.id;
let thread = thread_context.thread.clone();
cx.spawn(async move |cx| {
cx.spawn(move |mut cx| async move {
context_store
.update(cx, |context_store, cx| {
.update(&mut cx, |context_store, cx| {
let text = thread.read(cx).text().into();
context_store.replace_context(AssistantContext::Thread(ThreadContext {
id,

View File

@@ -335,12 +335,12 @@ impl ContextStrip {
context_store.accept_suggested_context(&suggested, cx)
});
cx.spawn_in(window, async move |this, cx| {
match task.await.notify_async_err(cx) {
cx.spawn_in(window, |this, mut cx| async move {
match task.await.notify_async_err(&mut cx) {
None => {}
Some(()) => {
if let Some(this) = this.upgrade() {
this.update(cx, |_, cx| cx.notify())?;
this.update(&mut cx, |_, cx| cx.notify())?;
}
}
}

View File

@@ -276,7 +276,7 @@ impl InlineAssistant {
if is_authenticated() {
handle_assist(window, cx);
} else {
cx.spawn_in(window, async move |_workspace, cx| {
cx.spawn_in(window, |_workspace, mut cx| async move {
let Some(task) = cx.update(|_, cx| {
LanguageModelRegistry::read_global(cx)
.active_provider()
@@ -1456,9 +1456,9 @@ impl EditorInlineAssists {
assist_ids: Vec::new(),
scroll_lock: None,
highlight_updates: highlight_updates_tx,
_update_highlights: cx.spawn({
_update_highlights: cx.spawn(|cx| {
let editor = editor.downgrade();
async move |cx| {
async move {
while let Ok(()) = highlight_updates_rx.changed().await {
let editor = editor.upgrade().context("editor was dropped")?;
cx.update_global(|assistant: &mut InlineAssistant, cx| {
@@ -1729,7 +1729,6 @@ impl CodeActionProvider for AssistantCodeActionProvider {
title: "Fix with Assistant".into(),
..Default::default()
})),
resolved: true,
}]))
} else {
Task::ready(Ok(Vec::new()))
@@ -1748,10 +1747,10 @@ impl CodeActionProvider for AssistantCodeActionProvider {
let editor = self.editor.clone();
let workspace = self.workspace.clone();
let thread_store = self.thread_store.clone();
window.spawn(cx, async move |cx| {
window.spawn(cx, |mut cx| async move {
let editor = editor.upgrade().context("editor was released")?;
let range = editor
.update(cx, |editor, cx| {
.update(&mut cx, |editor, cx| {
editor.buffer().update(cx, |multibuffer, cx| {
let buffer = buffer.read(cx);
let multibuffer_snapshot = multibuffer.read(cx);

View File

@@ -1,27 +1,24 @@
use std::sync::Arc;
use collections::HashSet;
use editor::actions::MoveUp;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle};
use file_icons::FileIcons;
use fs::Fs;
use git::ExpandCommitEditor;
use git_ui::git_panel;
use gpui::{
Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle,
WeakEntity,
};
use language_model::LanguageModelRegistry;
use language_model_selector::ToggleModelSelector;
use project::Project;
use rope::Point;
use settings::Settings;
use std::time::Duration;
use text::Bias;
use theme::ThemeSettings;
use ui::{
prelude::*, ButtonLike, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, Tooltip,
prelude::*, ButtonLike, Disclosure, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle,
Tooltip,
};
use util::ResultExt;
use vim_mode_setting::VimModeSetting;
use workspace::notifications::{NotificationId, NotifyTaskExt};
use workspace::{Toast, Workspace};
@@ -33,13 +30,12 @@ use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::thread::{RequestKind, Thread};
use crate::thread_store::ThreadStore;
use crate::tool_selector::ToolSelector;
use crate::{Chat, ChatMode, RemoveAllContext, ThreadEvent, ToggleContextPicker};
use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker};
pub struct MessageEditor {
thread: Entity<Thread>,
editor: Entity<Editor>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
context_store: Entity<ContextStore>,
context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
@@ -47,6 +43,7 @@ pub struct MessageEditor {
inline_context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
model_selector: Entity<AssistantModelSelector>,
tool_selector: Entity<ToolSelector>,
edits_expanded: bool,
_subscriptions: Vec<Subscription>,
}
@@ -54,13 +51,13 @@ impl MessageEditor {
pub fn new(
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
thread_store: WeakEntity<ThreadStore>,
thread: Entity<Thread>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let tools = thread.read(cx).tools().clone();
let context_store = cx.new(|_cx| ContextStore::new(workspace.clone()));
let context_picker_menu_handle = PopoverMenuHandle::default();
let inline_context_picker_menu_handle = PopoverMenuHandle::default();
let model_selector_menu_handle = PopoverMenuHandle::default();
@@ -109,9 +106,8 @@ impl MessageEditor {
];
Self {
editor: editor.clone(),
project: thread.read(cx).project().clone(),
thread,
editor: editor.clone(),
workspace,
context_store,
context_strip,
@@ -128,6 +124,7 @@ impl MessageEditor {
)
}),
tool_selector: cx.new(|cx| ToolSelector::new(tools, cx)),
edits_expanded: false,
_subscriptions: subscriptions,
}
}
@@ -160,7 +157,7 @@ impl MessageEditor {
return;
}
if self.thread.read(cx).is_generating() {
if self.thread.read(cx).is_streaming() {
return;
}
@@ -203,31 +200,16 @@ impl MessageEditor {
text
});
let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
let system_prompt_context_task = self.thread.read(cx).load_system_prompt_context(cx);
let refresh_task = refresh_context_store_text(self.context_store.clone(), cx);
let thread = self.thread.clone();
let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store();
let checkpoint = git_store.read(cx).checkpoint(cx);
cx.spawn(async move |_, cx| {
cx.spawn(move |_, mut cx| async move {
refresh_task.await;
let (system_prompt_context, load_error) = system_prompt_context_task.await;
thread
.update(cx, |thread, cx| {
thread.set_system_prompt_context(system_prompt_context);
if let Some(load_error) = load_error {
cx.emit(ThreadEvent::ShowError(load_error));
}
})
.ok();
let checkpoint = checkpoint.await.log_err();
thread
.update(cx, |thread, cx| {
.update(&mut cx, |thread, cx| {
let context = context_store.read(cx).snapshot(cx).collect::<Vec<_>>();
thread.insert_user_message(user_message, context, checkpoint, cx);
thread.insert_user_message(user_message, context, cx);
thread.send_to_model(model, request_kind, cx);
})
.ok();
@@ -313,9 +295,9 @@ impl MessageEditor {
.thread
.update(cx, |thread, cx| thread.report_feedback(is_positive, cx));
cx.spawn(async move |_, cx| {
cx.spawn(|_, mut cx| async move {
report.await?;
workspace.update(cx, |workspace, cx| {
workspace.update(&mut cx, |workspace, cx| {
let message = if is_positive {
"Positive feedback recorded. Thank you!"
} else {
@@ -344,7 +326,7 @@ impl Render for MessageEditor {
let focus_handle = self.editor.focus_handle(cx);
let inline_context_picker = self.inline_context_picker.clone();
let bg_color = cx.theme().colors().editor_background;
let is_generating = self.thread.read(cx).is_generating();
let is_streaming_completion = self.thread.read(cx).is_streaming();
let is_model_selected = self.is_model_selected(cx);
let is_editor_empty = self.is_editor_empty(cx);
let submit_label_color = if is_editor_empty {
@@ -363,16 +345,12 @@ impl Render for MessageEditor {
px(64.)
};
let project = self.thread.read(cx).project();
let changed_files = if let Some(repository) = project.read(cx).active_repository(cx) {
repository.read(cx).status().count()
} else {
0
};
let changed_buffers = self.thread.read(cx).scripting_changed_buffers(cx);
let changed_buffers_count = changed_buffers.len();
v_flex()
.size_full()
.when(is_generating, |parent| {
.when(is_streaming_completion, |parent| {
let focus_handle = self.editor.focus_handle(cx).clone();
parent.child(
h_flex().py_3().w_full().justify_center().child(
@@ -430,7 +408,7 @@ impl Render for MessageEditor {
),
)
})
.when(changed_files > 0, |parent| {
.when(changed_buffers_count > 0, |parent| {
parent.child(
v_flex()
.mx_2()
@@ -441,60 +419,96 @@ impl Render for MessageEditor {
.rounded_t_md()
.child(
h_flex()
.justify_between()
.gap_2()
.p_2()
.child(
h_flex()
.gap_2()
.child(
IconButton::new(
"edits-disclosure",
IconName::GitBranchSmall,
)
.icon_size(IconSize::Small)
.on_click(
|_ev, _window, cx| {
cx.defer(|cx| {
cx.dispatch_action(&git_panel::ToggleFocus)
});
},
),
)
.child(
Label::new(format!(
"{} {} changed",
changed_files,
if changed_files == 1 { "file" } else { "files" }
))
.size(LabelSize::XSmall)
.color(Color::Muted),
),
Disclosure::new("edits-disclosure", self.edits_expanded)
.on_click(cx.listener(|this, _ev, _window, cx| {
this.edits_expanded = !this.edits_expanded;
cx.notify();
})),
)
.child(
h_flex()
.gap_2()
.child(
Button::new("review", "Review")
.label_size(LabelSize::XSmall)
.on_click(|_event, _window, cx| {
cx.defer(|cx| {
cx.dispatch_action(
&git_ui::project_diff::Diff,
);
});
}),
)
.child(
Button::new("commit", "Commit")
.label_size(LabelSize::XSmall)
.on_click(|_event, _window, cx| {
cx.defer(|cx| {
cx.dispatch_action(&ExpandCommitEditor)
});
}),
),
Label::new("Edits")
.size(LabelSize::XSmall)
.color(Color::Muted),
)
.child(Label::new("").size(LabelSize::XSmall).color(Color::Muted))
.child(
Label::new(format!(
"{} {}",
changed_buffers_count,
if changed_buffers_count == 1 {
"file"
} else {
"files"
}
))
.size(LabelSize::XSmall)
.color(Color::Muted),
),
),
)
.when(self.edits_expanded, |parent| {
parent.child(
v_flex().bg(cx.theme().colors().editor_background).children(
changed_buffers.enumerate().flat_map(|(index, buffer)| {
let file = buffer.read(cx).file()?;
let path = file.path();
let parent_label = path.parent().and_then(|parent| {
let parent_str = parent.to_string_lossy();
if parent_str.is_empty() {
None
} else {
Some(
Label::new(format!(
"{}{}",
parent_str,
std::path::MAIN_SEPARATOR_STR
))
.color(Color::Muted)
.size(LabelSize::Small),
)
}
});
let name_label = path.file_name().map(|name| {
Label::new(name.to_string_lossy().to_string())
.size(LabelSize::Small)
});
let file_icon = FileIcons::get_icon(&path, cx)
.map(Icon::from_path)
.unwrap_or_else(|| Icon::new(IconName::File));
let element = div()
.p_2()
.when(index + 1 < changed_buffers_count, |parent| {
parent
.border_color(cx.theme().colors().border)
.border_b_1()
})
.child(
h_flex()
.gap_2()
.child(file_icon)
.child(
// TODO: handle overflow
h_flex()
.children(parent_label)
.children(name_label),
)
// TODO: show lines changed
.child(Label::new("+").color(Color::Created))
.child(Label::new("-").color(Color::Deleted)),
);
Some(element)
}),
),
)
}),
)
})
.child(
@@ -609,7 +623,7 @@ impl Render for MessageEditor {
.disabled(
is_editor_empty
|| !is_model_selected
|| is_generating,
|| is_streaming_completion,
)
.child(
h_flex()
@@ -644,7 +658,7 @@ impl Render for MessageEditor {
"Type a message to submit",
))
})
.when(is_generating, |button| {
.when(is_streaming_completion, |button| {
button.tooltip(Tooltip::text(
"Cancel to submit a new message",
))

View File

@@ -40,7 +40,7 @@ impl TerminalCodegen {
let telemetry = self.telemetry.clone();
self.status = CodegenStatus::Pending;
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
self.generation = cx.spawn(async move |this, cx| {
self.generation = cx.spawn(|this, mut cx| async move {
let model_telemetry_id = model.telemetry_id();
let model_provider_id = model.provider_id();
let response = model.stream_completion_text(prompt, &cx).await;
@@ -97,12 +97,12 @@ impl TerminalCodegen {
}
});
this.update(cx, |this, _| {
this.update(&mut cx, |this, _| {
this.message_id = message_id;
})?;
while let Some(hunk) = hunks_rx.next().await {
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
if let Some(transaction) = &mut this.transaction {
transaction.push(hunk, cx);
cx.notify();
@@ -116,7 +116,7 @@ impl TerminalCodegen {
let result = generate.await;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
if let Err(error) = result {
this.status = CodegenStatus::Error(error);
} else {

View File

@@ -1,12 +1,10 @@
use std::fmt::Write as _;
use std::io::Write;
use std::sync::Arc;
use anyhow::{Context as _, Result};
use assistant_tool::{ActionLog, ToolWorkingSet};
use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git;
@@ -17,14 +15,11 @@ use language_model::{
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, TokenUsage,
};
use project::git_store::{GitStore, GitStoreCheckpoint};
use project::{Project, Worktree};
use prompt_store::{
AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
};
use project::Project;
use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
use scripting_tool::{ScriptingSession, ScriptingTool};
use serde::{Deserialize, Serialize};
use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _};
use util::{post_inc, ResultExt, TryFutureExt as _};
use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
@@ -93,31 +88,6 @@ pub struct GitState {
pub diff: Option<String>,
}
#[derive(Clone)]
pub struct ThreadCheckpoint {
message_id: MessageId,
git_checkpoint: GitStoreCheckpoint,
}
pub enum LastRestoreCheckpoint {
Pending {
message_id: MessageId,
},
Error {
message_id: MessageId,
error: String,
},
}
impl LastRestoreCheckpoint {
pub fn message_id(&self) -> MessageId {
match self {
LastRestoreCheckpoint::Pending { message_id } => *message_id,
LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
}
}
}
/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
@@ -128,16 +98,12 @@ pub struct Thread {
next_message_id: MessageId,
context: BTreeMap<ContextId, ContextSnapshot>,
context_by_message: HashMap<MessageId, Vec<ContextId>>,
system_prompt_context: Option<AssistantSystemPromptContext>,
checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
project: Entity<Project>,
prompt_builder: Arc<PromptBuilder>,
tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState,
action_log: Entity<ActionLog>,
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
scripting_session: Entity<ScriptingSession>,
scripting_tool_use: ToolUseState,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
@@ -160,18 +126,14 @@ impl Thread {
next_message_id: MessageId(0),
context: BTreeMap::default(),
context_by_message: HashMap::default(),
system_prompt_context: None,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
project: project.clone(),
prompt_builder,
tools: tools.clone(),
last_restore_checkpoint: None,
tool_use: ToolUseState::new(tools.clone()),
tools,
tool_use: ToolUseState::new(),
scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
scripting_tool_use: ToolUseState::new(tools),
action_log: cx.new(|_| ActionLog::new()),
scripting_tool_use: ToolUseState::new(),
initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project, cx);
cx.foreground_executor()
@@ -197,12 +159,11 @@ impl Thread {
.map(|message| message.id.0 + 1)
.unwrap_or(0),
);
let tool_use =
ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
name != ScriptingTool::NAME
});
let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| {
name != ScriptingTool::NAME
});
let scripting_tool_use =
ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
ToolUseState::from_serialized_messages(&serialized.messages, |name| {
name == ScriptingTool::NAME
});
let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
@@ -224,16 +185,12 @@ impl Thread {
next_message_id,
context: BTreeMap::default(),
context_by_message: HashMap::default(),
system_prompt_context: None,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
last_restore_checkpoint: None,
project,
prompt_builder,
tools,
tool_use,
action_log: cx.new(|_| ActionLog::new()),
scripting_session,
scripting_tool_use,
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
@@ -280,7 +237,7 @@ impl Thread {
self.messages.iter()
}
pub fn is_generating(&self) -> bool {
pub fn is_streaming(&self) -> bool {
!self.pending_completions.is_empty() || !self.all_tools_finished()
}
@@ -288,66 +245,6 @@ impl Thread {
&self.tools
}
pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
Some(ThreadCheckpoint {
message_id: id,
git_checkpoint: checkpoint,
})
}
pub fn restore_checkpoint(
&mut self,
checkpoint: ThreadCheckpoint,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
message_id: checkpoint.message_id,
});
cx.emit(ThreadEvent::CheckpointChanged);
let project = self.project.read(cx);
let restore = project
.git_store()
.read(cx)
.restore_checkpoint(checkpoint.git_checkpoint, cx);
cx.spawn(async move |this, cx| {
let result = restore.await;
this.update(cx, |this, cx| {
if let Err(err) = result.as_ref() {
this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
message_id: checkpoint.message_id,
error: err.to_string(),
});
} else {
this.last_restore_checkpoint = None;
this.truncate(checkpoint.message_id, cx);
}
cx.emit(ThreadEvent::CheckpointChanged);
})?;
result
})
}
pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
self.last_restore_checkpoint.as_ref()
}
pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
let Some(message_ix) = self
.messages
.iter()
.rposition(|message| message.id == message_id)
else {
return;
};
for deleted_message in self.messages.drain(message_ix..) {
self.context_by_message.remove(&deleted_message.id);
self.checkpoints_by_message.remove(&deleted_message.id);
}
cx.notify();
}
pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
let context = self.context_by_message.get(&id)?;
Some(
@@ -367,27 +264,23 @@ impl Thread {
.into_iter()
.chain(self.scripting_tool_use.pending_tool_uses());
// If the only pending tool uses left are the ones with errors, then
// that means that we've finished running all of the pending tools.
// If the only pending tool uses left are the ones with errors, then that means that we've finished running all
// of the pending tools.
all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
}
pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
self.tool_use.tool_uses_for_message(id, cx)
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
self.tool_use.tool_uses_for_message(id)
}
pub fn scripting_tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
self.scripting_tool_use.tool_uses_for_message(id, cx)
pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
self.scripting_tool_use.tool_uses_for_message(id)
}
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
self.tool_use.tool_results_for_message(id)
}
pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
self.tool_use.tool_result(id)
}
pub fn scripting_tool_results_for_message(
&self,
id: MessageId,
@@ -395,6 +288,13 @@ impl Thread {
self.scripting_tool_use.tool_results_for_message(id)
}
pub fn scripting_changed_buffers<'a>(
&self,
cx: &'a App,
) -> impl ExactSizeIterator<Item = &'a Entity<language::Buffer>> {
self.scripting_session.read(cx).changed_buffers()
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
@@ -407,7 +307,6 @@ impl Thread {
&mut self,
text: impl Into<String>,
context: Vec<ContextSnapshot>,
checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
let message_id = self.insert_message(Role::User, text, cx);
@@ -415,9 +314,6 @@ impl Thread {
self.context
.extend(context.into_iter().map(|context| (context.id, context)));
self.context_by_message.insert(message_id, context_ids);
if let Some(checkpoint) = checkpoint {
self.checkpoints_by_message.insert(message_id, checkpoint);
}
message_id
}
@@ -490,9 +386,9 @@ impl Thread {
/// Serializes this thread into a format for storage or telemetry.
pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
let initial_project_snapshot = self.initial_project_snapshot.clone();
cx.spawn(async move |this, cx| {
cx.spawn(|this, cx| async move {
let initial_project_snapshot = initial_project_snapshot.await;
this.read_with(cx, |this, cx| SerializedThread {
this.read_with(&cx, |this, _| SerializedThread {
summary: this.summary_or_default(),
updated_at: this.updated_at(),
messages: this
@@ -502,9 +398,9 @@ impl Thread {
role: message.role,
text: message.text.clone(),
tool_uses: this
.tool_uses_for_message(message.id, cx)
.tool_uses_for_message(message.id)
.into_iter()
.chain(this.scripting_tool_uses_for_message(message.id, cx))
.chain(this.scripting_tool_uses_for_message(message.id))
.map(|tool_use| SerializedToolUse {
id: tool_use.id,
name: tool_use.name,
@@ -528,116 +424,6 @@ impl Thread {
})
}
pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
self.system_prompt_context = Some(context);
}
pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
&self.system_prompt_context
}
pub fn load_system_prompt_context(
&self,
cx: &App,
) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
let project = self.project.read(cx);
let tasks = project
.visible_worktrees(cx)
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(
project.fs().clone(),
worktree.read(cx),
cx,
)
})
.collect::<Vec<_>>();
cx.spawn(async |_cx| {
let results = futures::future::join_all(tasks).await;
let mut first_err = None;
let worktrees = results
.into_iter()
.map(|(worktree, err)| {
if first_err.is_none() && err.is_some() {
first_err = err;
}
worktree
})
.collect::<Vec<_>>();
(AssistantSystemPromptContext::new(worktrees), first_err)
})
}
fn load_worktree_info_for_system_prompt(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
let root_name = worktree.root_name().into();
let abs_path = worktree.abs_path();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
const RULES_FILE_NAMES: [&'static str; 5] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
"CLAUDE.md",
];
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
})
.next();
if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
cx.spawn(async move |_| {
let rules_file_result = maybe!(async move {
let abs_rules_path = abs_rules_path?;
let text = fs.load(&abs_rules_path).await.with_context(|| {
format!("Failed to load assistant rules file {:?}", abs_rules_path)
})?;
anyhow::Ok(RulesFile {
rel_path: rel_rules_path,
abs_path: abs_rules_path.into(),
text: text.trim().to_string(),
})
})
.await;
let (rules_file, rules_file_error) = match rules_file_result {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(ThreadError::Message {
header: "Error loading rules file".into(),
message: format!("{err}").into(),
}),
),
};
let worktree_info = WorktreeInfoForSystemPrompt {
root_name,
abs_path,
rules_file,
};
(worktree_info, rules_file_error)
})
} else {
Task::ready((
WorktreeInfoForSystemPrompt {
root_name,
abs_path,
rules_file: None,
},
None,
))
}
}
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
@@ -675,30 +461,36 @@ impl Thread {
request_kind: RequestKind,
cx: &App,
) -> LanguageModelRequest {
let worktree_root_names = self
.project
.read(cx)
.visible_worktrees(cx)
.map(|worktree| {
let worktree = worktree.read(cx);
AssistantSystemPromptWorktree {
root_name: worktree.root_name().into(),
abs_path: worktree.abs_path(),
}
})
.collect::<Vec<_>>();
let system_prompt = self
.prompt_builder
.generate_assistant_system_prompt(worktree_root_names)
.context("failed to generate assistant system prompt")
.log_err()
.unwrap_or_default();
let mut request = LanguageModelRequest {
messages: vec![],
messages: vec![LanguageModelRequestMessage {
role: Role::System,
content: vec![MessageContent::Text(system_prompt)],
cache: true,
}],
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
};
if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
if let Some(system_prompt) = self
.prompt_builder
.generate_assistant_system_prompt(system_prompt_context)
.context("failed to generate assistant system prompt")
.log_err()
{
request.messages.push(LanguageModelRequestMessage {
role: Role::System,
content: vec![MessageContent::Text(system_prompt)],
cache: true,
});
}
} else {
log::error!("system_prompt_context not set.")
}
let mut referenced_context_ids = HashSet::default();
for message in &self.messages {
@@ -761,39 +553,9 @@ impl Thread {
request.messages.push(context_message);
}
self.attach_stale_files(&mut request.messages, cx);
request
}
fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
const STALE_FILES_HEADER: &str = "These files changed since last read:";
let mut stale_message = String::new();
for stale_file in self.action_log.read(cx).stale_buffers(cx) {
let Some(file) = stale_file.read(cx).file() else {
continue;
};
if stale_message.is_empty() {
write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
}
writeln!(&mut stale_message, "- {}", file.path().display()).ok();
}
if !stale_message.is_empty() {
let context_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![stale_message.into()],
cache: false,
};
messages.push(context_message);
}
}
pub fn stream_completion(
&mut self,
request: LanguageModelRequest,
@@ -802,10 +564,8 @@ impl Thread {
) {
let pending_completion_id = post_inc(&mut self.completion_count);
let task = cx.spawn(async move |thread, cx| {
let task = cx.spawn(|thread, mut cx| async move {
let stream = model.stream_completion(request, &cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
let stream_completion = async {
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
@@ -814,7 +574,7 @@ impl Thread {
while let Some(event) = events.next().await {
let event = event?;
thread.update(cx, |thread, cx| {
thread.update(&mut cx, |thread, cx| {
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
thread.insert_message(Role::Assistant, String::new(), cx);
@@ -853,17 +613,13 @@ impl Thread {
.rfind(|message| message.role == Role::Assistant)
{
if tool_use.name.as_ref() == ScriptingTool::NAME {
thread.scripting_tool_use.request_tool_use(
last_assistant_message.id,
tool_use,
cx,
);
thread
.scripting_tool_use
.request_tool_use(last_assistant_message.id, tool_use);
} else {
thread.tool_use.request_tool_use(
last_assistant_message.id,
tool_use,
cx,
);
thread
.tool_use
.request_tool_use(last_assistant_message.id, tool_use);
}
}
}
@@ -877,7 +633,7 @@ impl Thread {
smol::future::yield_now().await;
}
thread.update(cx, |thread, cx| {
thread.update(&mut cx, |thread, cx| {
thread
.pending_completions
.retain(|completion| completion.id != pending_completion_id);
@@ -893,52 +649,31 @@ impl Thread {
let result = stream_completion.await;
thread
.update(cx, |thread, cx| {
match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
cx.emit(ThreadEvent::UsePendingTools);
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
},
Err(error) => {
if error.is::<PaymentRequiredError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
} else if error.is::<MaxMonthlySpendReachedError>() {
cx.emit(ThreadEvent::ShowError(
ThreadError::MaxMonthlySpendReached,
));
} else {
let error_message = error
.chain()
.map(|err| err.to_string())
.collect::<Vec<_>>()
.join("\n");
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
header: "Error interacting with language model".into(),
message: SharedString::from(error_message.clone()),
}));
}
thread.cancel_last_completion(cx);
.update(&mut cx, |thread, cx| match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
cx.emit(ThreadEvent::UsePendingTools);
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
},
Err(error) => {
if error.is::<PaymentRequiredError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
} else if error.is::<MaxMonthlySpendReachedError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
} else {
let error_message = error
.chain()
.map(|err| err.to_string())
.collect::<Vec<_>>()
.join("\n");
cx.emit(ThreadEvent::ShowError(ThreadError::Message(
SharedString::from(error_message.clone()),
)));
}
}
cx.emit(ThreadEvent::DoneStreaming);
if let Ok(initial_usage) = initial_token_usage {
let usage = thread.cumulative_token_usage.clone() - initial_usage;
telemetry::event!(
"Assistant Thread Completion",
thread_id = thread.id().to_string(),
model = model.telemetry_id(),
model_provider = model.provider_id().to_string(),
input_tokens = usage.input_tokens,
output_tokens = usage.output_tokens,
cache_creation_input_tokens = usage.cache_creation_input_tokens,
cache_read_input_tokens = usage.cache_read_input_tokens,
);
thread.cancel_last_completion();
}
})
.ok();
@@ -972,7 +707,7 @@ impl Thread {
cache: false,
});
self.pending_summary = cx.spawn(async move |this, cx| {
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let stream = model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
@@ -989,7 +724,7 @@ impl Thread {
}
}
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
if !new_summary.is_empty() {
this.summary = Some(new_summary.into());
}
@@ -1000,14 +735,10 @@ impl Thread {
anyhow::Ok(())
}
.log_err()
.await
});
}
pub fn use_pending_tools(
&mut self,
cx: &mut Context<Self>,
) -> impl IntoIterator<Item = PendingToolUse> {
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
let request = self.to_completion_request(RequestKind::Chat, cx);
let pending_tool_uses = self
.tool_use
@@ -1017,22 +748,11 @@ impl Thread {
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses.iter() {
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(
tool_use.input.clone(),
&request.messages,
self.project.clone(),
self.action_log.clone(),
cx,
);
let task = tool.run(tool_use.input, &request.messages, self.project.clone(), cx);
self.insert_tool_output(
tool_use.id.clone(),
tool_use.ui_text.clone().into(),
task,
cx,
);
self.insert_tool_output(tool_use.id.clone(), task, cx);
}
}
@@ -1044,8 +764,8 @@ impl Thread {
.cloned()
.collect::<Vec<_>>();
for scripting_tool_use in pending_scripting_tool_uses.iter() {
let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) {
for scripting_tool_use in pending_scripting_tool_uses {
let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) {
Err(err) => Task::ready(Err(err.into())),
Ok(input) => {
let (script_id, script_task) =
@@ -1054,10 +774,10 @@ impl Thread {
});
let session = self.scripting_session.clone();
cx.spawn(async move |_, cx| {
cx.spawn(|_, cx| async move {
script_task.await;
let message = session.read_with(cx, |session, _cx| {
let message = session.read_with(&cx, |session, _cx| {
// Using a id to get the script output seems impractical.
// Why not just include it in the Task result?
// This is because we'll later report the script state as it runs,
@@ -1072,29 +792,22 @@ impl Thread {
}
};
let ui_text: SharedString = scripting_tool_use.name.clone().into();
self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx);
self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
}
pending_tool_uses
.into_iter()
.chain(pending_scripting_tool_uses)
}
pub fn insert_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
output: Task<Result<String>>,
cx: &mut Context<Self>,
) {
let insert_output_task = cx.spawn({
let insert_output_task = cx.spawn(|thread, mut cx| {
let tool_use_id = tool_use_id.clone();
async move |thread, cx| {
async move {
let output = output.await;
thread
.update(cx, |thread, cx| {
.update(&mut cx, |thread, cx| {
let pending_tool_use = thread
.tool_use
.insert_tool_output(tool_use_id.clone(), output);
@@ -1102,7 +815,6 @@ impl Thread {
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
canceled: false,
});
})
.ok();
@@ -1110,22 +822,21 @@ impl Thread {
});
self.tool_use
.run_pending_tool(tool_use_id, ui_text, insert_output_task);
.run_pending_tool(tool_use_id, insert_output_task);
}
pub fn insert_scripting_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
output: Task<Result<String>>,
cx: &mut Context<Self>,
) {
let insert_output_task = cx.spawn({
let insert_output_task = cx.spawn(|thread, mut cx| {
let tool_use_id = tool_use_id.clone();
async move |thread, cx| {
async move {
let output = output.await;
thread
.update(cx, |thread, cx| {
.update(&mut cx, |thread, cx| {
let pending_tool_use = thread
.scripting_tool_use
.insert_tool_output(tool_use_id.clone(), output);
@@ -1133,7 +844,6 @@ impl Thread {
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
canceled: false,
});
})
.ok();
@@ -1141,20 +851,14 @@ impl Thread {
});
self.scripting_tool_use
.run_pending_tool(tool_use_id, ui_text, insert_output_task);
.run_pending_tool(tool_use_id, insert_output_task);
}
pub fn attach_tool_results(
pub fn send_tool_results_to_model(
&mut self,
updated_context: Vec<ContextSnapshot>,
model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) {
self.context.extend(
updated_context
.into_iter()
.map(|context| (context.id, context)),
);
// Insert a user message to contain the tool results.
self.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back
@@ -1162,28 +866,19 @@ impl Thread {
// so for now we provide some text to keep the model on track.
"Here are the tool results.",
Vec::new(),
None,
cx,
);
self.send_to_model(model, RequestKind::Chat, cx);
}
/// Cancels the last pending completion, if there are any pending.
///
/// Returns whether a completion was canceled.
pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
if self.pending_completions.pop().is_some() {
pub fn cancel_last_completion(&mut self) -> bool {
if let Some(_last_completion) = self.pending_completions.pop() {
true
} else {
let mut canceled = false;
for pending_tool_use in self.tool_use.cancel_pending() {
canceled = true;
cx.emit(ThreadEvent::ToolFinished {
tool_use_id: pending_tool_use.id.clone(),
pending_tool_use: Some(pending_tool_use),
canceled: true,
});
}
canceled
false
}
}
@@ -1219,14 +914,13 @@ impl Thread {
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Arc<ProjectSnapshot>> {
let git_store = project.read(cx).git_store().clone();
let worktree_snapshots: Vec<_> = project
.read(cx)
.visible_worktrees(cx)
.map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
.map(|worktree| Self::worktree_snapshot(worktree, cx))
.collect();
cx.spawn(async move |_, cx| {
cx.spawn(move |_, cx| async move {
let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
let mut unsaved_buffers = Vec::new();
@@ -1252,12 +946,8 @@ impl Thread {
})
}
fn worktree_snapshot(
worktree: Entity<project::Worktree>,
git_store: Entity<GitStore>,
cx: &App,
) -> Task<WorktreeSnapshot> {
cx.spawn(async move |cx| {
fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
cx.spawn(move |cx| async move {
// Get worktree path and snapshot
let worktree_info = cx.update(|app_cx| {
let worktree = worktree.read(app_cx);
@@ -1273,40 +963,42 @@ impl Thread {
};
};
let repo_info = git_store
.update(cx, |git_store, cx| {
git_store
.repositories()
.values()
.find(|repo| repo.read(cx).worktree_id == snapshot.id())
.and_then(|repo| {
let repo = repo.read(cx);
Some((repo.branch().cloned(), repo.local_repository()?))
})
})
.ok()
.flatten();
// Extract git information
let git_state = match repo_info {
let git_state = match snapshot.repositories().first() {
None => None,
Some((branch, repo)) => {
let current_branch = branch.map(|branch| branch.name.to_string());
let remote_url = repo.remote_url("origin");
let head_sha = repo.head_sha();
Some(repo_entry) => {
// Get branch information
let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
// Get diff asynchronously
let diff = repo
.diff(git::repository::DiffType::HeadToWorktree, cx.clone())
.await
.ok();
// Get repository info
let repo_result = worktree.read_with(&cx, |worktree, _cx| {
if let project::Worktree::Local(local_worktree) = &worktree {
local_worktree.get_local_repo(repo_entry).map(|local_repo| {
let repo = local_repo.repo();
(repo.remote_url("origin"), repo.head_sha(), repo.clone())
})
} else {
None
}
});
Some(GitState {
remote_url,
head_sha,
current_branch,
diff,
})
match repo_result {
Ok(Some((remote_url, head_sha, repository))) => {
// Get diff asynchronously
let diff = repository
.diff(git::repository::DiffType::HeadToWorktree, cx)
.await
.ok();
Some(GitState {
remote_url,
head_sha,
current_branch,
diff,
})
}
Err(_) | Ok(None) => None,
}
}
};
@@ -1317,7 +1009,7 @@ impl Thread {
})
}
pub fn to_markdown(&self, cx: &App) -> Result<String> {
pub fn to_markdown(&self) -> Result<String> {
let mut markdown = Vec::new();
if let Some(summary) = self.summary() {
@@ -1336,7 +1028,7 @@ impl Thread {
)?;
writeln!(markdown, "{}\n", message.text)?;
for tool_use in self.tool_uses_for_message(message.id, cx) {
for tool_use in self.tool_uses_for_message(message.id) {
writeln!(
markdown,
"**Use Tool: {} ({})**",
@@ -1365,14 +1057,6 @@ impl Thread {
Ok(String::from_utf8_lossy(&markdown).to_string())
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn project(&self) -> &Entity<Project> {
&self.project
}
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone()
}
@@ -1382,10 +1066,7 @@ impl Thread {
pub enum ThreadError {
PaymentRequired,
MaxMonthlySpendReached,
Message {
header: SharedString,
message: SharedString,
},
Message(SharedString),
}
#[derive(Debug, Clone)]
@@ -1393,7 +1074,6 @@ pub enum ThreadEvent {
ShowError(ThreadError),
StreamedCompletion,
StreamedAssistantText(MessageId, String),
DoneStreaming,
MessageAdded(MessageId),
MessageEdited(MessageId),
MessageDeleted(MessageId),
@@ -1404,10 +1084,7 @@ pub enum ThreadEvent {
tool_use_id: LanguageModelToolUseId,
/// The pending tool use that corresponds to this tool.
pending_tool_use: Option<PendingToolUse>,
/// Whether the tool was canceled by the user.
canceled: bool,
},
CheckpointChanged,
}
impl EventEmitter<ThreadEvent> for Thread {}

View File

@@ -20,7 +20,7 @@ use prompt_store::PromptBuilder;
use serde::{Deserialize, Serialize};
use util::ResultExt as _;
use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadId};
pub fn init(cx: &mut App) {
ThreadsDatabase::init(cx);
@@ -65,14 +65,6 @@ impl ThreadStore {
Ok(this)
}
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
self.context_server_manager.clone()
}
pub fn tools(&self) -> Arc<ToolWorkingSet> {
self.tools.clone()
}
/// Returns the number of threads.
pub fn thread_count(&self) -> usize {
self.threads.len()
@@ -106,14 +98,14 @@ impl ThreadStore {
) -> Task<Result<Entity<Thread>>> {
let id = id.clone();
let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
let database = database_future.await.map_err(|err| anyhow!(err))?;
let thread = database
.try_find_thread(id.clone())
.await?
.ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
let thread = this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
cx.new(|cx| {
Thread::deserialize(
id.clone(),
@@ -124,19 +116,7 @@ impl ThreadStore {
cx,
)
})
})?;
let (system_prompt_context, load_error) = thread
.update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
.await;
thread.update(cx, |thread, cx| {
thread.set_system_prompt_context(system_prompt_context);
if let Some(load_error) = load_error {
cx.emit(ThreadEvent::ShowError(load_error));
}
})?;
Ok(thread)
})
})
}
@@ -145,23 +125,23 @@ impl ThreadStore {
thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
let serialized_thread = serialized_thread.await?;
let database = database_future.await.map_err(|err| anyhow!(err))?;
database.save_thread(metadata, serialized_thread).await?;
this.update(cx, |this, cx| this.reload(cx))?.await
this.update(&mut cx, |this, cx| this.reload(cx))?.await
})
}
pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
let id = id.clone();
let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
let database = database_future.await.map_err(|err| anyhow!(err))?;
database.delete_thread(id.clone()).await?;
this.update(cx, |this, _cx| {
this.update(&mut cx, |this, _cx| {
this.threads.retain(|thread| thread.id != id)
})
})
@@ -169,14 +149,14 @@ impl ThreadStore {
pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
let threads = database_future
.await
.map_err(|err| anyhow!(err))?
.list_threads()
.await?;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.threads = threads;
cx.notify();
})
@@ -205,7 +185,7 @@ impl ThreadStore {
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
|this, mut cx| async move {
let Some(protocol) = server.client() else {
return;
};
@@ -230,7 +210,7 @@ impl ThreadStore {
})
.collect::<Vec<_>>();
this.update(cx, |this, _cx| {
this.update(&mut cx, |this, _cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
})
.log_err();

View File

@@ -1,50 +1,17 @@
use std::sync::Arc;
use assistant_settings::{AgentProfile, AssistantSettings};
use assistant_tool::{ToolSource, ToolWorkingSet};
use collections::HashMap;
use gpui::{Entity, Subscription};
use gpui::Entity;
use scripting_tool::ScriptingTool;
use settings::{Settings as _, SettingsStore};
use ui::{prelude::*, ContextMenu, PopoverMenu, Tooltip};
pub struct ToolSelector {
profiles: HashMap<Arc<str>, AgentProfile>,
tools: Arc<ToolWorkingSet>,
_subscriptions: Vec<Subscription>,
}
impl ToolSelector {
pub fn new(tools: Arc<ToolWorkingSet>, cx: &mut Context<Self>) -> Self {
let settings_subscription = cx.observe_global::<SettingsStore>(move |this, cx| {
this.refresh_profiles(cx);
});
let mut this = Self {
profiles: HashMap::default(),
tools,
_subscriptions: vec![settings_subscription],
};
this.refresh_profiles(cx);
this
}
fn refresh_profiles(&mut self, cx: &mut Context<Self>) {
let settings = AssistantSettings::get_global(cx);
let mut profiles = settings.profiles.clone();
let read_only = AgentProfile::read_only();
if !profiles.contains_key(read_only.name.as_ref()) {
profiles.insert(read_only.name.clone().into(), read_only);
}
let code_writer = AgentProfile::code_writer();
if !profiles.contains_key(code_writer.name.as_ref()) {
profiles.insert(code_writer.name.clone().into(), code_writer);
}
self.profiles = profiles;
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
Self { tools }
}
fn build_context_menu(
@@ -52,40 +19,13 @@ impl ToolSelector {
window: &mut Window,
cx: &mut Context<Self>,
) -> Entity<ContextMenu> {
let profiles = self.profiles.clone();
let tool_set = self.tools.clone();
ContextMenu::build_persistent(window, cx, move |mut menu, _window, cx| {
ContextMenu::build(window, cx, |mut menu, _window, cx| {
let icon_position = IconPosition::End;
let tools_by_source = self.tools.tools_by_source(cx);
menu = menu.header("Profiles");
for (_id, profile) in profiles.clone() {
menu = menu.toggleable_entry(profile.name.clone(), false, icon_position, None, {
let tools = tool_set.clone();
move |_window, cx| {
tools.disable_source(ToolSource::Native, cx);
tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
);
if profile.tools.contains_key(ScriptingTool::NAME) {
tools.enable_scripting_tool();
}
}
});
}
menu = menu.separator();
let tools_by_source = tool_set.tools_by_source(cx);
let all_tools_enabled = tool_set.are_all_tools_enabled();
let all_tools_enabled = self.tools.are_all_tools_enabled();
menu = menu.toggleable_entry("All Tools", all_tools_enabled, icon_position, None, {
let tools = tool_set.clone();
let tools = self.tools.clone();
move |_window, cx| {
if all_tools_enabled {
tools.disable_all_tools(cx);
@@ -101,7 +41,7 @@ impl ToolSelector {
.map(|tool| {
let source = tool.source();
let name = tool.name().into();
let is_enabled = tool_set.is_enabled(&source, &name);
let is_enabled = self.tools.is_enabled(&source, &name);
(source, name, is_enabled)
})
@@ -111,7 +51,7 @@ impl ToolSelector {
tools.push((
ToolSource::Native,
ScriptingTool::NAME.into(),
tool_set.is_scripting_tool_enabled(),
self.tools.is_scripting_tool_enabled(),
));
tools.sort_by(|(_, name_a, _), (_, name_b, _)| name_a.cmp(name_b));
}
@@ -120,7 +60,7 @@ impl ToolSelector {
ToolSource::Native => menu.separator().header("Zed Tools"),
ToolSource::ContextServer { id } => {
let all_tools_from_source_enabled =
tool_set.are_all_tools_from_source_enabled(&source);
self.tools.are_all_tools_from_source_enabled(&source);
menu.separator().header(id).toggleable_entry(
"All Tools",
@@ -128,7 +68,7 @@ impl ToolSelector {
icon_position,
None,
{
let tools = tool_set.clone();
let tools = self.tools.clone();
let source = source.clone();
move |_window, cx| {
if all_tools_from_source_enabled {
@@ -144,7 +84,7 @@ impl ToolSelector {
for (source, name, is_enabled) in tools {
menu = menu.toggleable_entry(name.clone(), is_enabled, icon_position, None, {
let tools = tool_set.clone();
let tools = self.tools.clone();
move |_window, _cx| {
if name.as_ref() == ScriptingTool::NAME {
if is_enabled {

View File

@@ -1,11 +1,10 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use collections::HashMap;
use futures::future::Shared;
use futures::FutureExt as _;
use gpui::{App, SharedString, Task};
use gpui::{SharedString, Task};
use language_model::{
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, Role,
@@ -18,7 +17,6 @@ use crate::thread_store::SerializedMessage;
pub struct ToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub ui_text: SharedString,
pub status: ToolUseStatus,
pub input: serde_json::Value,
}
@@ -32,7 +30,6 @@ pub enum ToolUseStatus {
}
pub struct ToolUseState {
tools: Arc<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
@@ -40,9 +37,8 @@ pub struct ToolUseState {
}
impl ToolUseState {
pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
pub fn new() -> Self {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
@@ -54,11 +50,10 @@ impl ToolUseState {
///
/// Accepts a function to filter the tools that should be used to populate the state.
pub fn from_serialized_messages(
tools: Arc<ToolWorkingSet>,
messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self {
let mut this = Self::new(tools);
let mut this = Self::new();
let mut tool_names_by_id = HashMap::default();
for message in messages {
@@ -123,27 +118,11 @@ impl ToolUseState {
this
}
pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
let mut pending_tools = Vec::new();
for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id,
content: "Tool canceled by user".into(),
is_error: true,
},
);
pending_tools.push(tool_use.clone());
}
pending_tools
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
return Vec::new();
};
@@ -178,7 +157,6 @@ impl ToolUseState {
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
input: tool_use.input.clone(),
status,
})
@@ -187,19 +165,6 @@ impl ToolUseState {
tool_uses
}
pub fn tool_ui_label(
&self,
tool_name: &str,
input: &serde_json::Value,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.tool(tool_name, cx) {
tool.ui_text(input).into()
} else {
"Unknown tool".into()
}
}
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
let empty = Vec::new();
@@ -217,18 +182,10 @@ impl ToolUseState {
.map_or(false, |results| !results.is_empty())
}
pub fn tool_result(
&self,
tool_use_id: &LanguageModelToolUseId,
) -> Option<&LanguageModelToolResult> {
self.tool_results.get(tool_use_id)
}
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,
tool_use: LanguageModelToolUse,
cx: &App,
) {
self.tool_uses_by_assistant_message
.entry(assistant_message_id)
@@ -248,24 +205,15 @@ impl ToolUseState {
PendingToolUse {
assistant_message_id,
id: tool_use.id,
name: tool_use.name.clone(),
ui_text: self
.tool_ui_label(&tool_use.name, &tool_use.input, cx)
.into(),
name: tool_use.name,
input: tool_use.input,
status: PendingToolUseStatus::Idle,
},
);
}
pub fn run_pending_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
task: Task<()>,
) {
pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.ui_text = ui_text.into();
tool_use.status = PendingToolUseStatus::Running {
_task: task.shared(),
};
@@ -278,12 +226,12 @@ impl ToolUseState {
output: Result<String>,
) -> Option<PendingToolUse> {
match output {
Ok(tool_result) => {
Ok(output) => {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
content: tool_result.into(),
content: output.into(),
is_error: false,
},
);
@@ -315,17 +263,9 @@ impl ToolUseState {
) {
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
for tool_use in tool_uses {
if self.tool_results.contains_key(&tool_use.id) {
// Do not send tool uses until they are completed
request_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
} else {
log::debug!(
"skipped tool use {:?} because it is still pending",
tool_use
);
}
request_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
}
}
}
@@ -338,19 +278,9 @@ impl ToolUseState {
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
for tool_use_id in tool_uses {
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
request_message.content.push(MessageContent::ToolResult(
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
is_error: tool_result.is_error,
content: if tool_result.content.is_empty() {
// Surprisingly, the API fails if we return an empty string here.
// It thinks we are sending a tool use without a tool result.
"<Tool returned an empty string>".into()
} else {
tool_result.content.clone()
},
},
));
request_message
.content
.push(MessageContent::ToolResult(tool_result.clone()));
}
}
}
@@ -364,7 +294,6 @@ pub struct PendingToolUse {
#[allow(unused)]
pub assistant_message_id: MessageId,
pub name: Arc<str>,
pub ui_text: Arc<str>,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
}

View File

@@ -1144,9 +1144,9 @@ impl AssistantContext {
fn set_language(&mut self, cx: &mut Context<Self>) {
let markdown = self.language_registry.language_for_name("Markdown");
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
let markdown = markdown.await?;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.buffer
.update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx));
})
@@ -1188,7 +1188,7 @@ impl AssistantContext {
return;
};
let debounce = self.token_count.is_some();
self.pending_token_count = cx.spawn(async move |this, cx| {
self.pending_token_count = cx.spawn(|this, mut cx| {
async move {
if debounce {
cx.background_executor()
@@ -1197,14 +1197,13 @@ impl AssistantContext {
}
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
this.start_cache_warming(&model, cx);
cx.notify()
})
}
.log_err()
.await
});
}
@@ -1343,7 +1342,7 @@ impl AssistantContext {
};
let model = Arc::clone(model);
self.pending_cache_warming_task = cx.spawn(async move |this, cx| {
self.pending_cache_warming_task = cx.spawn(|this, mut cx| {
async move {
match model.stream_completion(request, &cx).await {
Ok(mut stream) => {
@@ -1354,14 +1353,13 @@ impl AssistantContext {
log::warn!("Cache warming failed: {}", e);
}
};
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.update_cache_status_for_completion(cx);
})
.ok();
anyhow::Ok(())
}
.log_err()
.await
});
}
@@ -1918,7 +1916,7 @@ impl AssistantContext {
});
self.reparse(cx);
let insert_output_task = cx.spawn(async move |this, cx| {
let insert_output_task = cx.spawn(|this, mut cx| async move {
let run_command = async {
let mut stream = output.await?;
@@ -1935,7 +1933,7 @@ impl AssistantContext {
while let Some(event) = stream.next().await {
let event = event?;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.buffer.update(cx, |buffer, _cx| {
buffer.finalize_last_transaction();
buffer.start_transaction()
@@ -2036,7 +2034,7 @@ impl AssistantContext {
})?;
}
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction();
buffer.start_transaction();
@@ -2082,7 +2080,7 @@ impl AssistantContext {
let command_result = run_command.await;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
let version = this.version.clone();
let timestamp = this.next_timestamp();
let Some(invoked_slash_command) = this.invoked_slash_commands.get_mut(&command_id)
@@ -2212,7 +2210,7 @@ impl AssistantContext {
let pending_completion_id = post_inc(&mut self.completion_count);
let task = cx.spawn({
async move |this, cx| {
|this, mut cx| async move {
let stream = model.stream_completion(request, &cx);
let assistant_message_id = assistant_message.id;
let mut response_latency = None;
@@ -2227,7 +2225,7 @@ impl AssistantContext {
}
let event = event?;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
let message_ix = this
.message_anchors
.iter()
@@ -2266,7 +2264,7 @@ impl AssistantContext {
})?;
smol::future::yield_now().await;
}
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.pending_completions
.retain(|completion| completion.id != pending_completion_id);
this.summarize(false, cx);
@@ -2278,7 +2276,7 @@ impl AssistantContext {
let result = stream_completion.await;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
let error_message = if let Some(error) = result.as_ref().err() {
if error.is::<PaymentRequiredError>() {
cx.emit(ContextEvent::ShowPaymentRequiredError);
@@ -2788,7 +2786,7 @@ impl AssistantContext {
cache: false,
});
self.pending_summary = cx.spawn(async move |this, cx| {
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let stream = model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
@@ -2797,7 +2795,7 @@ impl AssistantContext {
while let Some(message) = messages.stream.next().await {
let text = message?;
let mut lines = text.lines();
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
let version = this.version.clone();
let timestamp = this.next_timestamp();
let summary = this.summary.get_or_insert(ContextSummary::default());
@@ -2821,7 +2819,7 @@ impl AssistantContext {
}
}
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
let version = this.version.clone();
let timestamp = this.next_timestamp();
if let Some(summary) = this.summary.as_mut() {
@@ -2839,7 +2837,6 @@ impl AssistantContext {
anyhow::Ok(())
}
.log_err()
.await
});
}
}
@@ -2946,12 +2943,12 @@ impl AssistantContext {
return;
}
self.pending_save = cx.spawn(async move |this, cx| {
self.pending_save = cx.spawn(|this, mut cx| async move {
if let Some(debounce) = debounce {
cx.background_executor().timer(debounce).await;
}
let (old_path, summary) = this.read_with(cx, |this, _| {
let (old_path, summary) = this.read_with(&cx, |this, _| {
let path = this.path.clone();
let summary = if let Some(summary) = this.summary.as_ref() {
if summary.done {
@@ -2966,7 +2963,7 @@ impl AssistantContext {
})?;
if let Some(summary) = summary {
let context = this.read_with(cx, |this, cx| this.serialize(cx))?;
let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
let mut discriminant = 1;
let mut new_path;
loop {
@@ -2998,7 +2995,7 @@ impl AssistantContext {
}
}
this.update(cx, |this, _| this.path = Some(new_path))?;
this.update(&mut cx, |this, _| this.path = Some(new_path))?;
}
Ok(())

View File

@@ -229,7 +229,6 @@ impl ContextEditor {
editor.set_show_git_diff_gutter(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_runnables(false, cx);
editor.set_show_breakpoints(false, cx);
editor.set_show_wrap_guides(false, cx);
editor.set_show_indent_guides(false, cx);
editor.set_completion_provider(Some(Box::new(completion_provider)));
@@ -907,7 +906,7 @@ impl ContextEditor {
if editor_state.opened_patch != patch {
state.update_task = Some({
let this = this.clone();
cx.spawn_in(window, async move |_, cx| {
cx.spawn_in(window, |_, cx| async move {
Self::update_patch_editor(this.clone(), patch, cx)
.await
.log_err();
@@ -1070,9 +1069,10 @@ impl ContextEditor {
})
.ok();
} else {
patch_state.update_task = Some(cx.spawn_in(window, async move |this, cx| {
Self::open_patch_editor(this, new_patch, cx).await.log_err();
}));
patch_state.update_task =
Some(cx.spawn_in(window, move |this, cx| async move {
Self::open_patch_editor(this, new_patch, cx).await.log_err();
}));
}
}
}
@@ -1102,10 +1102,10 @@ impl ContextEditor {
async fn open_patch_editor(
this: WeakEntity<Self>,
patch: AssistantPatch,
cx: &mut AsyncWindowContext,
mut cx: AsyncWindowContext,
) -> Result<()> {
let project = this.read_with(cx, |this, _| this.project.clone())?;
let resolved_patch = patch.resolve(project.clone(), cx).await;
let project = this.read_with(&cx, |this, _| this.project.clone())?;
let resolved_patch = patch.resolve(project.clone(), &mut cx).await;
let editor = cx.new_window_entity(|window, cx| {
let editor = ProposedChangesEditor::new(
@@ -1129,7 +1129,7 @@ impl ContextEditor {
editor
})?;
this.update(cx, |this, _| {
this.update(&mut cx, |this, _| {
if let Some(patch_state) = this.patches.get_mut(&patch.range) {
patch_state.editor = Some(PatchEditorState {
editor: editor.downgrade(),
@@ -1138,8 +1138,8 @@ impl ContextEditor {
patch_state.update_task.take();
}
})?;
this.read_with(cx, |this, _| this.workspace.clone())?
.update_in(cx, |workspace, window, cx| {
this.read_with(&cx, |this, _| this.workspace.clone())?
.update_in(&mut cx, |workspace, window, cx| {
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, false, window, cx)
})
.log_err();
@@ -1150,11 +1150,11 @@ impl ContextEditor {
async fn update_patch_editor(
this: WeakEntity<Self>,
patch: AssistantPatch,
cx: &mut AsyncWindowContext,
mut cx: AsyncWindowContext,
) -> Result<()> {
let project = this.update(cx, |this, _| this.project.clone())?;
let resolved_patch = patch.resolve(project.clone(), cx).await;
this.update_in(cx, |this, window, cx| {
let project = this.update(&mut cx, |this, _| this.project.clone())?;
let resolved_patch = patch.resolve(project.clone(), &mut cx).await;
this.update_in(&mut cx, |this, window, cx| {
let patch_state = this.patches.get_mut(&patch.range)?;
let locations = resolved_patch
@@ -1624,14 +1624,14 @@ impl ContextEditor {
.map(|path| Workspace::project_path_for_path(project.clone(), &path, false, cx))
.collect::<Vec<_>>();
cx.spawn(async move |_, cx| {
cx.spawn(move |_, cx| async move {
let mut paths = vec![];
let mut worktrees = vec![];
let opened_paths = futures::future::join_all(tasks).await;
for (worktree, project_path) in opened_paths.into_iter().flatten() {
let Ok(worktree_root_name) =
worktree.read_with(cx, |worktree, _| worktree.root_name().to_string())
worktree.read_with(&cx, |worktree, _| worktree.root_name().to_string())
else {
continue;
};
@@ -1648,12 +1648,12 @@ impl ContextEditor {
};
window
.spawn(cx, async move |cx| {
.spawn(cx, |mut cx| async move {
let (paths, dragged_file_worktrees) = paths.await;
let cmd_name = FileSlashCommand.name();
context_editor_view
.update_in(cx, |context_editor, window, cx| {
.update_in(&mut cx, |context_editor, window, cx| {
let file_argument = paths
.into_iter()
.map(|path| path.to_string_lossy().to_string())
@@ -2199,9 +2199,9 @@ impl ContextEditor {
.log_err();
if let Some(client) = client {
cx.spawn(async move |this, cx| {
client.authenticate_and_connect(true, cx).await?;
this.update(cx, |_, cx| cx.notify())
cx.spawn(|this, mut cx| async move {
client.authenticate_and_connect(true, &mut cx).await?;
this.update(&mut cx, |_, cx| cx.notify())
})
.detach_and_log_err(cx)
}
@@ -3160,10 +3160,10 @@ impl FollowableItem for ContextEditor {
assistant_panel_delegate.open_remote_context(workspace, context_id, window, cx)
});
Some(window.spawn(cx, async move |cx| {
Some(window.spawn(cx, |mut cx| async move {
let context_editor = context_editor_task.await?;
context_editor
.update_in(cx, |context_editor, window, cx| {
.update_in(&mut cx, |context_editor, window, cx| {
context_editor.remote_id = Some(id);
context_editor.editor.update(cx, |editor, cx| {
editor.apply_update_proto(

View File

@@ -164,9 +164,9 @@ impl PickerDelegate for SavedContextPickerDelegate {
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let search = self.store.read(cx).search(query, cx);
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
let matches = search.await;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
let host_contexts = this.delegate.store.read(cx).host_contexts();
this.delegate.matches = host_contexts
.iter()

View File

@@ -100,7 +100,7 @@ impl ContextStore {
let fs = project.read(cx).fs().clone();
let languages = project.read(cx).languages().clone();
let telemetry = project.read(cx).client().telemetry().clone();
cx.spawn(async move |cx| {
cx.spawn(|mut cx| async move {
const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
@@ -125,15 +125,16 @@ impl ContextStore {
languages,
slash_commands,
telemetry,
_watch_updates: cx.spawn(async move |this, cx| {
_watch_updates: cx.spawn(|this, mut cx| {
async move {
while events.next().await.is_some() {
this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
this.update(&mut cx, |this, cx| this.reload(cx))?
.await
.log_err();
}
anyhow::Ok(())
}
.log_err()
.await
}),
client_subscription: None,
_project_subscriptions: vec![
@@ -394,7 +395,7 @@ impl ContextStore {
let prompt_builder = self.prompt_builder.clone();
let slash_commands = self.slash_commands.clone();
let request = self.client.request(proto::CreateContext { project_id });
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
let response = request.await?;
let context_id = ContextId::from_proto(response.context_id);
let context_proto = response.context.context("invalid context")?;
@@ -420,8 +421,8 @@ impl ContextStore {
.collect::<Result<Vec<_>>>()
})
.await?;
context.update(cx, |context, cx| context.apply_ops(operations, cx))?;
this.update(cx, |this, cx| {
context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
this.update(&mut cx, |this, cx| {
if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
existing_context
} else {
@@ -456,7 +457,7 @@ impl ContextStore {
let prompt_builder = self.prompt_builder.clone();
let slash_commands = self.slash_commands.clone();
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
let saved_context = load.await?;
let context = cx.new(|cx| {
AssistantContext::deserialize(
@@ -470,7 +471,7 @@ impl ContextStore {
cx,
)
})?;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
existing_context
} else {
@@ -488,7 +489,7 @@ impl ContextStore {
) -> Task<Result<()>> {
let fs = self.fs.clone();
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
fs.remove_file(
&path,
RemoveOptions {
@@ -498,7 +499,7 @@ impl ContextStore {
)
.await?;
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.contexts.retain(|context| {
context
.upgrade()
@@ -564,7 +565,7 @@ impl ContextStore {
});
let prompt_builder = self.prompt_builder.clone();
let slash_commands = self.slash_commands.clone();
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
let response = request.await?;
let context_proto = response.context.context("invalid context")?;
let context = cx.new(|cx| {
@@ -589,8 +590,8 @@ impl ContextStore {
.collect::<Result<Vec<_>>>()
})
.await?;
context.update(cx, |context, cx| context.apply_ops(operations, cx))?;
this.update(cx, |this, cx| {
context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?;
this.update(&mut cx, |this, cx| {
if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
existing_context
} else {
@@ -699,12 +700,12 @@ impl ContextStore {
project_id,
contexts,
});
cx.spawn(async move |this, cx| {
cx.spawn(|this, cx| async move {
let response = request.await?;
let mut context_ids = Vec::new();
let mut operations = Vec::new();
this.read_with(cx, |this, cx| {
this.read_with(&cx, |this, cx| {
for context_version_proto in response.contexts {
let context_version = ContextVersion::from_proto(&context_version_proto);
let context_id = ContextId::from_proto(context_version_proto.context_id);
@@ -767,7 +768,7 @@ impl ContextStore {
fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let fs = self.fs.clone();
cx.spawn(async move |this, cx| {
cx.spawn(|this, mut cx| async move {
fs.create_dir(contexts_dir()).await?;
let mut paths = fs.read_dir(contexts_dir()).await?;
@@ -807,7 +808,7 @@ impl ContextStore {
}
contexts.sort_unstable_by_key(|context| Reverse(context.mtime));
this.update(cx, |this, cx| {
this.update(&mut cx, |this, cx| {
this.contexts_metadata = contexts;
cx.notify();
})
@@ -818,7 +819,7 @@ impl ContextStore {
cx.update_entity(
&self.context_server_manager,
|context_server_manager, cx| {
for server in context_server_manager.running_servers() {
for server in context_server_manager.servers() {
context_server_manager
.restart_server(&server.id(), cx)
.detach_and_log_err(cx);
@@ -849,7 +850,7 @@ impl ContextStore {
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
|this, mut cx| async move {
let Some(protocol) = server.client() else {
return;
};
@@ -874,7 +875,7 @@ impl ContextStore {
})
.collect::<Vec<_>>();
this.update( cx, |this, _cx| {
this.update(&mut cx, |this, _cx| {
this.context_server_slash_command_ids
.insert(server_id.clone(), slash_command_ids);
})

View File

@@ -59,7 +59,7 @@ impl SlashCommandCompletionProvider {
let command_name = command_name.to_string();
let editor = self.editor.clone();
let workspace = self.workspace.clone();
window.spawn(cx, async move |cx| {
window.spawn(cx, |mut cx| async move {
let matches = match_strings(
&candidates,
&command_name,

View File

@@ -100,7 +100,7 @@ impl PickerDelegate for SlashCommandDelegate {
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let all_commands = self.all_commands.clone();
cx.spawn_in(window, async move |this, cx| {
cx.spawn_in(window, |this, mut cx| async move {
let filtered_commands = cx
.background_spawn(async move {
if query.is_empty() {
@@ -119,7 +119,7 @@ impl PickerDelegate for SlashCommandDelegate {
})
.await;
this.update_in(cx, |this, window, cx| {
this.update_in(&mut cx, |this, window, cx| {
this.delegate.filtered_commands = filtered_commands;
this.delegate.set_selected_index(0, window, cx);
cx.notify();

View File

@@ -1,44 +0,0 @@
[package]
name = "assistant_eval"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[[bin]]
name = "assistant_eval"
path = "src/main.rs"
[dependencies]
anyhow.workspace = true
assistant2.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
clap.workspace = true
client.workspace = true
collections.workspace = true
context_server.workspace = true
env_logger.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
language_models.workspace = true
node_runtime.workspace = true
project.workspace = true
prompt_store.workspace = true
regex.workspace = true
release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
smol.workspace = true
util.workspace = true

View File

@@ -1,77 +0,0 @@
# Tool Evals
A framework for evaluating and benchmarking AI assistant performance in the Zed editor.
## Overview
Tool Evals provides a headless environment for running assistants evaluations on code repositories. It automates the process of:
1. Cloning and setting up test repositories
2. Sending prompts to language models
3. Allowing the assistant to use tools to modify code
4. Collecting metrics on performance
5. Evaluating results against known good solutions
## How It Works
The system consists of several key components:
- **Eval**: Loads test cases from the evaluation_data directory, clones repos, and executes evaluations
- **HeadlessAssistant**: Provides a headless environment for running the AI assistant
- **Judge**: Compares AI-generated diffs with reference solutions and scores their functional similarity
The evaluation flow:
1. An evaluation is loaded from the evaluation_data directory
2. The target repository is cloned and checked out at a specific commit
3. A HeadlessAssistant instance is created with the specified language model
4. The user prompt is sent to the assistant
5. The assistant responds and uses tools to modify code
6. Upon completion, a diff is generated from the changes
7. Results are saved including the diff, assistant's response, and performance metrics
8. If a reference solution exists, a Judge evaluates the similarity of the solution
## Setup Requirements
### Prerequisites
- Rust and Cargo
- Git
- Network access to clone repositories
- Appropriate API keys for language models and git services (Anthropic, GitHub, etc.)
### Environment Variables
Ensure you have the required API keys set, either from a dev run of Zed or via these environment variables:
- `ZED_ANTHROPIC_API_KEY` for Claude models
- `ZED_OPENAI_API_KEY` for OpenAI models
- `ZED_GITHUB_API_KEY` for GitHub API (or similar)
## Usage
### Running a Single Evaluation
To run a specific evaluation:
```bash
cargo run -p assistant_eval -- bubbletea-add-set-window-title
```
The arguments are regex patterns for the evaluation names to run, so to run all evaluations that contain `bubbletea`, run:
```bash
cargo run -p assistant_eval -- bubbletea
```
To run all evaluations:
```bash
cargo run -p assistant_eval -- --all
```
## Evaluation Data Structure
Each evaluation should be placed in the `evaluation_data` directory with the following structure:
* `prompt.txt`: The user's prompt.
* `original.diff`: The `git diff` of the change anticipated for this prompt.
* `setup.json`: Information about the repo used for the evaluation.

View File

@@ -1,52 +0,0 @@
// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
use std::process::Command;
fn main() {
if cfg!(target_os = "macos") {
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
// Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
// Seems to be required to enable Swift concurrency
println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
// Register exported Objective-C selectors, protocols, etc
println!("cargo:rustc-link-arg=-Wl,-ObjC");
}
// Populate git sha environment variable if git is available
println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
println!(
"cargo:rustc-env=TARGET={}",
std::env::var("TARGET").unwrap()
);
if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
if output.status.success() {
let git_sha = String::from_utf8_lossy(&output.stdout);
let git_sha = git_sha.trim();
println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
if let Ok(build_profile) = std::env::var("PROFILE") {
if build_profile == "release" {
// This is currently the best way to make `cargo build ...`'s build script
// to print something to stdout without extra verbosity.
println!(
"cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
);
}
}
}
}
#[cfg(target_os = "windows")]
{
#[cfg(target_env = "msvc")]
{
// todo(windows): This is to avoid stack overflow. Remove it when solved.
println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
}
}
}

View File

@@ -1,267 +0,0 @@
use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
use anyhow::anyhow;
use assistant2::RequestKind;
use collections::HashMap;
use gpui::{App, Task};
use language_model::{LanguageModel, TokenUsage};
use serde::{Deserialize, Serialize};
use std::{
fs,
io::Write,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use util::command::new_smol_command;
pub struct Eval {
pub name: String,
pub path: PathBuf,
pub repo_path: PathBuf,
pub eval_setup: EvalSetup,
pub user_prompt: String,
}
#[derive(Debug, Serialize)]
pub struct EvalOutput {
pub diff: String,
pub last_message: String,
pub elapsed_time: Duration,
pub assistant_response_count: usize,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub token_usage: TokenUsage,
}
#[derive(Deserialize)]
pub struct EvalSetup {
pub url: String,
pub base_sha: String,
}
impl Eval {
/// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
/// if necessary.
pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
let prompt_path = path.join("prompt.txt");
let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
let setup_path = path.join("setup.json");
let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
Ok(Eval {
name,
path,
repo_path,
eval_setup,
user_prompt,
})
}
pub fn run(
self,
app_state: Arc<HeadlessAppState>,
model: Arc<dyn LanguageModel>,
cx: &mut App,
) -> Task<anyhow::Result<EvalOutput>> {
cx.spawn(async move |cx| {
checkout_repo(&self.eval_setup, &self.repo_path).await?;
let (assistant, done_rx) =
cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
let _worktree = assistant
.update(cx, |assistant, cx| {
assistant.project.update(cx, |project, cx| {
project.create_worktree(&self.repo_path, true, cx)
})
})?
.await?;
let start_time = std::time::SystemTime::now();
let (system_prompt_context, load_error) = cx
.update(|cx| {
assistant
.read(cx)
.thread
.read(cx)
.load_system_prompt_context(cx)
})?
.await;
if let Some(load_error) = load_error {
return Err(anyhow!("{:?}", load_error));
};
assistant.update(cx, |assistant, cx| {
assistant.thread.update(cx, |thread, cx| {
let context = vec![];
thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
thread.set_system_prompt_context(system_prompt_context);
thread.send_to_model(model, RequestKind::Chat, cx);
});
})?;
done_rx.recv().await??;
let elapsed_time = start_time.elapsed()?;
let diff = query_git(&self.repo_path, vec!["diff"]).await?;
assistant.update(cx, |assistant, cx| {
let thread = assistant.thread.read(cx);
let last_message = thread.messages().last().unwrap();
if last_message.role != language_model::Role::Assistant {
return Err(anyhow!("Last message is not from assistant"));
}
let assistant_response_count = thread
.messages()
.filter(|message| message.role == language_model::Role::Assistant)
.count();
Ok(EvalOutput {
diff,
last_message: last_message.text.clone(),
elapsed_time,
assistant_response_count,
tool_use_counts: assistant.tool_use_counts.clone(),
token_usage: thread.cumulative_token_usage(),
})
})?
})
}
}
impl EvalOutput {
// Method to save the output to a directory
pub fn save_to_directory(
&self,
output_dir: &Path,
eval_output_value: String,
) -> anyhow::Result<()> {
// Create the output directory if it doesn't exist
fs::create_dir_all(&output_dir)?;
// Save the diff to a file
let diff_path = output_dir.join("diff.patch");
let mut diff_file = fs::File::create(&diff_path)?;
diff_file.write_all(self.diff.as_bytes())?;
// Save the last message to a file
let message_path = output_dir.join("assistant_response.txt");
let mut message_file = fs::File::create(&message_path)?;
message_file.write_all(self.last_message.as_bytes())?;
// Current metrics for this run
let current_metrics = serde_json::json!({
"elapsed_time_ms": self.elapsed_time.as_millis(),
"assistant_response_count": self.assistant_response_count,
"tool_use_counts": self.tool_use_counts,
"token_usage": self.token_usage,
"eval_output_value": eval_output_value,
});
// Get current timestamp in milliseconds
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_millis()
.to_string();
// Path to metrics file
let metrics_path = output_dir.join("metrics.json");
// Load existing metrics if the file exists, or create a new object
let mut historical_metrics = if metrics_path.exists() {
let metrics_content = fs::read_to_string(&metrics_path)?;
serde_json::from_str::<serde_json::Value>(&metrics_content)
.unwrap_or_else(|_| serde_json::json!({}))
} else {
serde_json::json!({})
};
// Add new run with timestamp as key
if let serde_json::Value::Object(ref mut map) = historical_metrics {
map.insert(timestamp, current_metrics);
}
// Write updated metrics back to file
let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
let mut metrics_file = fs::File::create(&metrics_path)?;
metrics_file.write_all(metrics_json.as_bytes())?;
Ok(())
}
}
fn repo_dir_name(url: &str) -> String {
url.trim_start_matches("https://")
.replace(|c: char| !c.is_alphanumeric(), "_")
}
async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
if !repo_path.exists() {
smol::unblock({
let repo_path = repo_path.to_path_buf();
|| std::fs::create_dir_all(repo_path)
})
.await?;
run_git(repo_path, vec!["init"]).await?;
run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
} else {
let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
if actual_origin != eval_setup.url {
return Err(anyhow!(
"remote origin {} does not match expected origin {}",
actual_origin,
eval_setup.url
));
}
// TODO: consider including "-x" to remove ignored files. The downside of this is that it will
// also remove build artifacts, and so prevent incremental reuse there.
run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
}
run_git(
repo_path,
vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
)
.await?;
run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
Ok(())
}
async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
let exit_status = new_smol_command("git")
.current_dir(repo_path)
.args(args.clone())
.status()
.await?;
if exit_status.success() {
Ok(())
} else {
Err(anyhow!(
"`git {}` failed with {}",
args.join(" "),
exit_status,
))
}
}
async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
let output = new_smol_command("git")
.current_dir(repo_path)
.args(args.clone())
.output()
.await?;
if output.status.success() {
Ok(String::from_utf8(output.stdout)?.trim().to_string())
} else {
Err(anyhow!(
"`git {}` failed with {}",
args.join(" "),
output.status
))
}
}

View File

@@ -1,236 +0,0 @@
use anyhow::anyhow;
use assistant2::{RequestKind, Thread, ThreadEvent, ThreadStore};
use assistant_tool::ToolWorkingSet;
use client::{Client, UserStore};
use collections::HashMap;
use futures::StreamExt;
use gpui::{prelude::*, App, AsyncApp, Entity, SemanticVersion, Subscription, Task};
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest,
};
use node_runtime::NodeRuntime;
use project::{Project, RealFs};
use prompt_store::PromptBuilder;
use settings::SettingsStore;
use smol::channel;
use std::sync::Arc;
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct HeadlessAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
// Additional fields not present in `workspace::AppState`.
pub prompt_builder: Arc<PromptBuilder>,
}
pub struct HeadlessAssistant {
pub thread: Entity<Thread>,
pub project: Entity<Project>,
#[allow(dead_code)]
pub thread_store: Entity<ThreadStore>,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub done_tx: channel::Sender<anyhow::Result<()>>,
_subscription: Subscription,
}
impl HeadlessAssistant {
pub fn new(
app_state: Arc<HeadlessAppState>,
cx: &mut App,
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
let env = None;
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
env,
cx,
);
let tools = Arc::new(ToolWorkingSet::default());
let thread_store =
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
let headless_thread = cx.new(move |cx| Self {
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
thread,
project,
thread_store,
tool_use_counts: HashMap::default(),
done_tx,
});
Ok((headless_thread, done_rx))
}
fn handle_thread_event(
&mut self,
thread: Entity<Thread>,
event: &ThreadEvent,
cx: &mut Context<Self>,
) {
match event {
ThreadEvent::ShowError(err) => self
.done_tx
.send_blocking(Err(anyhow!("{:?}", err)))
.unwrap(),
ThreadEvent::DoneStreaming => {
let thread = thread.read(cx);
if let Some(message) = thread.messages().last() {
println!("Message: {}", message.text,);
}
if thread.all_tools_finished() {
self.done_tx.send_blocking(Ok(())).unwrap()
}
}
ThreadEvent::UsePendingTools => {
thread.update(cx, |thread, cx| {
thread.use_pending_tools(cx);
});
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
if let Some(pending_tool_use) = pending_tool_use {
println!(
"Used tool {} with input: {}",
pending_tool_use.name, pending_tool_use.input
);
*self
.tool_use_counts
.entry(pending_tool_use.name.clone())
.or_insert(0) += 1;
}
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
println!("Tool result: {:?}", tool_result);
}
if thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
thread.update(cx, |thread, cx| {
thread.attach_tool_results(vec![], cx);
thread.send_to_model(model, RequestKind::Chat, cx);
});
}
}
}
_ => {}
}
}
}
pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
release_channel::init(SemanticVersion::default(), cx);
gpui_tokio::init(cx);
let mut settings_store = SettingsStore::new(cx);
settings_store
.set_default_settings(settings::default_settings().as_ref(), cx)
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
Project::init_settings(cx);
let client = Client::production(cx);
cx.set_http_client(client.http_client().clone());
let git_binary_path = None;
let fs = Arc::new(RealFs::new(git_binary_path));
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
assistant_tools::init(client.http_client().clone(), cx);
context_server::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
Arc::new(HeadlessAppState {
languages,
client,
user_store,
fs,
node_runtime: NodeRuntime::unavailable(),
prompt_builder,
})
}
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model = model_registry
.available_models(cx)
.find(|model| model.id().0 == model_name);
let Some(model) = model else {
return Err(anyhow!(
"No language model named {} was available. Available models: {}",
model_name,
model_registry
.available_models(cx)
.map(|model| model.id().0.clone())
.collect::<Vec<_>>()
.join(", ")
));
};
Ok(model)
}
pub fn authenticate_model_provider(
provider_id: LanguageModelProviderId,
cx: &mut App,
) -> Task<std::result::Result<(), AuthenticateError>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}
pub async fn send_language_model_request(
model: Arc<dyn LanguageModel>,
request: LanguageModelRequest,
cx: &mut AsyncApp,
) -> anyhow::Result<String> {
match model.stream_completion_text(request, &cx).await {
Ok(mut stream) => {
let mut full_response = String::new();
// Process the response stream
while let Some(chunk_result) = stream.stream.next().await {
match chunk_result {
Ok(chunk_str) => {
full_response.push_str(&chunk_str);
}
Err(err) => {
return Err(anyhow!(
"Error receiving response from language model: {err}"
));
}
}
}
Ok(full_response)
}
Err(err) => Err(anyhow!(
"Failed to get response from language model. Error was: {err}"
)),
}
}

View File

@@ -1,121 +0,0 @@
use crate::eval::EvalOutput;
use crate::headless_assistant::send_language_model_request;
use anyhow::anyhow;
use gpui::{App, Task};
use language_model::{
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
use std::{path::Path, sync::Arc};
pub struct Judge {
pub original_diff: Option<String>,
#[allow(dead_code)]
pub original_message: Option<String>,
pub model: Arc<dyn LanguageModel>,
}
impl Judge {
pub async fn load(eval_path: &Path, model: Arc<dyn LanguageModel>) -> anyhow::Result<Judge> {
let original_diff_path = eval_path.join("original.diff");
let original_diff = smol::unblock(move || {
if std::fs::exists(&original_diff_path)? {
anyhow::Ok(Some(std::fs::read_to_string(&original_diff_path)?))
} else {
anyhow::Ok(None)
}
});
let original_message_path = eval_path.join("original_message.txt");
let original_message = smol::unblock(move || {
if std::fs::exists(&original_message_path)? {
anyhow::Ok(Some(std::fs::read_to_string(&original_message_path)?))
} else {
anyhow::Ok(None)
}
});
Ok(Self {
original_diff: original_diff.await?,
original_message: original_message.await?,
model,
})
}
pub fn run(&self, eval_output: &EvalOutput, cx: &mut App) -> Task<anyhow::Result<String>> {
let Some(original_diff) = self.original_diff.as_ref() else {
return Task::ready(Err(anyhow!("No original.diff found")));
};
// TODO: check for empty diff?
let prompt = diff_comparison_prompt(&original_diff, &eval_output.diff);
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text(prompt)],
cache: false,
}],
temperature: Some(0.0),
tools: Vec::new(),
stop: Vec::new(),
};
let model = self.model.clone();
cx.spawn(async move |cx| send_language_model_request(model, request, cx).await)
}
}
pub fn diff_comparison_prompt(original_diff: &str, new_diff: &str) -> String {
format!(
r#"# Git Diff Similarity Evaluation Template
## Instructions
Compare the two diffs and score them between 0.0 and 1.0 based on their functional similarity.
- 1.0 = Perfect functional match (achieves identical results)
- 0.0 = No functional similarity whatsoever
## Evaluation Criteria
Please consider the following aspects in order of importance:
1. **Functional Equivalence (60%)**
- Do both diffs achieve the same end result?
- Are the changes functionally equivalent despite possibly using different approaches?
- Do the modifications address the same issues or implement the same features?
2. **Logical Structure (20%)**
- Are the logical flows similar?
- Do the modifications affect the same code paths?
- Are control structures (if/else, loops, etc.) modified in similar ways?
3. **Code Content (15%)**
- Are similar lines added/removed?
- Are the same variables, functions, or methods being modified?
- Are the same APIs or libraries being used?
4. **File Layout (5%)**
- Are the same files being modified?
- Are changes occurring in similar locations within files?
## Input
Original Diff:
```git
{}
```
New Diff:
```git
{}
```
## Output Format
THE ONLY OUTPUT SHOULD BE A SCORE BETWEEN 0.0 AND 1.0.
Example output:
0.85"#,
original_diff, new_diff
)
}

View File

@@ -1,243 +0,0 @@
mod eval;
mod headless_assistant;
mod judge;
use clap::Parser;
use eval::{Eval, EvalOutput};
use futures::future;
use gpui::{Application, AsyncApp};
use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState};
use itertools::Itertools;
use judge::Judge;
use language_model::{LanguageModel, LanguageModelRegistry};
use regex::Regex;
use reqwest_client::ReqwestClient;
use std::{cmp, path::PathBuf, sync::Arc};
#[derive(Parser, Debug)]
#[command(
name = "assistant_eval",
disable_version_flag = true,
before_help = "Tool eval runner"
)]
struct Args {
/// Regexes to match the names of evals to run.
eval_name_regexes: Vec<String>,
/// Runs all evals in `evaluation_data`, causes the regex to be ignored.
#[arg(long)]
all: bool,
/// Name of the model (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model_name: String,
/// Name of the editor model (default: value of `--model_name`).
#[arg(long)]
editor_model_name: Option<String>,
/// Name of the judge model (default: value of `--model_name`).
#[arg(long)]
judge_model_name: Option<String>,
/// Number of evaluations to run concurrently (default: 10)
#[arg(short, long, default_value = "10")]
concurrency: usize,
}
fn main() {
env_logger::init();
let args = Args::parse();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
let crate_dir = PathBuf::from("../zed-agent-bench");
let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
let repos_dir = crate_dir.join("repos");
if !repos_dir.exists() {
std::fs::create_dir_all(&repos_dir).unwrap();
}
let repos_dir = repos_dir.canonicalize().unwrap();
let all_evals = std::fs::read_dir(&evaluation_data_dir)
.unwrap()
.map(|path| path.unwrap().file_name().to_string_lossy().to_string())
.collect::<Vec<_>>();
let evals_to_run = if args.all {
all_evals
} else {
args.eval_name_regexes
.into_iter()
.map(|regex_string| Regex::new(&regex_string).unwrap())
.flat_map(|regex| {
all_evals
.iter()
.filter(|eval_name| regex.is_match(eval_name))
.cloned()
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
};
if evals_to_run.is_empty() {
panic!("Names of evals to run must be provided or `--all` specified");
}
println!("Will run the following evals: {evals_to_run:?}");
println!("Running up to {} evals concurrently", args.concurrency);
let editor_model_name = if let Some(model_name) = args.editor_model_name {
model_name
} else {
args.model_name.clone()
};
let judge_model_name = if let Some(model_name) = args.judge_model_name {
model_name
} else {
args.model_name.clone()
};
app.run(move |cx| {
let app_state = headless_assistant::init(cx);
let model = find_model(&args.model_name, cx).unwrap();
let editor_model = find_model(&editor_model_name, cx).unwrap();
let judge_model = find_model(&judge_model_name, cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_active_model(Some(model.clone()), cx);
registry.set_editor_model(Some(editor_model.clone()), cx);
});
let model_provider_id = model.provider_id();
let editor_model_provider_id = editor_model.provider_id();
let judge_model_provider_id = judge_model.provider_id();
cx.spawn(async move |cx| {
// Authenticate all model providers first
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
let eval_load_futures = evals_to_run
.into_iter()
.map(|eval_name| {
let eval_path = evaluation_data_dir.join(&eval_name);
let load_future = Eval::load(eval_name.clone(), eval_path, &repos_dir);
async move {
match load_future.await {
Ok(eval) => Some(eval),
Err(err) => {
// TODO: Persist errors / surface errors at the end.
println!("Error loading {eval_name}: {err}");
None
}
}
}
})
.collect::<Vec<_>>();
let loaded_evals = future::join_all(eval_load_futures)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>();
// The evals need to be loaded and grouped by URL before concurrently running, since
// evals that use the same remote URL will use the same working directory.
let mut evals_grouped_by_url: Vec<Vec<Eval>> = loaded_evals
.into_iter()
.map(|eval| (eval.eval_setup.url.clone(), eval))
.into_group_map()
.into_values()
.collect::<Vec<_>>();
// Sort groups in descending order, so that bigger groups start first.
evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
let result_futures = evals_grouped_by_url
.into_iter()
.map(|evals| {
let model = model.clone();
let judge_model = judge_model.clone();
let app_state = app_state.clone();
let cx = cx.clone();
async move {
let mut results = Vec::new();
for eval in evals {
let name = eval.name.clone();
println!("Starting eval named {}", name);
let result = run_eval(
eval,
model.clone(),
judge_model.clone(),
app_state.clone(),
cx.clone(),
)
.await;
results.push((name, result));
}
results
}
})
.collect::<Vec<_>>();
let results = future::join_all(result_futures)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>();
// Process results in order of completion
for (eval_name, result) in results {
match result {
Ok((eval_output, judge_output)) => {
println!("Generated diff for {eval_name}:\n");
println!("{}\n", eval_output.diff);
println!("Last message for {eval_name}:\n");
println!("{}\n", eval_output.last_message);
println!("Elapsed time: {:?}", eval_output.elapsed_time);
println!(
"Assistant response count: {}",
eval_output.assistant_response_count
);
println!("Tool use counts: {:?}", eval_output.tool_use_counts);
println!("Judge output for {eval_name}: {judge_output}");
}
Err(err) => {
// TODO: Persist errors / surface errors at the end.
println!("Error running {eval_name}: {err}");
}
}
}
cx.update(|cx| cx.quit()).unwrap();
})
.detach();
});
println!("Done running evals");
}
async fn run_eval(
eval: Eval,
model: Arc<dyn LanguageModel>,
judge_model: Arc<dyn LanguageModel>,
app_state: Arc<HeadlessAppState>,
cx: AsyncApp,
) -> anyhow::Result<(EvalOutput, String)> {
let path = eval.path.clone();
let judge = Judge::load(&path, judge_model).await?;
let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?;
let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?;
eval_output.save_to_directory(&path, judge_output.to_string())?;
Ok((eval_output, judge_output))
}

View File

@@ -14,7 +14,6 @@ path = "src/assistant_settings.rs"
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
collections.workspace = true
feature_flags.workspace = true
gpui.workspace = true
language_model.workspace = true

View File

@@ -1,59 +0,0 @@
use std::sync::Arc;
use collections::HashMap;
use gpui::SharedString;
/// A profile for the Zed Agent that controls its behavior.
#[derive(Debug, Clone)]
pub struct AgentProfile {
/// The name of the profile.
pub name: SharedString,
pub tools: HashMap<Arc<str>, bool>,
#[allow(dead_code)]
pub context_servers: HashMap<Arc<str>, ContextServerPreset>,
}
#[derive(Debug, Clone)]
pub struct ContextServerPreset {
#[allow(dead_code)]
pub tools: HashMap<Arc<str>, bool>,
}
impl AgentProfile {
pub fn read_only() -> Self {
Self {
name: "Read-only".into(),
tools: HashMap::from_iter([
("diagnostics".into(), true),
("fetch".into(), true),
("list-directory".into(), true),
("now".into(), true),
("path-search".into(), true),
("read-file".into(), true),
("regex-search".into(), true),
("thinking".into(), true),
]),
context_servers: HashMap::default(),
}
}
pub fn code_writer() -> Self {
Self {
name: "Code Writer".into(),
tools: HashMap::from_iter([
("bash".into(), true),
("delete-path".into(), true),
("diagnostics".into(), true),
("edit-files".into(), true),
("fetch".into(), true),
("list-directory".into(), true),
("now".into(), true),
("path-search".into(), true),
("read-file".into(), true),
("regex-search".into(), true),
("thinking".into(), true),
]),
context_servers: HashMap::default(),
}
}
}

View File

@@ -1,10 +1,7 @@
mod agent_profile;
use std::sync::Arc;
use ::open_ai::Model as OpenAiModel;
use anthropic::Model as AnthropicModel;
use collections::HashMap;
use deepseek::Model as DeepseekModel;
use feature_flags::FeatureFlagAppExt;
use gpui::{App, Pixels};
@@ -15,8 +12,6 @@ use schemars::{schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
pub use crate::agent_profile::*;
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum AssistantDockPosition {
@@ -71,7 +66,6 @@ pub struct AssistantSettings {
pub inline_alternatives: Vec<LanguageModelSelection>,
pub using_outdated_settings_version: bool,
pub enable_experimental_live_diffs: bool,
pub profiles: HashMap<Arc<str>, AgentProfile>,
}
impl AssistantSettings {
@@ -172,7 +166,6 @@ impl AssistantSettingsContent {
editor_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
profiles: None,
},
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
},
@@ -194,7 +187,6 @@ impl AssistantSettingsContent {
editor_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
profiles: None,
},
}
}
@@ -324,7 +316,6 @@ impl Default for VersionedAssistantSettingsContent {
editor_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
profiles: None,
})
}
}
@@ -361,8 +352,6 @@ pub struct AssistantSettingsContentV2 {
///
/// Default: false
enable_experimental_live_diffs: Option<bool>,
#[schemars(skip)]
profiles: Option<HashMap<Arc<str>, AgentProfileContent>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
@@ -399,12 +388,6 @@ impl Default for LanguageModelSelection {
}
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AgentProfileContent {
pub name: Arc<str>,
pub tools: HashMap<Arc<str>, bool>,
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContentV1 {
/// Whether the Assistant is enabled.
@@ -499,24 +482,6 @@ impl Settings for AssistantSettings {
&mut settings.enable_experimental_live_diffs,
value.enable_experimental_live_diffs,
);
merge(
&mut settings.profiles,
value.profiles.map(|profiles| {
profiles
.into_iter()
.map(|(id, profile)| {
(
id,
AgentProfile {
name: profile.name.into(),
tools: profile.tools,
context_servers: HashMap::default(),
},
)
})
.collect()
}),
);
}
Ok(settings)
@@ -581,7 +546,6 @@ mod tests {
default_width: None,
default_height: None,
enable_experimental_live_diffs: None,
profiles: None,
}),
)
},

View File

@@ -77,8 +77,8 @@ impl SlashCommand for AutoCommand {
let cx: &mut App = cx;
cx.spawn(async move |cx| {
let task = project_index.read_with(cx, |project_index, cx| {
cx.spawn(|cx: gpui::AsyncApp| async move {
let task = project_index.read_with(&cx, |project_index, cx| {
project_index.flush_summary_backlogs(cx)
})?;
@@ -117,9 +117,9 @@ impl SlashCommand for AutoCommand {
return Task::ready(Err(anyhow!("no project indexer")));
};
let task = window.spawn(cx, async move |cx| {
let task = window.spawn(cx, |cx| async move {
let summaries = project_index
.read_with(cx, |project_index, cx| project_index.all_summaries(cx))?
.read_with(&cx, |project_index, cx| project_index.all_summaries(cx))?
.await?;
commands_for_summaries(&summaries, &original_prompt, &cx).await

View File

@@ -186,7 +186,7 @@ impl SlashCommand for DiagnosticsSlashCommand {
let task = collect_diagnostics(workspace.read(cx).project().clone(), options, cx);
window.spawn(cx, async move |_| {
window.spawn(cx, move |_| async move {
task.await?
.map(|output| output.to_event_stream())
.ok_or_else(|| anyhow!("No diagnostics found"))
@@ -268,7 +268,7 @@ fn collect_diagnostics(
})
.collect();
cx.spawn(async move |cx| {
cx.spawn(|mut cx| async move {
let mut output = SlashCommandOutput::default();
if let Some(error_source) = error_source.as_ref() {
@@ -299,7 +299,7 @@ fn collect_diagnostics(
}
if let Some(buffer) = project_handle
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
.await
.log_err()
{

View File

@@ -241,7 +241,7 @@ fn collect_files(
.collect::<Vec<_>>();
let (events_tx, events_rx) = mpsc::unbounded();
cx.spawn(async move |cx| {
cx.spawn(|mut cx| async move {
for snapshot in snapshots {
let worktree_id = snapshot.id();
let mut directory_stack: Vec<Arc<Path>> = Vec::new();
@@ -352,7 +352,7 @@ fn collect_files(
)))?;
} else if entry.is_file() {
let Some(open_buffer_task) = project_handle
.update(cx, |project, cx| {
.update(&mut cx, |project, cx| {
project.open_buffer((worktree_id, &entry.path), cx)
})
.ok()
@@ -361,7 +361,7 @@ fn collect_files(
};
if let Some(buffer) = open_buffer_task.await.log_err() {
let mut output = SlashCommandOutput::default();
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let snapshot = buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
append_buffer_to_output(
&snapshot,
Some(&path_including_worktree_name),

View File

@@ -99,7 +99,7 @@ impl SlashCommand for ProjectSlashCommand {
return Task::ready(Err(anyhow::anyhow!("no project indexer")));
};
window.spawn(cx, async move |cx| {
window.spawn(cx, |mut cx| async move {
let current_model = current_model.ok_or_else(|| anyhow!("no model selected"))?;
let prompt =
@@ -123,7 +123,7 @@ impl SlashCommand for ProjectSlashCommand {
.search_queries;
let results = project_index
.read_with(cx, |project_index, cx| {
.read_with(&cx, |project_index, cx| {
project_index.search(search_queries.clone(), 25, cx)
})?
.await?;

View File

@@ -109,9 +109,9 @@ impl SlashCommand for SearchSlashCommand {
return Task::ready(Err(anyhow::anyhow!("no project indexer")));
};
window.spawn(cx, async move |cx| {
window.spawn(cx, |cx| async move {
let results = project_index
.read_with(cx, |project_index, cx| {
.read_with(&cx, |project_index, cx| {
project_index.search(vec![query.clone()], limit.unwrap_or(5), cx)
})?
.await?;

View File

@@ -86,7 +86,7 @@ impl SlashCommand for TabSlashCommand {
tab_items_for_queries(workspace, &[current_query], cancel, false, window, cx);
let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId);
window.spawn(cx, async move |_| {
window.spawn(cx, |_| async move {
let tab_items = tab_items_search.await?;
let run_command = tab_items.len() == 1;
let tab_completion_items = tab_items.into_iter().filter_map(|(path, ..)| {
@@ -172,11 +172,11 @@ fn tab_items_for_queries(
) -> Task<anyhow::Result<Vec<(Option<PathBuf>, BufferSnapshot, usize)>>> {
let empty_query = queries.is_empty() || queries.iter().all(|query| query.trim().is_empty());
let queries = queries.to_owned();
window.spawn(cx, async move |cx| {
window.spawn(cx, |mut cx| async move {
let mut open_buffers =
workspace
.context("no workspace")?
.update(cx, |workspace, cx| {
.update(&mut cx, |workspace, cx| {
if strict_match && empty_query {
let snapshot = active_item_buffer(workspace, cx)?;
let full_path = snapshot.resolve_file_path(cx, true);

View File

@@ -14,11 +14,9 @@ path = "src/assistant_tool.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
clock.workspace = true
derive_more.workspace = true
gpui.workspace = true
language.workspace = true
language_model.workspace = true
gpui.workspace = true
parking_lot.workspace = true
project.workspace = true
serde.workspace = true

View File

@@ -4,9 +4,7 @@ mod tool_working_set;
use std::sync::Arc;
use anyhow::Result;
use collections::{HashMap, HashSet};
use gpui::{App, Context, Entity, SharedString, Task};
use language::Buffer;
use gpui::{App, Entity, SharedString, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
@@ -43,70 +41,12 @@ pub trait Tool: 'static + Send + Sync {
serde_json::Value::Object(serde_json::Map::default())
}
/// Returns markdown to be displayed in the UI for this tool.
fn ui_text(&self, input: &serde_json::Value) -> String;
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>>;
}
/// Tracks actions performed by tools in a thread
#[derive(Debug)]
pub struct ActionLog {
/// Buffers that user manually added to the context, and whose content has
/// changed since the model last saw them.
stale_buffers_in_context: HashSet<Entity<Buffer>>,
/// Buffers that we want to notify the model about when they change.
tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
}
#[derive(Debug, Default)]
struct TrackedBuffer {
version: clock::Global,
}
impl ActionLog {
/// Creates a new, empty action log.
pub fn new() -> Self {
Self {
stale_buffers_in_context: HashSet::default(),
tracked_buffers: HashMap::default(),
}
}
/// Track a buffer as read, so we can notify the model about user edits.
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
tracked_buffer.version = buffer.read(cx).version();
}
/// Mark a buffer as edited, so we can refresh it in the context
pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
for buffer in &buffers {
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
tracked_buffer.version = buffer.read(cx).version();
}
self.stale_buffers_in_context.extend(buffers);
}
/// Iterate over buffers changed since last read or edited by the model
pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
self.tracked_buffers
.iter()
.filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
.map(|(buffer, _)| buffer)
}
/// Takes and returns the set of buffers pending refresh, clearing internal state.
pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
std::mem::take(&mut self.stale_buffers_in_context)
}
}

View File

@@ -19,9 +19,6 @@ collections.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
http_client.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
project.workspace = true
@@ -29,18 +26,15 @@ release_channel.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
theme.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
worktree.workspace = true
settings.workspace = true
[dev-dependencies]
rand.workspace = true
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
workspace = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@@ -2,33 +2,26 @@ mod bash_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_files_tool;
mod fetch_tool;
mod list_directory_tool;
mod now_tool;
mod path_search_tool;
mod read_file_tool;
mod regex_search_tool;
mod thinking_tool;
use std::sync::Arc;
mod regex_search;
use assistant_tool::ToolRegistry;
use gpui::App;
use http_client::HttpClientWithUrl;
use crate::bash_tool::BashTool;
use crate::delete_path_tool::DeletePathTool;
use crate::diagnostics_tool::DiagnosticsTool;
use crate::edit_files_tool::EditFilesTool;
use crate::fetch_tool::FetchTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::path_search_tool::PathSearchTool;
use crate::read_file_tool::ReadFileTool;
use crate::regex_search_tool::RegexSearchTool;
use crate::thinking_tool::ThinkingTool;
use crate::regex_search::RegexSearchTool;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
pub fn init(cx: &mut App) {
assistant_tool::init(cx);
crate::edit_files_tool::log::init(cx);
@@ -42,6 +35,4 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(PathSearchTool);
registry.register_tool(ReadFileTool);
registry.register_tool(RegexSearchTool);
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
}

View File

@@ -1,5 +1,5 @@
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
@@ -32,19 +32,11 @@ impl Tool for BashTool {
serde_json::to_value(&schema).unwrap()
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<BashToolInput>(input.clone()) {
Ok(input) => format!("`$ {}`", input.command),
Err(_) => "Run bash command".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input: BashToolInput = match serde_json::from_value(input) {
@@ -57,7 +49,7 @@ impl Tool for BashTool {
};
let working_directory = worktree.read(cx).abs_path();
cx.spawn(async move |_| {
cx.spawn(|_| async move {
// Add 2>&1 to merge stderr into stdout for proper interleaving.
let command = format!("({}) 2>&1", input.command);

View File

@@ -1,15 +1,16 @@
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::{fs, path::PathBuf, sync::Arc};
use util::paths::PathMatcher;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct DeletePathToolInput {
/// The path of the file or directory to delete.
/// The glob to match files in the project to delete.
///
/// <example>
/// If the project has the following files:
@@ -18,9 +19,9 @@ pub struct DeletePathToolInput {
/// - directory2/a/things.txt
/// - directory3/a/other.txt
///
/// You can delete the first file by providing a path of "directory1/a/something.txt"
/// You can delete the first two files by providing a glob of "*thing*.txt"
/// </example>
pub path: String,
pub glob: String,
}
pub struct DeletePathTool;
@@ -39,40 +40,126 @@ impl Tool for DeletePathTool {
serde_json::to_value(&schema).unwrap()
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<DeletePathToolInput>(input.clone()) {
Ok(input) => format!("Delete “`{}`”", input.path),
Err(_) => "Delete path".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
Ok(input) => input.path,
let glob = match serde_json::from_value::<DeletePathToolInput>(input) {
Ok(input) => input.glob,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let path_matcher = match PathMatcher::new(&[glob.clone()]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {}", err))),
};
match project
.read(cx)
.find_project_path(&path_str, cx)
.and_then(|path| project.update(cx, |project, cx| project.delete_file(path, false, cx)))
{
Some(deletion_task) => cx.background_spawn(async move {
match deletion_task.await {
Ok(()) => Ok(format!("Deleted {path_str}")),
Err(err) => Err(anyhow!("Failed to delete {path_str}: {err}")),
struct Match {
display_path: String,
path: PathBuf,
}
let mut matches = Vec::new();
let mut deleted_paths = Vec::new();
let mut errors = Vec::new();
for worktree_handle in project.read(cx).worktrees(cx) {
let worktree = worktree_handle.read(cx);
let worktree_root = worktree.abs_path().to_path_buf();
// Don't consider ignored entries.
for entry in worktree.entries(false, 0) {
if path_matcher.is_match(&entry.path) {
matches.push(Match {
path: worktree_root.join(&entry.path),
display_path: entry.path.display().to_string(),
});
}
}),
None => Task::ready(Err(anyhow!(
"Couldn't delete {path_str} because that path isn't in this project."
))),
}
}
if matches.is_empty() {
return Task::ready(Ok(format!("No paths in the project matched {glob:?}")));
}
let paths_matched = matches.len();
// Delete the files
for Match { path, display_path } in matches {
match fs::remove_file(&path) {
Ok(()) => {
deleted_paths.push(display_path);
}
Err(file_err) => {
// Try to remove directory if it's not a file. Retrying as a directory
// on error saves a syscall compared to checking whether it's
// a directory up front for every single file.
if let Err(dir_err) = fs::remove_dir_all(&path) {
let error = if path.is_dir() {
format!("Failed to delete directory {}: {dir_err}", display_path)
} else {
format!("Failed to delete file {}: {file_err}", display_path)
};
errors.push(error);
} else {
deleted_paths.push(display_path);
}
}
}
}
if errors.is_empty() {
// 0 deleted paths should never happen if there were no errors;
// we already returned if matches was empty.
let answer = if deleted_paths.len() == 1 {
format!(
"Deleted {}",
deleted_paths.first().unwrap_or(&String::new())
)
} else {
// Sort to group entries in the same directory together
deleted_paths.sort();
let mut buf = format!("Deleted these {} paths:\n", deleted_paths.len());
for path in deleted_paths.iter() {
buf.push('\n');
buf.push_str(path);
}
buf
};
Task::ready(Ok(answer))
} else {
if deleted_paths.is_empty() {
Task::ready(Err(anyhow!(
"{glob:?} matched {} deleted because of {}:\n{}",
if paths_matched == 1 {
"1 path, but it was not".to_string()
} else {
format!("{} paths, but none were", paths_matched)
},
if errors.len() == 1 {
"this error".to_string()
} else {
format!("{} errors", errors.len())
},
errors.join("\n")
)))
} else {
// Sort to group entries in the same directory together
deleted_paths.sort();
Task::ready(Ok(format!(
"Deleted {} paths matching glob {glob:?}:\n{}\n\nErrors:\n{}",
deleted_paths.len(),
deleted_paths.join("\n"),
errors.join("\n")
)))
}
}
}
}

View File

@@ -1 +1 @@
Deletes the file or directory (and the directory's contents, recursively) at the specified path in the project, and returns confirmation of the deletion.
Deletes all files and directories in the project which match the given glob, and returns a list of the paths that were deleted.

View File

@@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::LanguageModelRequestMessage;
@@ -46,41 +46,28 @@ impl Tool for DiagnosticsTool {
serde_json::to_value(&schema).unwrap()
}
fn ui_text(&self, input: &serde_json::Value) -> String {
if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input.clone())
.ok()
.and_then(|input| input.path)
{
format!("Check diagnostics for “`{}`”", path.display())
} else {
"Check project diagnostics".to_string()
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input)
.ok()
.and_then(|input| input.path)
{
let input = match serde_json::from_value::<DiagnosticsToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
if let Some(path) = input.path {
let Some(project_path) = project.read(cx).find_project_path(&path, cx) else {
return Task::ready(Err(anyhow!(
"Could not find path {} in project",
path.display()
)));
return Task::ready(Err(anyhow!("Could not find path in project")));
};
let buffer = project.update(cx, |project, cx| project.open_buffer(project_path, cx));
cx.spawn(async move |cx| {
cx.spawn(|cx| async move {
let mut output = String::new();
let buffer = buffer.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let snapshot = buffer.read_with(&cx, |buffer, _cx| buffer.snapshot())?;
for (_, group) in snapshot.diagnostic_groups(None) {
let entry = &group.entries[group.primary_ix];

View File

@@ -1,30 +1,32 @@
mod edit_action;
pub mod log;
mod replace;
use anyhow::{anyhow, Context, Result};
use assistant_tool::{ActionLog, Tool};
use assistant_tool::Tool;
use collections::HashSet;
use edit_action::{EditAction, EditActionParser};
use futures::StreamExt;
use gpui::{App, AsyncApp, Entity, Task};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use log::{EditToolLog, EditToolRequestId};
use project::Project;
use replace::{replace_exact, replace_with_flexible_indent};
use project::{search::SearchQuery, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::sync::Arc;
use util::paths::PathMatcher;
use util::ResultExt;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct EditFilesToolInput {
/// High-level edit instructions. These will be interpreted by a smaller
/// model, so explain the changes you want that model to make and which
/// file paths need changing. The description should be concise and clear.
/// file paths need changing.
///
/// The description should be concise and clear. We will show this
/// description to the user as well.
///
/// WARNING: When specifying which file paths need changing, you MUST
/// start each path with one of the project's root directories.
@@ -55,21 +57,6 @@ pub struct EditFilesToolInput {
/// Notice how we never specify code snippets in the instructions!
/// </example>
pub edit_instructions: String,
/// A user-friendly description of what changes are being made.
/// This will be shown to the user in the UI to describe the edit operation. The screen real estate for this UI will be extremely
/// constrained, so make the description extremely terse.
///
/// <example>
/// For fixing a broken authentication system:
/// "Fix auth bug in login flow"
/// </example>
///
/// <example>
/// For adding unit tests to a module:
/// "Add tests for user profile logic"
/// </example>
pub display_description: String,
}
pub struct EditFilesTool;
@@ -88,19 +75,11 @@ impl Tool for EditFilesTool {
serde_json::to_value(&schema).unwrap()
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<EditFilesToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Edit files".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<EditFilesToolInput>(input) {
@@ -114,16 +93,10 @@ impl Tool for EditFilesTool {
log.new_request(input.edit_instructions.clone(), cx)
});
let task = EditToolRequest::new(
input,
messages,
project,
action_log,
Some((log.clone(), req_id)),
cx,
);
let task =
EditToolRequest::new(input, messages, project, Some((log.clone(), req_id)), cx);
cx.spawn(async move |cx| {
cx.spawn(|mut cx| async move {
let result = task.await;
let str_result = match &result {
@@ -131,53 +104,38 @@ impl Tool for EditFilesTool {
Err(err) => Err(err.to_string()),
};
log.update(cx, |log, cx| log.set_tool_output(req_id, str_result, cx))
.log_err();
log.update(&mut cx, |log, cx| {
log.set_tool_output(req_id, str_result, cx)
})
.log_err();
result
})
}
None => EditToolRequest::new(input, messages, project, action_log, None, cx),
None => EditToolRequest::new(input, messages, project, None, cx),
}
}
}
struct EditToolRequest {
parser: EditActionParser,
editor_response: EditorResponse,
changed_buffers: HashSet<Entity<language::Buffer>>,
bad_searches: Vec<BadSearch>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
}
enum EditorResponse {
/// The editor model hasn't produced any actions yet.
/// If we don't have any by the end, we'll return its message to the architect model.
Message(String),
/// The editor model produced at least one action.
Actions {
applied: Vec<AppliedAction>,
search_errors: Vec<SearchError>,
},
}
struct AppliedAction {
source: String,
buffer: Entity<language::Buffer>,
log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
}
#[derive(Debug)]
enum SearchError {
NoMatch {
file_path: String,
search: String,
},
EmptyBuffer {
file_path: String,
search: String,
exists: bool,
},
enum DiffResult {
BadSearch(BadSearch),
Diff(language::Diff),
}
#[derive(Debug)]
struct BadSearch {
file_path: String,
search: String,
}
impl EditToolRequest {
@@ -185,8 +143,7 @@ impl EditToolRequest {
input: EditFilesToolInput,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
cx: &mut App,
) -> Task<Result<String>> {
let model_registry = LanguageModelRegistry::read_global(cx);
@@ -195,23 +152,12 @@ impl EditToolRequest {
};
let mut messages = messages.to_vec();
// Remove the last tool use (this run) to prevent an invalid request
'outer: for message in messages.iter_mut().rev() {
for (index, content) in message.content.iter().enumerate().rev() {
match content {
MessageContent::ToolUse(_) => {
message.content.remove(index);
break 'outer;
}
MessageContent::ToolResult(_) => {
// If we find any tool results before a tool use, the request is already valid
break 'outer;
}
MessageContent::Text(_) | MessageContent::Image(_) => {}
}
}
if let Some(last_message) = messages.last_mut() {
// Strip out tool use from the last message because we're in the middle of executing a tool call.
last_message
.content
.retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
}
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![
@@ -221,7 +167,7 @@ impl EditToolRequest {
cache: false,
});
cx.spawn(async move |cx| {
cx.spawn(|mut cx| async move {
let llm_request = LanguageModelRequest {
messages,
tools: vec![],
@@ -234,30 +180,24 @@ impl EditToolRequest {
let mut request = Self {
parser: EditActionParser::new(),
editor_response: EditorResponse::Message(String::with_capacity(256)),
action_log,
changed_buffers: HashSet::default(),
bad_searches: Vec::new(),
project,
tool_log,
log,
};
while let Some(chunk) = chunks.stream.next().await {
request.process_response_chunk(&chunk?, cx).await?;
request.process_response_chunk(&chunk?, &mut cx).await?;
}
request.finalize(cx).await
request.finalize(&mut cx).await
})
}
async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
let new_actions = self.parser.parse_chunk(chunk);
if let EditorResponse::Message(ref mut message) = self.editor_response {
if new_actions.is_empty() {
message.push_str(chunk);
}
}
if let Some((ref log, req_id)) = self.tool_log {
if let Some((ref log, req_id)) = self.log {
log.update(cx, |log, cx| {
log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
})
@@ -271,11 +211,7 @@ impl EditToolRequest {
Ok(())
}
async fn apply_action(
&mut self,
(action, source): (EditAction, String),
cx: &mut AsyncApp,
) -> Result<()> {
async fn apply_action(&mut self, action: EditAction, cx: &mut AsyncApp) -> Result<()> {
let project_path = self.project.read_with(cx, |project, cx| {
project
.find_project_path(action.file_path(), cx)
@@ -287,11 +223,6 @@ impl EditToolRequest {
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
enum DiffResult {
Diff(language::Diff),
SearchError(SearchError),
}
let result = match action {
EditAction::Replace {
old,
@@ -301,39 +232,7 @@ impl EditToolRequest {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
cx.background_executor()
.spawn(async move {
if snapshot.is_empty() {
let exists = snapshot
.file()
.map_or(false, |file| file.disk_state().exists());
let error = SearchError::EmptyBuffer {
file_path: file_path.display().to_string(),
exists,
search: old,
};
return anyhow::Ok(DiffResult::SearchError(error));
}
let replace_result =
// Try to match exactly
replace_exact(&old, &new, &snapshot)
.await
// If that fails, try being flexible about indentation
.or_else(|| replace_with_flexible_indent(&old, &new, &snapshot));
let Some(diff) = replace_result else {
let error = SearchError::NoMatch {
search: old,
file_path: file_path.display().to_string(),
};
return Ok(DiffResult::SearchError(error));
};
Ok(DiffResult::Diff(diff))
})
.spawn(Self::replace_diff(old, new, file_path, snapshot))
.await
}
EditAction::Write { content, .. } => Ok(DiffResult::Diff(
@@ -344,179 +243,133 @@ impl EditToolRequest {
}?;
match result {
DiffResult::SearchError(error) => {
self.push_search_error(error);
DiffResult::BadSearch(invalid_replace) => {
self.bad_searches.push(invalid_replace);
}
DiffResult::Diff(diff) => {
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
self.push_applied_action(AppliedAction { source, buffer });
self.changed_buffers.insert(buffer);
}
}
anyhow::Ok(())
Ok(())
}
fn push_search_error(&mut self, error: SearchError) {
match &mut self.editor_response {
EditorResponse::Message(_) => {
self.editor_response = EditorResponse::Actions {
applied: Vec::new(),
search_errors: vec![error],
};
}
EditorResponse::Actions { search_errors, .. } => {
search_errors.push(error);
}
}
}
async fn replace_diff(
old: String,
new: String,
file_path: std::path::PathBuf,
snapshot: language::BufferSnapshot,
) -> Result<DiffResult> {
let query = SearchQuery::text(
old.clone(),
false,
true,
true,
PathMatcher::new(&[])?,
PathMatcher::new(&[])?,
None,
)?;
fn push_applied_action(&mut self, action: AppliedAction) {
match &mut self.editor_response {
EditorResponse::Message(_) => {
self.editor_response = EditorResponse::Actions {
applied: vec![action],
search_errors: Vec::new(),
};
}
EditorResponse::Actions { applied, .. } => {
applied.push(action);
}
let matches = query.search(&snapshot, None).await;
if matches.is_empty() {
return Ok(DiffResult::BadSearch(BadSearch {
search: new.clone(),
file_path: file_path.display().to_string(),
}));
}
let edit_range = matches[0].clone();
let diff = language::text_diff(&old, &new);
let edits = diff
.into_iter()
.map(|(old_range, text)| {
let start = edit_range.start + old_range.start;
let end = edit_range.start + old_range.end;
(start..end, text)
})
.collect::<Vec<_>>();
let diff = language::Diff {
base_version: snapshot.version().clone(),
line_ending: snapshot.line_ending(),
edits,
};
anyhow::Ok(DiffResult::Diff(diff))
}
async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
match self.editor_response {
EditorResponse::Message(message) => Err(anyhow!(
"No edits were applied! You might need to provide more context.\n\n{}",
message
)),
EditorResponse::Actions {
applied,
search_errors,
} => {
let mut output = String::with_capacity(1024);
let mut answer = match self.changed_buffers.len() {
0 => "No files were edited.".to_string(),
1 => "Successfully edited ".to_string(),
_ => "Successfully edited these files:\n\n".to_string(),
};
let parse_errors = self.parser.errors();
let has_errors = !search_errors.is_empty() || !parse_errors.is_empty();
// Save each buffer once at the end
for buffer in self.changed_buffers {
let (path, save_task) = self.project.update(cx, |project, cx| {
let path = buffer
.read(cx)
.file()
.map(|file| file.path().display().to_string());
if has_errors {
let error_count = search_errors.len() + parse_errors.len();
let task = project.save_buffer(buffer.clone(), cx);
if applied.is_empty() {
writeln!(
&mut output,
"{} errors occurred! No edits were applied.",
error_count,
)?;
} else {
writeln!(
&mut output,
"{} errors occurred, but {} edits were correctly applied.",
error_count,
applied.len(),
)?;
(path, task)
})?;
writeln!(
&mut output,
"# {} SEARCH/REPLACE block(s) applied:\n\nDo not re-send these since they are already applied!\n",
applied.len()
)?;
}
} else {
write!(
&mut output,
"Successfully applied! Here's a list of applied edits:"
)?;
}
save_task.await?;
let mut changed_buffers = HashSet::default();
if let Some(path) = path {
writeln!(&mut answer, "{}", path)?;
}
}
for action in applied {
changed_buffers.insert(action.buffer);
write!(&mut output, "\n\n{}", action.source)?;
}
let errors = self.parser.errors();
for buffer in &changed_buffers {
self.project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
}
if errors.is_empty() && self.bad_searches.is_empty() {
Ok(answer.trim_end().to_string())
} else {
if !self.bad_searches.is_empty() {
writeln!(
&mut answer,
"\nThese searches failed because they didn't match any strings:"
)?;
self.action_log
.update(cx, |log, cx| log.buffer_edited(changed_buffers.clone(), cx))
.log_err();
if !search_errors.is_empty() {
for replace in self.bad_searches {
writeln!(
&mut output,
"\n\n## {} SEARCH/REPLACE block(s) failed to match:\n",
search_errors.len()
)?;
for error in search_errors {
match error {
SearchError::NoMatch { file_path, search } => {
writeln!(
&mut output,
"### No exact match in: `{}`\n```\n{}\n```\n",
file_path, search,
)?;
}
SearchError::EmptyBuffer {
file_path,
exists: true,
search,
} => {
writeln!(
&mut output,
"### No match because `{}` is empty:\n```\n{}\n```\n",
file_path, search,
)?;
}
SearchError::EmptyBuffer {
file_path,
exists: false,
search,
} => {
writeln!(
&mut output,
"### No match because `{}` does not exist:\n```\n{}\n```\n",
file_path, search,
)?;
}
}
}
write!(&mut output,
"The SEARCH section must exactly match an existing block of lines including all white \
space, comments, indentation, docstrings, etc."
&mut answer,
"- '{}' does not appear in `{}`",
replace.search.replace("\r", "\\r").replace("\n", "\\n"),
replace.file_path
)?;
}
if !parse_errors.is_empty() {
writeln!(
&mut output,
"\n\n## {} SEARCH/REPLACE blocks failed to parse:",
parse_errors.len()
)?;
writeln!(&mut answer, "Make sure to use exact searches.")?;
}
for error in parse_errors {
writeln!(&mut output, "- {}", error)?;
}
}
if !errors.is_empty() {
writeln!(
&mut answer,
"\nThese SEARCH/REPLACE blocks failed to parse:"
)?;
if has_errors {
writeln!(&mut output,
"\n\nYou can fix errors by running the tool again. You can include instructions, \
but errors are part of the conversation so you don't need to repeat them.",
)?;
Err(anyhow!(output))
} else {
Ok(output)
for error in errors {
writeln!(&mut answer, "- {}", error)?;
}
}
writeln!(
&mut answer,
"\nYou can fix errors by running the tool again. You can include instructions,\
but errors are part of the conversation so you don't need to repeat them."
)?;
Err(anyhow!(answer))
}
}
}

View File

@@ -1,8 +1,4 @@
use std::{
mem::take,
ops::Range,
path::{Path, PathBuf},
};
use std::path::{Path, PathBuf};
use util::ResultExt;
/// Represents an edit action to be performed on a file.
@@ -32,14 +28,12 @@ impl EditAction {
#[derive(Debug)]
pub struct EditActionParser {
state: State,
pre_fence_line: Vec<u8>,
marker_ix: usize,
line: usize,
column: usize,
marker_ix: usize,
action_source: Vec<u8>,
fence_start_offset: usize,
block_range: Range<usize>,
old_range: Range<usize>,
new_range: Range<usize>,
old_bytes: Vec<u8>,
new_bytes: Vec<u8>,
errors: Vec<ParseError>,
}
@@ -64,14 +58,12 @@ impl EditActionParser {
pub fn new() -> Self {
Self {
state: State::Default,
pre_fence_line: Vec::new(),
marker_ix: 0,
line: 1,
column: 0,
action_source: Vec::new(),
fence_start_offset: 0,
marker_ix: 0,
block_range: Range::default(),
old_range: Range::default(),
new_range: Range::default(),
old_bytes: Vec::new(),
new_bytes: Vec::new(),
errors: Vec::new(),
}
}
@@ -84,7 +76,7 @@ impl EditActionParser {
///
/// If a block fails to parse, it will simply be skipped and an error will be recorded.
/// All errors can be accessed through the `EditActionsParser::errors` method.
pub fn parse_chunk(&mut self, input: &str) -> Vec<(EditAction, String)> {
pub fn parse_chunk(&mut self, input: &str) -> Vec<EditAction> {
use State::*;
const FENCE: &[u8] = b"```";
@@ -105,21 +97,20 @@ impl EditActionParser {
self.column += 1;
}
let action_offset = self.action_source.len();
match &self.state {
Default => match self.match_marker(byte, FENCE, false) {
Default => match match_marker(byte, FENCE, false, &mut self.marker_ix) {
MarkerMatch::Complete => {
self.fence_start_offset = action_offset + 1 - FENCE.len();
self.to_state(OpenFence);
}
MarkerMatch::Partial => {}
MarkerMatch::None => {
if self.marker_ix > 0 {
self.marker_ix = 0;
} else if self.action_source.ends_with(b"\n") {
self.action_source.clear();
} else if self.pre_fence_line.ends_with(b"\n") {
self.pre_fence_line.clear();
}
self.pre_fence_line.push(byte);
}
},
OpenFence => {
@@ -134,34 +125,39 @@ impl EditActionParser {
}
}
SearchBlock => {
if self.extend_block_range(byte, DIVIDER, NL_DIVIDER) {
self.old_range = take(&mut self.block_range);
if collect_until_marker(
byte,
DIVIDER,
NL_DIVIDER,
true,
&mut self.marker_ix,
&mut self.old_bytes,
) {
self.to_state(ReplaceBlock);
}
}
ReplaceBlock => {
if self.extend_block_range(byte, REPLACE_MARKER, NL_REPLACE_MARKER) {
self.new_range = take(&mut self.block_range);
if collect_until_marker(
byte,
REPLACE_MARKER,
NL_REPLACE_MARKER,
true,
&mut self.marker_ix,
&mut self.new_bytes,
) {
self.to_state(CloseFence);
}
}
CloseFence => {
if self.expect_marker(byte, FENCE, false) {
self.action_source.push(byte);
if let Some(action) = self.action() {
actions.push(action);
}
self.errors();
self.reset();
continue;
}
}
};
self.action_source.push(byte);
}
actions
@@ -172,44 +168,48 @@ impl EditActionParser {
&self.errors
}
fn action(&mut self) -> Option<(EditAction, String)> {
let old_range = take(&mut self.old_range);
let new_range = take(&mut self.new_range);
fn action(&mut self) -> Option<EditAction> {
if self.old_bytes.is_empty() && self.new_bytes.is_empty() {
self.push_error(ParseErrorKind::NoOp);
return None;
}
let action_source = take(&mut self.action_source);
let action_source = String::from_utf8(action_source).log_err()?;
let mut pre_fence_line = std::mem::take(&mut self.pre_fence_line);
let mut file_path_bytes = action_source[..self.fence_start_offset].to_owned();
if pre_fence_line.ends_with(b"\n") {
pre_fence_line.pop();
pop_carriage_return(&mut pre_fence_line);
}
if file_path_bytes.ends_with("\n") {
file_path_bytes.pop();
if file_path_bytes.ends_with("\r") {
file_path_bytes.pop();
let file_path = PathBuf::from(String::from_utf8(pre_fence_line).log_err()?);
let content = String::from_utf8(std::mem::take(&mut self.new_bytes)).log_err()?;
if self.old_bytes.is_empty() {
Some(EditAction::Write { file_path, content })
} else {
let old = String::from_utf8(std::mem::take(&mut self.old_bytes)).log_err()?;
Some(EditAction::Replace {
file_path,
old,
new: content,
})
}
}
fn expect_marker(&mut self, byte: u8, marker: &'static [u8], trailing_newline: bool) -> bool {
match match_marker(byte, marker, trailing_newline, &mut self.marker_ix) {
MarkerMatch::Complete => true,
MarkerMatch::Partial => false,
MarkerMatch::None => {
self.push_error(ParseErrorKind::ExpectedMarker {
expected: marker,
found: byte,
});
self.reset();
false
}
}
let file_path = PathBuf::from(file_path_bytes);
if old_range.is_empty() {
return Some((
EditAction::Write {
file_path,
content: action_source[new_range].to_owned(),
},
action_source,
));
}
let old = action_source[old_range].to_owned();
let new = action_source[new_range].to_owned();
let action = EditAction::Replace {
file_path,
old,
new,
};
Some((action, action_source))
}
fn to_state(&mut self, state: State) {
@@ -218,95 +218,18 @@ impl EditActionParser {
}
fn reset(&mut self) {
self.action_source.clear();
self.block_range = Range::default();
self.old_range = Range::default();
self.new_range = Range::default();
self.fence_start_offset = 0;
self.marker_ix = 0;
self.pre_fence_line.clear();
self.old_bytes.clear();
self.new_bytes.clear();
self.to_state(State::Default);
}
fn expect_marker(&mut self, byte: u8, marker: &'static [u8], trailing_newline: bool) -> bool {
match self.match_marker(byte, marker, trailing_newline) {
MarkerMatch::Complete => true,
MarkerMatch::Partial => false,
MarkerMatch::None => {
self.errors.push(ParseError {
line: self.line,
column: self.column,
expected: marker,
found: byte,
});
self.reset();
false
}
}
}
fn extend_block_range(&mut self, byte: u8, marker: &[u8], nl_marker: &[u8]) -> bool {
let marker = if self.block_range.is_empty() {
// do not require another newline if block is empty
marker
} else {
nl_marker
};
let offset = self.action_source.len();
match self.match_marker(byte, marker, true) {
MarkerMatch::Complete => {
if self.action_source[self.block_range.clone()].ends_with(b"\r") {
self.block_range.end -= 1;
}
true
}
MarkerMatch::Partial => false,
MarkerMatch::None => {
if self.marker_ix > 0 {
self.marker_ix = 0;
self.block_range.end = offset;
// The beginning of marker might match current byte
match self.match_marker(byte, marker, true) {
MarkerMatch::Complete => return true,
MarkerMatch::Partial => return false,
MarkerMatch::None => { /* no match, keep collecting */ }
}
}
if self.block_range.is_empty() {
self.block_range.start = offset;
}
self.block_range.end = offset + 1;
false
}
}
}
fn match_marker(&mut self, byte: u8, marker: &[u8], trailing_newline: bool) -> MarkerMatch {
if trailing_newline && self.marker_ix >= marker.len() {
if byte == b'\n' {
MarkerMatch::Complete
} else if byte == b'\r' {
MarkerMatch::Partial
} else {
MarkerMatch::None
}
} else if byte == marker[self.marker_ix] {
self.marker_ix += 1;
if self.marker_ix < marker.len() || trailing_newline {
MarkerMatch::Partial
} else {
MarkerMatch::Complete
}
} else {
MarkerMatch::None
}
fn push_error(&mut self, kind: ParseErrorKind) {
self.errors.push(ParseError {
line: self.line,
column: self.column,
kind,
});
}
}
@@ -317,24 +240,114 @@ enum MarkerMatch {
Complete,
}
fn match_marker(
byte: u8,
marker: &[u8],
trailing_newline: bool,
marker_ix: &mut usize,
) -> MarkerMatch {
if trailing_newline && *marker_ix >= marker.len() {
if byte == b'\n' {
MarkerMatch::Complete
} else if byte == b'\r' {
MarkerMatch::Partial
} else {
MarkerMatch::None
}
} else if byte == marker[*marker_ix] {
*marker_ix += 1;
if *marker_ix < marker.len() || trailing_newline {
MarkerMatch::Partial
} else {
MarkerMatch::Complete
}
} else {
MarkerMatch::None
}
}
fn collect_until_marker(
byte: u8,
marker: &[u8],
nl_marker: &[u8],
trailing_newline: bool,
marker_ix: &mut usize,
buf: &mut Vec<u8>,
) -> bool {
let marker = if buf.is_empty() {
// do not require another newline if block is empty
marker
} else {
nl_marker
};
match match_marker(byte, marker, trailing_newline, marker_ix) {
MarkerMatch::Complete => {
pop_carriage_return(buf);
true
}
MarkerMatch::Partial => false,
MarkerMatch::None => {
if *marker_ix > 0 {
buf.extend_from_slice(&marker[..*marker_ix]);
*marker_ix = 0;
// The beginning of marker might match current byte
match match_marker(byte, marker, trailing_newline, marker_ix) {
MarkerMatch::Complete => return true,
MarkerMatch::Partial => return false,
MarkerMatch::None => { /* no match, keep collecting */ }
}
}
buf.push(byte);
false
}
}
}
fn pop_carriage_return(buf: &mut Vec<u8>) {
if buf.ends_with(b"\r") {
buf.pop();
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct ParseError {
line: usize,
column: usize,
expected: &'static [u8],
found: u8,
kind: ParseErrorKind,
}
#[derive(Debug, PartialEq, Eq)]
pub enum ParseErrorKind {
ExpectedMarker { expected: &'static [u8], found: u8 },
NoOp,
}
impl std::fmt::Display for ParseErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ParseErrorKind::ExpectedMarker { expected, found } => {
write!(
f,
"Expected marker {:?}, found {:?}",
String::from_utf8_lossy(expected),
*found as char
)
}
ParseErrorKind::NoOp => {
write!(f, "No search or replace")
}
}
}
}
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"input:{}:{}: Expected marker {:?}, found {:?}",
self.line,
self.column,
String::from_utf8_lossy(self.expected),
self.found as char
)
write!(f, "input:{}:{}: {}", self.line, self.column, self.kind)
}
}
@@ -342,7 +355,6 @@ impl std::fmt::Display for ParseError {
mod tests {
use super::*;
use rand::prelude::*;
use util::line_endings;
#[test]
fn test_simple_edit_action() {
@@ -359,16 +371,16 @@ fn replacement() {}
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {}".to_string(),
new: "fn replacement() {}".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -386,16 +398,16 @@ fn replacement() {}
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {}".to_string(),
new: "fn replacement() {}".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -417,16 +429,16 @@ This change makes the function better.
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {}".to_string(),
new: "fn replacement() {}".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -455,27 +467,24 @@ fn new_util() -> bool { true }
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 2);
let (action, _) = &actions[0];
assert_eq!(
action,
&EditAction::Replace {
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {}".to_string(),
new: "fn replacement() {}".to_string(),
}
);
let (action2, _) = &actions[1];
assert_eq!(
action2,
&EditAction::Replace {
actions[1],
EditAction::Replace {
file_path: PathBuf::from("src/utils.rs"),
old: "fn old_util() -> bool { false }".to_string(),
new: "fn new_util() -> bool { true }".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -507,18 +516,16 @@ fn replacement() {
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
let (action, _) = &actions[0];
assert_eq!(
action,
&EditAction::Replace {
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {\n println!(\"This is the original function\");\n let x = 42;\n if x > 0 {\n println!(\"Positive number\");\n }\n}".to_string(),
new: "fn replacement() {\n println!(\"This is the replacement function\");\n let x = 100;\n if x > 50 {\n println!(\"Large number\");\n } else {\n println!(\"Small number\");\n }\n}".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -539,16 +546,16 @@ fn new_function() {
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Write {
file_path: PathBuf::from("src/main.rs"),
content: "fn new_function() {\n println!(\"This function is being added\");\n}"
.to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -566,11 +573,9 @@ fn this_will_be_deleted() {
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(&input);
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn this_will_be_deleted() {\n println!(\"Deleting this function\");\n}"
@@ -578,13 +583,12 @@ fn this_will_be_deleted() {
new: "".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(&input.replace("\n", "\r\n"));
assert_no_errors(&parser);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old:
@@ -593,6 +597,7 @@ fn this_will_be_deleted() {
new: "".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -608,15 +613,15 @@ fn this_will_be_deleted() {
let mut parser = EditActionParser::new();
let actions = parser.parse_chunk(input);
assert_eq!(actions.len(), 1);
assert_eq!(
actions[0].0,
EditAction::Write {
file_path: PathBuf::from("src/main.rs"),
content: String::new(),
}
);
assert_no_errors(&parser);
// Should not create an action when both sections are empty
assert_eq!(actions.len(), 0);
// Check that the NoOp error was added
assert_eq!(parser.errors().len(), 1);
match parser.errors()[0].kind {
ParseErrorKind::NoOp => {}
_ => panic!("Expected NoOp error"),
}
}
#[test]
@@ -637,27 +642,26 @@ fn replacement() {}"#;
let mut parser = EditActionParser::new();
let actions1 = parser.parse_chunk(input_part1);
assert_no_errors(&parser);
assert_eq!(actions1.len(), 0);
assert_eq!(parser.errors().len(), 0);
let actions2 = parser.parse_chunk(input_part2);
// No actions should be complete yet
assert_no_errors(&parser);
assert_eq!(actions2.len(), 0);
assert_eq!(parser.errors().len(), 0);
let actions3 = parser.parse_chunk(input_part3);
// The third chunk should complete the action
assert_no_errors(&parser);
assert_eq!(actions3.len(), 1);
let (action, _) = &actions3[0];
assert_eq!(
action,
&EditAction::Replace {
actions3[0],
EditAction::Replace {
file_path: PathBuf::from("src/main.rs"),
old: "fn original() {}".to_string(),
new: "fn replacement() {}".to_string(),
}
);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -666,35 +670,28 @@ fn replacement() {}"#;
let actions1 = parser.parse_chunk("src/main.rs\n```rust\n<<<<<<< SEARCH\n");
// Check parser is in the correct state
assert_no_errors(&parser);
assert_eq!(parser.state, State::SearchBlock);
assert_eq!(
parser.action_source,
b"src/main.rs\n```rust\n<<<<<<< SEARCH\n"
);
assert_eq!(parser.pre_fence_line, b"src/main.rs\n");
assert_eq!(parser.errors().len(), 0);
// Continue parsing
let actions2 = parser.parse_chunk("original code\n=======\n");
assert_no_errors(&parser);
assert_eq!(parser.state, State::ReplaceBlock);
assert_eq!(
&parser.action_source[parser.old_range.clone()],
b"original code"
);
assert_eq!(parser.old_bytes, b"original code");
assert_eq!(parser.errors().len(), 0);
let actions3 = parser.parse_chunk("replacement code\n>>>>>>> REPLACE\n```\n");
// After complete parsing, state should reset
assert_no_errors(&parser);
assert_eq!(parser.state, State::Default);
assert_eq!(parser.action_source, b"\n");
assert!(parser.old_range.is_empty());
assert!(parser.new_range.is_empty());
assert_eq!(parser.pre_fence_line, b"\n");
assert!(parser.old_bytes.is_empty());
assert!(parser.new_bytes.is_empty());
assert_eq!(actions1.len(), 0);
assert_eq!(actions2.len(), 0);
assert_eq!(actions3.len(), 1);
assert_eq!(parser.errors().len(), 0);
}
#[test]
@@ -748,10 +745,9 @@ fn new_utils_func() {}
// Only the second block should be parsed
assert_eq!(actions.len(), 1);
let (action, _) = &actions[0];
assert_eq!(
action,
&EditAction::Replace {
actions[0],
EditAction::Replace {
file_path: PathBuf::from("src/utils.rs"),
old: "fn utils_func() {}".to_string(),
new: "fn new_utils_func() {}".to_string(),
@@ -760,7 +756,7 @@ fn new_utils_func() {}
assert_eq!(parser.errors().len(), 1);
assert_eq!(
parser.errors()[0].to_string(),
"input:8:1: Expected marker \"```\", found '<'"
"input:8:1: Expected marker \"```\", found '<'".to_string()
);
// The parser should continue after an error
@@ -787,65 +783,64 @@ fn new_utils_func() {}
let (chunk, rest) = remaining.split_at(chunk_size);
let chunk_actions = parser.parse_chunk(chunk);
actions.extend(chunk_actions);
actions.extend(parser.parse_chunk(chunk));
remaining = rest;
}
assert_examples_in_system_prompt(&actions, parser.errors());
}
fn assert_examples_in_system_prompt(actions: &[(EditAction, String)], errors: &[ParseError]) {
fn assert_examples_in_system_prompt(actions: &[EditAction], errors: &[ParseError]) {
assert_eq!(actions.len(), 5);
assert_eq!(
actions[0].0,
actions[0],
EditAction::Replace {
file_path: PathBuf::from("mathweb/flask/app.py"),
old: "from flask import Flask".to_string(),
new: line_endings!("import math\nfrom flask import Flask").to_string(),
},
new: "import math\nfrom flask import Flask".to_string(),
}
.fix_lf(),
);
assert_eq!(
actions[1].0,
actions[1],
EditAction::Replace {
file_path: PathBuf::from("mathweb/flask/app.py"),
old: line_endings!("def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n").to_string(),
old: "def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n".to_string(),
new: "".to_string(),
}
.fix_lf()
);
assert_eq!(
actions[2].0,
actions[2],
EditAction::Replace {
file_path: PathBuf::from("mathweb/flask/app.py"),
old: " return str(factorial(n))".to_string(),
new: " return str(math.factorial(n))".to_string(),
},
}
.fix_lf(),
);
assert_eq!(
actions[3].0,
actions[3],
EditAction::Write {
file_path: PathBuf::from("hello.py"),
content: line_endings!(
"def hello():\n \"print a greeting\"\n\n print(\"hello\")"
)
.to_string(),
},
content: "def hello():\n \"print a greeting\"\n\n print(\"hello\")"
.to_string(),
}
.fix_lf(),
);
assert_eq!(
actions[4].0,
actions[4],
EditAction::Replace {
file_path: PathBuf::from("main.py"),
old: line_endings!(
"def hello():\n \"print a greeting\"\n\n print(\"hello\")"
)
.to_string(),
old: "def hello():\n \"print a greeting\"\n\n print(\"hello\")".to_string(),
new: "from hello import hello".to_string(),
},
}
.fix_lf(),
);
// The system prompt includes some text that would produce errors
@@ -865,6 +860,29 @@ fn new_utils_func() {}
);
}
impl EditAction {
fn fix_lf(self: EditAction) -> EditAction {
#[cfg(windows)]
match self {
EditAction::Replace {
file_path,
old,
new,
} => EditAction::Replace {
file_path: file_path.clone(),
old: old.replace("\n", "\r\n"),
new: new.replace("\n", "\r\n"),
},
EditAction::Write { file_path, content } => EditAction::Write {
file_path: file_path.clone(),
content: content.replace("\n", "\r\n"),
},
}
#[cfg(not(windows))]
self
}
}
#[test]
fn test_print_error() {
let input = r#"src/main.rs
@@ -886,20 +904,4 @@ fn replacement() {}
assert_eq!(format!("{}", error), expected_error);
}
// helpers
fn assert_no_errors(parser: &EditActionParser) {
let errors = parser.errors();
assert!(
errors.is_empty(),
"Expected no errors, but found:\n\n{}",
errors
.iter()
.map(|e| e.to_string())
.collect::<Vec<String>>()
.join("\n")
);
}
}

View File

@@ -80,7 +80,7 @@ impl EditToolLog {
&mut self,
id: EditToolRequestId,
chunk: &str,
new_actions: &[(EditAction, String)],
new_actions: &[EditAction],
cx: &mut Context<Self>,
) {
if let Some(request) = self.requests.get_mut(id.0 as usize) {
@@ -92,9 +92,7 @@ impl EditToolLog {
response.push_str(chunk);
}
}
request
.parsed_edits
.extend(new_actions.iter().cloned().map(|(action, _)| action));
request.parsed_edits.extend(new_actions.iter().cloned());
cx.emit(EditToolLogEvent::Updated);
}

View File

@@ -1,525 +0,0 @@
use language::{BufferSnapshot, Diff, Point, ToOffset};
use project::search::SearchQuery;
use util::{paths::PathMatcher, ResultExt as _};
/// Performs an exact string replacement in a buffer, requiring precise character-for-character matching.
/// Uses the search functionality to locate the first occurrence of the exact string.
/// Returns None if no exact match is found in the buffer.
pub async fn replace_exact(old: &str, new: &str, snapshot: &BufferSnapshot) -> Option<Diff> {
let query = SearchQuery::text(
old,
false,
true,
true,
PathMatcher::new(&[]).ok()?,
PathMatcher::new(&[]).ok()?,
None,
)
.log_err()?;
let matches = query.search(&snapshot, None).await;
if matches.is_empty() {
return None;
}
let edit_range = matches[0].clone();
let diff = language::text_diff(&old, &new);
let edits = diff
.into_iter()
.map(|(old_range, text)| {
let start = edit_range.start + old_range.start;
let end = edit_range.start + old_range.end;
(start..end, text)
})
.collect::<Vec<_>>();
let diff = language::Diff {
base_version: snapshot.version().clone(),
line_ending: snapshot.line_ending(),
edits,
};
Some(diff)
}
/// Performs a replacement that's indentation-aware - matches text content ignoring leading whitespace differences.
/// When replacing, preserves the indentation level found in the buffer at each matching line.
/// Returns None if no match found or if indentation is offset inconsistently across matched lines.
pub fn replace_with_flexible_indent(old: &str, new: &str, buffer: &BufferSnapshot) -> Option<Diff> {
let (old_lines, old_min_indent) = lines_with_min_indent(old);
let (new_lines, new_min_indent) = lines_with_min_indent(new);
let min_indent = old_min_indent.min(new_min_indent);
let old_lines = drop_lines_prefix(&old_lines, min_indent);
let new_lines = drop_lines_prefix(&new_lines, min_indent);
let max_row = buffer.max_point().row;
'windows: for start_row in 0..max_row.saturating_sub(old_lines.len() as u32 - 1) {
let mut common_leading = None;
let end_row = start_row + old_lines.len() as u32 - 1;
if end_row > max_row {
// The buffer ends before fully matching the pattern
return None;
}
let start_point = Point::new(start_row, 0);
let end_point = Point::new(end_row, buffer.line_len(end_row));
let range = start_point.to_offset(buffer)..end_point.to_offset(buffer);
let window_text = buffer.text_for_range(range.clone());
let mut window_lines = window_text.lines();
let mut old_lines_iter = old_lines.iter();
while let (Some(window_line), Some(old_line)) = (window_lines.next(), old_lines_iter.next())
{
let line_trimmed = window_line.trim_start();
if line_trimmed != old_line.trim_start() {
continue 'windows;
}
if line_trimmed.is_empty() {
continue;
}
let line_leading = &window_line[..window_line.len() - old_line.len()];
match &common_leading {
Some(common_leading) if common_leading != line_leading => {
continue 'windows;
}
Some(_) => (),
None => common_leading = Some(line_leading.to_string()),
}
}
if let Some(common_leading) = common_leading {
let line_ending = buffer.line_ending();
let replacement = new_lines
.iter()
.map(|new_line| {
if new_line.trim().is_empty() {
new_line.to_string()
} else {
common_leading.to_string() + new_line
}
})
.collect::<Vec<_>>()
.join(line_ending.as_str());
let diff = Diff {
base_version: buffer.version().clone(),
line_ending,
edits: vec![(range, replacement.into())],
};
return Some(diff);
}
}
None
}
fn drop_lines_prefix<'a>(lines: &'a [&str], prefix_len: usize) -> Vec<&'a str> {
lines
.iter()
.map(|line| line.get(prefix_len..).unwrap_or(""))
.collect()
}
fn lines_with_min_indent(input: &str) -> (Vec<&str>, usize) {
let mut lines = Vec::new();
let mut min_indent: Option<usize> = None;
for line in input.lines() {
lines.push(line);
if !line.trim().is_empty() {
let indent = line.len() - line.trim_start().len();
min_indent = Some(min_indent.map_or(indent, |m| m.min(indent)));
}
}
(lines, min_indent.unwrap_or(0))
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::prelude::*;
use gpui::TestAppContext;
use unindent::Unindent;
#[gpui::test]
fn test_replace_consistent_indentation(cx: &mut TestAppContext) {
let whole = r#"
fn test() {
let x = 5;
println!("x = {}", x);
let y = 10;
}
"#
.unindent();
let old = r#"
let x = 5;
println!("x = {}", x);
"#
.unindent();
let new = r#"
let x = 42;
println!("New value: {}", x);
"#
.unindent();
let expected = r#"
fn test() {
let x = 42;
println!("New value: {}", x);
let y = 10;
}
"#
.unindent();
assert_eq!(
test_replace_with_flexible_indent(cx, &whole, &old, &new),
Some(expected.to_string())
);
}
#[gpui::test]
fn test_replace_inconsistent_indentation(cx: &mut TestAppContext) {
let whole = r#"
fn test() {
if condition {
println!("{}", 43);
}
}
"#
.unindent();
let old = r#"
if condition {
println!("{}", 43);
"#
.unindent();
let new = r#"
if condition {
println!("{}", 42);
"#
.unindent();
assert_eq!(
test_replace_with_flexible_indent(cx, &whole, &old, &new),
None
);
}
#[gpui::test]
fn test_replace_with_empty_lines(cx: &mut TestAppContext) {
// Test with empty lines
let whole = r#"
fn test() {
let x = 5;
println!("x = {}", x);
}
"#
.unindent();
let old = r#"
let x = 5;
println!("x = {}", x);
"#
.unindent();
let new = r#"
let x = 10;
println!("New x: {}", x);
"#
.unindent();
let expected = r#"
fn test() {
let x = 10;
println!("New x: {}", x);
}
"#
.unindent();
assert_eq!(
test_replace_with_flexible_indent(cx, &whole, &old, &new),
Some(expected.to_string())
);
}
#[gpui::test]
fn test_replace_no_match(cx: &mut TestAppContext) {
// Test with no match
let whole = r#"
fn test() {
let x = 5;
}
"#
.unindent();
let old = r#"
let y = 10;
"#
.unindent();
let new = r#"
let y = 20;
"#
.unindent();
assert_eq!(
test_replace_with_flexible_indent(cx, &whole, &old, &new),
None
);
}
#[gpui::test]
fn test_replace_whole_ends_before_matching_old(cx: &mut TestAppContext) {
let whole = r#"
fn test() {
let x = 5;
"#
.unindent();
let old = r#"
let x = 5;
println!("x = {}", x);
"#
.unindent();
let new = r#"
let x = 10;
println!("x = {}", x);
"#
.unindent();
// Should return None because whole doesn't fully contain the old text
assert_eq!(
test_replace_with_flexible_indent(cx, &whole, &old, &new),
None
);
}
#[test]
fn test_lines_with_min_indent() {
// Empty string
assert_eq!(lines_with_min_indent(""), (vec![], 0));
// Single line without indentation
assert_eq!(lines_with_min_indent("hello"), (vec!["hello"], 0));
// Multiple lines with no indentation
assert_eq!(
lines_with_min_indent("line1\nline2\nline3"),
(vec!["line1", "line2", "line3"], 0)
);
// Multiple lines with consistent indentation
assert_eq!(
lines_with_min_indent(" line1\n line2\n line3"),
(vec![" line1", " line2", " line3"], 2)
);
// Multiple lines with varying indentation
assert_eq!(
lines_with_min_indent(" line1\n line2\n line3"),
(vec![" line1", " line2", " line3"], 2)
);
// Lines with mixed indentation and empty lines
assert_eq!(
lines_with_min_indent(" line1\n\n line2"),
(vec![" line1", "", " line2"], 2)
);
}
#[gpui::test]
fn test_replace_with_missing_indent_uneven_match(cx: &mut TestAppContext) {
let whole = r#"
fn test() {
if true {
let x = 5;
println!("x = {}", x);
}
}
"#
.unindent();
let old = r#"
let x = 5;
println!("x = {}", x);
"#
.unindent();
let new = r#"
let x = 42;
println!("x = {}", x);
"#
.unindent();
let expected = r#"
fn test() {
if true {
let x = 42;
println!("x = {}", x);
}
}
"#
.unindent();
assert_eq!(
test_replace_with_flexible_indent(cx, &whole, &old, &new),
Some(expected.to_string())
);
}
#[gpui::test]
fn test_replace_big_example(cx: &mut TestAppContext) {
let whole = r#"
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_valid_age() {
assert!(is_valid_age(0));
assert!(!is_valid_age(151));
}
}
"#
.unindent();
let old = r#"
#[test]
fn test_is_valid_age() {
assert!(is_valid_age(0));
assert!(!is_valid_age(151));
}
"#
.unindent();
let new = r#"
#[test]
fn test_is_valid_age() {
assert!(is_valid_age(0));
assert!(!is_valid_age(151));
}
#[test]
fn test_group_people_by_age() {
let people = vec![
Person::new("Young One", 5, "young@example.com").unwrap(),
Person::new("Teen One", 15, "teen@example.com").unwrap(),
Person::new("Teen Two", 18, "teen2@example.com").unwrap(),
Person::new("Adult One", 25, "adult@example.com").unwrap(),
];
let groups = group_people_by_age(&people);
assert_eq!(groups.get(&0).unwrap().len(), 1); // One person in 0-9
assert_eq!(groups.get(&10).unwrap().len(), 2); // Two people in 10-19
assert_eq!(groups.get(&20).unwrap().len(), 1); // One person in 20-29
}
"#
.unindent();
let expected = r#"
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_valid_age() {
assert!(is_valid_age(0));
assert!(!is_valid_age(151));
}
#[test]
fn test_group_people_by_age() {
let people = vec![
Person::new("Young One", 5, "young@example.com").unwrap(),
Person::new("Teen One", 15, "teen@example.com").unwrap(),
Person::new("Teen Two", 18, "teen2@example.com").unwrap(),
Person::new("Adult One", 25, "adult@example.com").unwrap(),
];
let groups = group_people_by_age(&people);
assert_eq!(groups.get(&0).unwrap().len(), 1); // One person in 0-9
assert_eq!(groups.get(&10).unwrap().len(), 2); // Two people in 10-19
assert_eq!(groups.get(&20).unwrap().len(), 1); // One person in 20-29
}
}
"#
.unindent();
assert_eq!(
test_replace_with_flexible_indent(cx, &whole, &old, &new),
Some(expected.to_string())
);
}
#[test]
fn test_drop_lines_prefix() {
// Empty array
assert_eq!(drop_lines_prefix(&[], 2), Vec::<&str>::new());
// Zero prefix length
assert_eq!(
drop_lines_prefix(&["line1", "line2"], 0),
vec!["line1", "line2"]
);
// Normal prefix drop
assert_eq!(
drop_lines_prefix(&[" line1", " line2"], 2),
vec!["line1", "line2"]
);
// Prefix longer than some lines
assert_eq!(drop_lines_prefix(&[" line1", "a"], 2), vec!["line1", ""]);
// Prefix longer than all lines
assert_eq!(drop_lines_prefix(&["a", "b"], 5), vec!["", ""]);
// Mixed length lines
assert_eq!(
drop_lines_prefix(&[" line1", " line2", " line3"], 2),
vec![" line1", "line2", " line3"]
);
}
fn test_replace_with_flexible_indent(
cx: &mut TestAppContext,
whole: &str,
old: &str,
new: &str,
) -> Option<String> {
// Create a local buffer with the test content
let buffer = cx.new(|cx| language::Buffer::local(whole, cx));
// Get the buffer snapshot
let buffer_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
// Call replace_flexible and transform the result
replace_with_flexible_indent(old, new, &buffer_snapshot).map(|diff| {
buffer.update(cx, |buffer, cx| {
let _ = buffer.apply_diff(diff, cx);
buffer.text()
})
})
}
}

View File

@@ -1,160 +0,0 @@
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;
use anyhow::{anyhow, bail, Context as _, Result};
use assistant_tool::{ActionLog, Tool};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext as _, Entity, Task};
use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
use http_client::{AsyncBody, HttpClientWithUrl};
use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
enum ContentType {
Html,
Plaintext,
Json,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct FetchToolInput {
/// The URL to fetch.
url: String,
}
pub struct FetchTool {
http_client: Arc<HttpClientWithUrl>,
}
impl FetchTool {
pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
Self { http_client }
}
async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
let mut url = url.to_owned();
if !url.starts_with("https://") && !url.starts_with("http://") {
url = format!("https://{url}");
}
let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
let mut body = Vec::new();
response
.body_mut()
.read_to_end(&mut body)
.await
.context("error reading response body")?;
if response.status().is_client_error() {
let text = String::from_utf8_lossy(body.as_slice());
bail!(
"status error {}, response: {text:?}",
response.status().as_u16()
);
}
let Some(content_type) = response.headers().get("content-type") else {
bail!("missing Content-Type header");
};
let content_type = content_type
.to_str()
.context("invalid Content-Type header")?;
let content_type = match content_type {
"text/html" => ContentType::Html,
"text/plain" => ContentType::Plaintext,
"application/json" => ContentType::Json,
_ => ContentType::Html,
};
match content_type {
ContentType::Html => {
let mut handlers: Vec<TagHandler> = vec![
Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
Rc::new(RefCell::new(markdown::ParagraphHandler)),
Rc::new(RefCell::new(markdown::HeadingHandler)),
Rc::new(RefCell::new(markdown::ListHandler)),
Rc::new(RefCell::new(markdown::TableHandler::new())),
Rc::new(RefCell::new(markdown::StyledTextHandler)),
];
if url.contains("wikipedia.org") {
use html_to_markdown::structure::wikipedia;
handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
handlers.push(Rc::new(
RefCell::new(wikipedia::WikipediaCodeHandler::new()),
));
} else {
handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
}
convert_html_to_markdown(&body[..], &mut handlers)
}
ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
ContentType::Json => {
let json: serde_json::Value = serde_json::from_slice(&body)?;
Ok(format!(
"```json\n{}\n```",
serde_json::to_string_pretty(&json)?
))
}
}
}
}
impl Tool for FetchTool {
fn name(&self) -> String {
"fetch".to_string()
}
fn description(&self) -> String {
include_str!("./fetch_tool/description.md").to_string()
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(FetchToolInput);
serde_json::to_value(&schema).unwrap()
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FetchToolInput>(input.clone()) {
Ok(input) => format!("Fetch `{}`", input.url),
Err(_) => "Fetch URL".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<FetchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let text = cx.background_spawn({
let http_client = self.http_client.clone();
let url = input.url.clone();
async move { Self::build_message(http_client, &url).await }
});
cx.foreground_executor().spawn(async move {
let text = text.await?;
if text.trim().is_empty() {
bail!("no textual content found");
}
Ok(text)
})
}
}

View File

@@ -1 +0,0 @@
Fetches a URL and returns the content as Markdown.

Some files were not shown because too many files have changed in this diff Show More