Compare commits
2 Commits
windows/re
...
fix-keep-e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
585110b2cb | ||
|
|
4059a21422 |
26
.github/workflows/ci.yml
vendored
@@ -269,10 +269,6 @@ jobs:
|
||||
mkdir -p ./../.cargo
|
||||
cp ./.cargo/ci-config.toml ./../.cargo/config.toml
|
||||
|
||||
- name: Check that Cargo.lock is up to date
|
||||
run: |
|
||||
cargo update --locked --workspace
|
||||
|
||||
- name: cargo clippy
|
||||
run: ./script/clippy
|
||||
|
||||
@@ -771,7 +767,7 @@ jobs:
|
||||
timeout-minutes: 120
|
||||
name: Create a Windows installer
|
||||
runs-on: [self-hosted, Windows, X64]
|
||||
if: true && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
|
||||
if: false && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
|
||||
needs: [windows_tests]
|
||||
env:
|
||||
AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }}
|
||||
@@ -807,16 +803,16 @@ jobs:
|
||||
name: ZedEditorUserSetup-x64-${{ github.event.pull_request.head.sha || github.sha }}.exe
|
||||
path: ${{ env.SETUP_PATH }}
|
||||
|
||||
# - name: Upload Artifacts to release
|
||||
# uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
|
||||
# # Re-enable when we are ready to publish windows preview releases
|
||||
# if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview
|
||||
# with:
|
||||
# draft: true
|
||||
# prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
|
||||
# files: ${{ env.SETUP_PATH }}
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Upload Artifacts to release
|
||||
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
|
||||
# Re-enable when we are ready to publish windows preview releases
|
||||
if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview
|
||||
with:
|
||||
draft: true
|
||||
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
|
||||
files: ${{ env.SETUP_PATH }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
auto-release-preview:
|
||||
name: Auto release preview
|
||||
|
||||
20
.github/workflows/release_nightly.yml
vendored
@@ -111,11 +111,6 @@ jobs:
|
||||
echo "Publishing version: ${version} on release channel nightly"
|
||||
echo "nightly" > crates/zed/RELEASE_CHANNEL
|
||||
|
||||
- name: Setup Sentry CLI
|
||||
uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2
|
||||
with:
|
||||
token: ${{ SECRETS.SENTRY_AUTH_TOKEN }}
|
||||
|
||||
- name: Create macOS app bundle
|
||||
run: script/bundle-mac
|
||||
|
||||
@@ -141,11 +136,6 @@ jobs:
|
||||
- name: Install Linux dependencies
|
||||
run: ./script/linux && ./script/install-mold 2.34.0
|
||||
|
||||
- name: Setup Sentry CLI
|
||||
uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2
|
||||
with:
|
||||
token: ${{ SECRETS.SENTRY_AUTH_TOKEN }}
|
||||
|
||||
- name: Limit target directory size
|
||||
run: script/clear-target-dir-if-larger-than 100
|
||||
|
||||
@@ -178,11 +168,6 @@ jobs:
|
||||
- name: Install Linux dependencies
|
||||
run: ./script/linux
|
||||
|
||||
- name: Setup Sentry CLI
|
||||
uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2
|
||||
with:
|
||||
token: ${{ SECRETS.SENTRY_AUTH_TOKEN }}
|
||||
|
||||
- name: Limit target directory size
|
||||
run: script/clear-target-dir-if-larger-than 100
|
||||
|
||||
@@ -277,11 +262,6 @@ jobs:
|
||||
Write-Host "Publishing version: $version on release channel nightly"
|
||||
"nightly" | Set-Content -Path "crates/zed/RELEASE_CHANNEL"
|
||||
|
||||
- name: Setup Sentry CLI
|
||||
uses: matbour/setup-sentry-cli@3e938c54b3018bdd019973689ef984e033b0454b #v2
|
||||
with:
|
||||
token: ${{ SECRETS.SENTRY_AUTH_TOKEN }}
|
||||
|
||||
- name: Build Zed installer
|
||||
working-directory: ${{ env.ZED_WORKSPACE }}
|
||||
run: script/bundle-windows.ps1
|
||||
|
||||
69
Cargo.lock
generated
@@ -6,7 +6,6 @@ version = 4
|
||||
name = "acp_thread"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"agent-client-protocol",
|
||||
"agentic-coding-protocol",
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
@@ -136,23 +135,11 @@ dependencies = [
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent-client-protocol"
|
||||
version = "0.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b"
|
||||
dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent_servers"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"acp_thread",
|
||||
"agent-client-protocol",
|
||||
"agentic-coding-protocol",
|
||||
"anyhow",
|
||||
"collections",
|
||||
@@ -168,7 +155,6 @@ dependencies = [
|
||||
"nix 0.29.0",
|
||||
"paths",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -209,9 +195,9 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"acp_thread",
|
||||
"agent",
|
||||
"agent-client-protocol",
|
||||
"agent_servers",
|
||||
"agent_settings",
|
||||
"agentic-coding-protocol",
|
||||
"ai_onboarding",
|
||||
"anyhow",
|
||||
"assistant_context",
|
||||
@@ -224,7 +210,6 @@ dependencies = [
|
||||
"chrono",
|
||||
"client",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"component",
|
||||
"context_server",
|
||||
"db",
|
||||
@@ -246,7 +231,6 @@ dependencies = [
|
||||
"jsonschema",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"languages",
|
||||
"log",
|
||||
"lsp",
|
||||
@@ -285,7 +269,6 @@ dependencies = [
|
||||
"time_format",
|
||||
"tree-sitter-md",
|
||||
"ui",
|
||||
"ui_input",
|
||||
"unindent",
|
||||
"urlencoding",
|
||||
"util",
|
||||
@@ -361,7 +344,6 @@ dependencies = [
|
||||
"proto",
|
||||
"serde",
|
||||
"smallvec",
|
||||
"telemetry",
|
||||
"ui",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
@@ -1887,7 +1869,9 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"aws-smithy-runtime-api",
|
||||
"aws-smithy-types",
|
||||
"futures 0.3.31",
|
||||
"http_client",
|
||||
"tokio",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -4258,7 +4242,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "dap-types"
|
||||
version = "0.0.1"
|
||||
source = "git+https://github.com/zed-industries/dap-types?rev=1b461b310481d01e02b2603c16d7144b926339f8#1b461b310481d01e02b2603c16d7144b926339f8"
|
||||
source = "git+https://github.com/zed-industries/dap-types?rev=7f39295b441614ca9dbf44293e53c32f666897f9#7f39295b441614ca9dbf44293e53c32f666897f9"
|
||||
dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
@@ -4980,7 +4964,6 @@ dependencies = [
|
||||
"text",
|
||||
"theme",
|
||||
"time",
|
||||
"tree-sitter-bash",
|
||||
"tree-sitter-html",
|
||||
"tree-sitter-python",
|
||||
"tree-sitter-rust",
|
||||
@@ -5386,13 +5369,11 @@ dependencies = [
|
||||
"log",
|
||||
"lsp",
|
||||
"parking_lot",
|
||||
"pretty_assertions",
|
||||
"semantic_version",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"task",
|
||||
"toml 0.8.20",
|
||||
"url",
|
||||
"util",
|
||||
"wasm-encoder 0.221.3",
|
||||
"wasmparser 0.221.3",
|
||||
@@ -6377,7 +6358,6 @@ dependencies = [
|
||||
"buffer_diff",
|
||||
"call",
|
||||
"chrono",
|
||||
"client",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"component",
|
||||
@@ -7419,9 +7399,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "grid"
|
||||
version = "0.17.0"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71b01d27060ad58be4663b9e4ac9e2d4806918e8876af8912afbddd1a91d5eaa"
|
||||
checksum = "be136d9dacc2a13cc70bb6c8f902b414fb2641f8db1314637c6b7933411a8f82"
|
||||
|
||||
[[package]]
|
||||
name = "group"
|
||||
@@ -7692,12 +7672,6 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "hex-literal"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bcaaec4551594c969335c98c903c1397853d4198408ea609190f420500f6be71"
|
||||
|
||||
[[package]]
|
||||
name = "hexf-parse"
|
||||
version = "0.2.1"
|
||||
@@ -7879,7 +7853,6 @@ dependencies = [
|
||||
"derive_more 0.99.19",
|
||||
"futures 0.3.31",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -9042,7 +9015,6 @@ dependencies = [
|
||||
"task",
|
||||
"text",
|
||||
"theme",
|
||||
"toml 0.8.20",
|
||||
"tree-sitter",
|
||||
"tree-sitter-elixir",
|
||||
"tree-sitter-embedded-template",
|
||||
@@ -9125,11 +9097,11 @@ dependencies = [
|
||||
"client",
|
||||
"collections",
|
||||
"component",
|
||||
"convert_case 0.8.0",
|
||||
"copilot",
|
||||
"credentials_provider",
|
||||
"deepseek",
|
||||
"editor",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"google_ai",
|
||||
"gpui",
|
||||
@@ -9418,7 +9390,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "libwebrtc"
|
||||
version = "0.3.10"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
|
||||
dependencies = [
|
||||
"cxx",
|
||||
"jni",
|
||||
@@ -9498,7 +9470,7 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856"
|
||||
[[package]]
|
||||
name = "livekit"
|
||||
version = "0.7.8"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"futures-util",
|
||||
@@ -9521,7 +9493,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "livekit-api"
|
||||
version = "0.4.2"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"http 0.2.12",
|
||||
@@ -9545,7 +9517,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "livekit-protocol"
|
||||
version = "0.3.9"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"livekit-runtime",
|
||||
@@ -9562,7 +9534,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "livekit-runtime"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
|
||||
dependencies = [
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
@@ -11028,7 +11000,6 @@ dependencies = [
|
||||
"ui",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -14804,7 +14775,6 @@ dependencies = [
|
||||
"fs",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"log",
|
||||
"menu",
|
||||
@@ -15985,12 +15955,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "taffy"
|
||||
version = "0.8.3"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7aaef0ac998e6527d6d0d5582f7e43953bb17221ac75bb8eb2fcc2db3396db1c"
|
||||
checksum = "e8b61630cba2afd2c851821add2e1bb1b7851a2436e839ab73b56558b009035e"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"grid",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"slotmap",
|
||||
]
|
||||
@@ -17717,6 +17688,7 @@ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
name = "vim"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"agent_ui",
|
||||
"anyhow",
|
||||
"assets",
|
||||
"async-compat",
|
||||
@@ -18550,7 +18522,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "webrtc-sys"
|
||||
version = "0.3.7"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cxx",
|
||||
@@ -18563,15 +18535,13 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "webrtc-sys-build"
|
||||
version = "0.3.6"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
|
||||
dependencies = [
|
||||
"fs2",
|
||||
"hex-literal",
|
||||
"regex",
|
||||
"reqwest 0.11.27",
|
||||
"scratch",
|
||||
"semver",
|
||||
"sha2",
|
||||
"zip",
|
||||
]
|
||||
|
||||
@@ -20195,7 +20165,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.198.0"
|
||||
version = "0.197.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
@@ -20236,7 +20206,6 @@ dependencies = [
|
||||
"extension",
|
||||
"extension_host",
|
||||
"extensions_ui",
|
||||
"feature_flags",
|
||||
"feedback",
|
||||
"file_finder",
|
||||
"fs",
|
||||
|
||||
14
Cargo.toml
@@ -413,7 +413,6 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||
#
|
||||
|
||||
agentic-coding-protocol = "0.0.10"
|
||||
agent-client-protocol = "0.0.11"
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
@@ -460,7 +459,7 @@ core-video = { version = "0.4.3", features = ["metal"] }
|
||||
cpal = "0.16"
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
ctor = "0.4.0"
|
||||
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "1b461b310481d01e02b2603c16d7144b926339f8" }
|
||||
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "7f39295b441614ca9dbf44293e53c32f666897f9" }
|
||||
dashmap = "6.0"
|
||||
derive_more = "0.99.17"
|
||||
dirs = "4.0"
|
||||
@@ -483,7 +482,6 @@ heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
|
||||
hex = "0.4.3"
|
||||
html5ever = "0.27.0"
|
||||
http = "1.1"
|
||||
http-body = "1.0"
|
||||
hyper = "0.14"
|
||||
ignore = "0.4.22"
|
||||
image = "0.25.1"
|
||||
@@ -674,13 +672,8 @@ features = [
|
||||
"Win32_Globalization",
|
||||
"Win32_Graphics_Direct2D",
|
||||
"Win32_Graphics_Direct2D_Common",
|
||||
"Win32_Graphics_Direct3D",
|
||||
"Win32_Graphics_Direct3D11",
|
||||
"Win32_Graphics_Direct3D_Fxc",
|
||||
"Win32_Graphics_DirectComposition",
|
||||
"Win32_Graphics_DirectWrite",
|
||||
"Win32_Graphics_Dwm",
|
||||
"Win32_Graphics_Dxgi",
|
||||
"Win32_Graphics_Dxgi_Common",
|
||||
"Win32_Graphics_Gdi",
|
||||
"Win32_Graphics_Imaging",
|
||||
@@ -725,11 +718,6 @@ workspace-hack = { path = "tooling/workspace-hack" }
|
||||
split-debuginfo = "unpacked"
|
||||
codegen-units = 16
|
||||
|
||||
# mirror configuration for crates compiled for the build platform
|
||||
# (without this cargo will compile ~400 crates twice)
|
||||
[profile.dev.build-override]
|
||||
codegen-units = 16
|
||||
|
||||
[profile.dev.package]
|
||||
taffy = { opt-level = 3 }
|
||||
cranelift-codegen = { opt-level = 3 }
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M7.25669 0.999943C8.27509 0.993825 9.24655 1.42125 9.9227 2.17279C11.4427 1.85079 12.9991 2.53518 13.7733 3.86518C14.159 4.5149 14.3171 5.26409 14.2372 5.99994H13.2967C13.3789 5.42185 13.265 4.8321 12.9686 4.32514C12.2353 3.06961 10.6088 2.63919 9.33676 3.36322L6.48032 4.98822C6.46926 4.99697 6.46284 5.01135 6.46372 5.02533V6.38568L9.91294 4.42084C10.0565 4.33818 10.2336 4.33823 10.3768 4.42084L13.1502 5.99994H11.2948L9.88364 5.19623C9.87034 5.19054 9.85459 5.19128 9.84262 5.19916L8.64926 5.87983L8.8602 5.99994H7.99985C6.89539 6.00004 5.99988 6.89547 5.99985 7.99994V9.34955L3.90219 8.15522C3.75815 8.07431 3.66897 7.92228 3.66977 7.75873V4.53803C3.66977 4.50828 3.67172 4.4654 3.67172 4.44135C3.08836 4.65262 2.59832 5.0599 2.28794 5.59174C1.55635 6.84647 1.99122 8.44936 3.26059 9.17475L5.99985 10.7363V11.6162C5.87564 11.6568 5.73827 11.6456 5.6229 11.579L2.7977 9.96869C2.77156 9.95382 2.73449 9.9311 2.71372 9.91889C2.60687 10.5231 2.7194 11.1466 3.0311 11.6777C3.6435 12.7209 4.87159 13.1902 5.99985 12.9023V13.8398C4.50443 14.1233 2.98758 13.4424 2.22641 12.1347C1.71174 11.2677 1.60096 10.2237 1.9227 9.27045C0.880739 8.13295 0.703328 6.46023 1.48325 5.13373C1.98739 4.26024 2.84863 3.64401 3.84653 3.44233C4.3245 1.9837 5.70306 0.996447 7.25669 0.999943ZM7.25766 1.91498C5.78932 1.9143 4.59839 3.08914 4.59751 4.53803V7.79193C4.59926 7.80578 4.60735 7.81796 4.61997 7.82416L5.8143 8.50483L5.81626 4.57611C5.81537 4.41216 5.90431 4.2606 6.04868 4.17963L8.87387 2.56928C8.89868 2.55441 8.93612 2.53379 8.95786 2.5224C8.48035 2.13046 7.8788 1.91498 7.25766 1.91498Z" fill="black"/>
|
||||
<path d="M13.5 6C14.6046 6 15.5 6.89543 15.5 8V13.5C15.5 14.6046 14.6046 15.5 13.5 15.5H8C6.89543 15.5 6 14.6046 6 13.5V8C6 6.89543 6.89543 6 8 6H13.5ZM10.8916 8.02539C10.0563 8.02539 9.33453 8.27982 8.81934 8.76562C8.30213 9.25335 8.02547 9.94371 8.02539 10.748C8.02539 11.557 8.29852 12.2492 8.81543 12.7373C9.33013 13.2232 10.0521 13.4746 10.8916 13.4746C11.9865 13.4745 12.8545 13.1022 13.3076 12.3525C13.3894 12.2176 13.4521 12.0693 13.4521 11.8857C13.4521 11.4795 13.0933 11.2773 12.7842 11.2773C12.6604 11.2774 12.5292 11.3025 12.4072 11.3779C12.2862 11.4529 12.2058 11.5586 12.1494 11.666L12.1475 11.6689C11.9677 12.0213 11.5535 12.246 10.8955 12.2461C10.4219 12.2461 10.0667 12.0932 9.83008 11.8506C9.59255 11.607 9.44141 11.2389 9.44141 10.748C9.44148 10.264 9.59319 9.89628 9.83203 9.65137C10.0702 9.40725 10.4255 9.25391 10.8916 9.25391C11.4912 9.25399 11.9415 9.50614 12.1289 9.8916V9.89062C12.1888 10.0157 12.276 10.1311 12.4023 10.2129C12.5303 10.2956 12.6724 10.3271 12.8115 10.3271C12.9661 10.3271 13.1303 10.2857 13.2627 10.1758C13.4018 10.0603 13.4746 9.89383 13.4746 9.71582C13.4746 9.61857 13.4542 9.52036 13.4199 9.42773L13.3818 9.33691C12.9749 8.49175 11.9927 8.02548 10.8916 8.02539ZM10.3203 8.97852L10.1494 9.03516C10.2095 9.01178 10.2716 8.99089 10.3359 8.97363C10.3307 8.97505 10.3256 8.97706 10.3203 8.97852ZM10.4814 8.94141C10.4969 8.9385 10.5126 8.93616 10.5283 8.93359C10.5126 8.93617 10.4969 8.9385 10.4814 8.94141ZM10.6709 8.91504C10.6819 8.91399 10.693 8.913 10.7041 8.91211C10.693 8.913 10.6819 8.91399 10.6709 8.91504Z" fill="black" fill-opacity="0.95"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 3.2 KiB |
@@ -1,7 +1 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.6667 6C11.003 6.44823 11.2208 6.97398 11.3001 7.52867" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M12.9094 3.75732C13.7621 4.6095 14.3383 5.69876 14.5629 6.88315C14.7875 8.06754 14.6502 9.29213 14.1688 10.3973" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M2.66675 2L13.6667 13" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M5.33333 4.66669L4.942 5.05802C4.85494 5.1456 4.75136 5.21504 4.63726 5.2623C4.52317 5.30957 4.40083 5.33372 4.27733 5.33335H2.66667C2.48986 5.33335 2.32029 5.40359 2.19526 5.52862C2.07024 5.65364 2 5.82321 2 6.00002V10C2 10.1768 2.07024 10.3464 2.19526 10.4714C2.32029 10.5964 2.48986 10.6667 2.66667 10.6667H4.27733C4.40083 10.6663 4.52317 10.6905 4.63726 10.7377C4.75136 10.785 4.85494 10.8544 4.942 10.942L7.19733 13.198C7.26307 13.2639 7.34687 13.3088 7.43813 13.3269C7.52939 13.3451 7.62399 13.3358 7.70995 13.3002C7.79591 13.2646 7.86936 13.2042 7.921 13.1268C7.97263 13.0494 8.00013 12.9584 8 12.8654V7.33335" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M7.21875 2.78136C7.28267 2.71719 7.36421 2.67345 7.45303 2.65568C7.54184 2.63791 7.63393 2.64691 7.71762 2.68154C7.80132 2.71618 7.87284 2.77488 7.92312 2.85022C7.97341 2.92555 8.0002 3.01412 8.00008 3.10469V3.56202" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-volume-off"><path d="M16 9a5 5 0 0 1 .95 2.293"/><path d="M19.364 5.636a9 9 0 0 1 1.889 9.96"/><path d="m2 2 20 20"/><path d="m7 7-.587.587A1.4 1.4 0 0 1 5.416 8H3a1 1 0 0 0-1 1v6a1 1 0 0 0 1 1h2.416a1.4 1.4 0 0 1 .997.413l3.383 3.384A.705.705 0 0 0 11 19.298V11"/><path d="M9.828 4.172A.686.686 0 0 1 11 4.657v.686"/></svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 1.6 KiB After Width: | Height: | Size: 527 B |
@@ -1,5 +1 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M8 3.13467C7.99987 3.04181 7.97223 2.95107 7.92057 2.8739C7.86892 2.79674 7.79557 2.7366 7.70977 2.70108C7.62397 2.66557 7.52958 2.65626 7.43849 2.67434C7.34741 2.69242 7.26373 2.73707 7.198 2.80266L4.942 5.058C4.85494 5.14558 4.75136 5.21502 4.63726 5.26228C4.52317 5.30954 4.40083 5.33369 4.27733 5.33333H2.66667C2.48986 5.33333 2.32029 5.40357 2.19526 5.52859C2.07024 5.65362 2 5.82319 2 6V10C2 10.1768 2.07024 10.3464 2.19526 10.4714C2.32029 10.5964 2.48986 10.6667 2.66667 10.6667H4.27733C4.40083 10.6663 4.52317 10.6905 4.63726 10.7377C4.75136 10.785 4.85494 10.8544 4.942 10.942L7.19733 13.198C7.26307 13.2639 7.34687 13.3087 7.43813 13.3269C7.52939 13.3451 7.62399 13.3358 7.70995 13.3002C7.79591 13.2645 7.86936 13.2042 7.921 13.1268C7.97263 13.0494 8.00013 12.9584 8 12.8653V3.13467Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M10.6667 6C11.0995 6.57699 11.3334 7.27877 11.3334 8C11.3334 8.72123 11.0995 9.42301 10.6667 10" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M12.9094 12.2427C13.4666 11.6855 13.9085 11.0241 14.2101 10.2961C14.5116 9.56815 14.6668 8.78793 14.6668 7.99999C14.6668 7.21205 14.5116 6.43183 14.2101 5.70387C13.9085 4.97591 13.4666 4.31448 12.9094 3.75732" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-volume-2"><path d="M11 4.702a.705.705 0 0 0-1.203-.498L6.413 7.587A1.4 1.4 0 0 1 5.416 8H3a1 1 0 0 0-1 1v6a1 1 0 0 0 1 1h2.416a1.4 1.4 0 0 1 .997.413l3.383 3.384A.705.705 0 0 0 11 19.298z"/><path d="M16 9a5 5 0 0 1 0 6"/><path d="M19.364 18.364a9 9 0 0 0 0-12.728"/></svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 475 B |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-cloud-download-icon lucide-cloud-download"><path d="M12 13v8l-4-4"/><path d="m12 21 4-4"/><path d="M4.393 15.269A7 7 0 1 1 15.71 8h1.79a4.5 4.5 0 0 1 2.436 8.284"/></svg>
|
||||
|
Before Width: | Height: | Size: 372 B |
@@ -1,5 +1,8 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.437 11.0461L13.4831 8L10.437 4.95392" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M13 8L8 8" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M6.6553 13.4659H4.21843C3.89528 13.4659 3.58537 13.3375 3.35687 13.109C3.12837 12.8805 3 12.5706 3 12.2475V3.71843C3 3.39528 3.12837 3.08537 3.35687 2.85687C3.58537 2.62837 3.89528 2.5 4.21843 2.5H6.6553" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
clip-rule="evenodd"
|
||||
d="M3 1C2.44771 1 2 1.44772 2 2V13C2 13.5523 2.44772 14 3 14H10.5C10.7761 14 11 13.7761 11 13.5C11 13.2239 10.7761 13 10.5 13H3V2L10.5 2C10.7761 2 11 1.77614 11 1.5C11 1.22386 10.7761 1 10.5 1H3ZM12.6036 4.89645C12.4083 4.70118 12.0917 4.70118 11.8964 4.89645C11.7012 5.09171 11.7012 5.40829 11.8964 5.60355L13.2929 7H6.5C6.22386 7 6 7.22386 6 7.5C6 7.77614 6.22386 8 6.5 8H13.2929L11.8964 9.39645C11.7012 9.59171 11.7012 9.90829 11.8964 10.1036C12.0917 10.2988 12.4083 10.2988 12.6036 10.1036L14.8536 7.85355C15.0488 7.65829 15.0488 7.34171 14.8536 7.14645L12.6036 4.89645Z"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 637 B After Width: | Height: | Size: 768 B |
@@ -1,3 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M9.628 11.0743V10.4575H8.45562L8.65084 10.2445C8.75911 10.1264 8.96952 9.79454 9.11862 9.50789C9.52153 8.73047 9.51798 7.25107 9.11862 6.43992C8.58614 5.35722 7.49453 4.56381 6.24942 4.35703C4.59252 4.08192 2.86196 5.00312 2.14045 6.54287C1.77038 7.33182 1.77038 8.64437 2.14045 9.43333C2.45905 10.1122 3.11309 10.8204 3.73609 11.1595C4.51439 11.5828 5.18264 11.676 7.51312 11.6848L9.62627 11.6928L9.628 11.0743ZM5.30605 10.169C4.24109 10.0111 3.45215 9.07124 3.45659 7.96813C3.45659 7.33004 3.70064 6.80022 4.18697 6.36182C4.67685 5.91986 5.1312 5.77344 5.86602 5.82048C7.00287 5.89236 7.82382 6.79845 7.82382 7.98056C7.82382 8.61332 7.71996 8.91682 7.33036 9.42534C6.90172 9.98444 6.08345 10.2853 5.30692 10.1699M15.1374 10.9802V10.2684H11.8138V4.47509H10.1986V11.6928H15.1374V10.9802Z" fill="black"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 916 B |
@@ -1,5 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M8 12.2028V14.3042" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M12.2027 6.94928V8.11672C12.2027 9.20041 11.7599 10.2397 10.9717 11.006C10.1836 11.7723 9.11457 12.2028 7.99992 12.2028C6.88527 12.2028 5.81627 11.7723 5.02809 11.006C4.23991 10.2397 3.79712 9.20041 3.79712 8.11672V6.94928" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M10.1015 3.63555C10.1015 2.56426 9.16065 1.6958 8.00008 1.6958C6.83951 1.6958 5.89868 2.56426 5.89868 3.63555V8.16165C5.89868 9.23294 6.83951 10.1014 8.00008 10.1014C9.16065 10.1014 10.1015 9.23294 10.1015 8.16165V3.63555Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.72742 8.83338C3.63539 8.57302 3.34973 8.43656 3.08937 8.52858C2.82901 8.6206 2.69255 8.90626 2.78458 9.16662C2.86101 9.38288 2.95188 9.59228 3.056 9.79364C3.81427 11.2601 5.27842 12.3044 7.00014 12.4753L7.00014 14L5.50014 14C5.22399 14 5.00014 14.2239 5.00014 14.5C5.00014 14.7761 5.22399 15 5.50014 15L7.50014 15L9.50014 15C9.77628 15 10.0001 14.7761 10.0001 14.5C10.0001 14.2239 9.77628 14 9.50014 14L8.00014 14L8.00014 12.4753C9.72168 12.3043 11.1857 11.26 11.9439 9.79364C12.048 9.59228 12.1389 9.38288 12.2153 9.16662C12.3073 8.90626 12.1709 8.6206 11.9105 8.52858C11.6501 8.43656 11.3645 8.57302 11.2725 8.83338C11.2114 9.00607 11.1388 9.17337 11.0556 9.33433C10.3899 10.6218 9.04706 11.5 7.49994 11.5C5.95282 11.5 4.60997 10.6218 3.94428 9.33433C3.86104 9.17337 3.78845 9.00607 3.72742 8.83338ZM5.5 3.5L5.5 7.5C5.5 8.60457 6.39543 9.5 7.5 9.5C8.60457 9.5 9.5 8.60457 9.5 7.5L9.5 3.5C9.5 2.39543 8.60457 1.5 7.5 1.5C6.39543 1.5 5.5 2.39543 5.5 3.5ZM4.5 7.5C4.5 9.15685 5.84315 10.5 7.5 10.5C9.15685 10.5 10.5 9.15685 10.5 7.5L10.5 3.5C10.5 1.84315 9.15685 0.5 7.5 0.5C5.84315 0.5 4.5 1.84315 4.5 3.5L4.5 7.5Z" fill="black"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 847 B After Width: | Height: | Size: 1.3 KiB |
@@ -1,8 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M3 3L13 13" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M12 9C12 8.74858 12 8.49375 12 8.23839V7" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M4.00043 7V8.09869C3.98856 8.86731 4.22157 9.62164 4.66938 10.2643C5.11718 10.907 5.75924 11.4085 6.51267 11.7042C7.2661 11.9999 8.09632 12.0761 8.89619 11.923C9.47851 11.8115 10.0253 11.5823 10.5 11.2539" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M10 6V3.62904C9.99714 3.26103 9.8347 2.90448 9.53885 2.6168C9.24299 2.32913 8.83093 2.12707 8.36903 2.04316C7.90713 1.95926 7.42226 1.9984 6.99252 2.15427C6.56278 2.31015 6.21317 2.57369 6 2.90245" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M6 6V8.00088C6.00031 8.39636 6.10356 8.78287 6.29674 9.11159C6.48991 9.44031 6.76433 9.69649 7.08534 9.84779C7.40634 9.99909 7.75954 10.0387 8.10032 9.96165C8.4411 9.88459 8.75417 9.69431 9 9.41483" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 12V14" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.87 1.83637C13.0557 1.63204 13.0407 1.31581 12.8363 1.13006C12.632 0.944307 12.3158 0.959365 12.13 1.16369L10.4589 3.00199C10.2216 1.58215 8.98719 0.5 7.5 0.5C5.84315 0.5 4.5 1.84315 4.5 3.5L4.5 7.5C4.5 8.0754 4.66199 8.61297 4.94286 9.06958L4.24966 9.8321C4.1363 9.6744 4.03412 9.5081 3.94428 9.33433C3.86104 9.17337 3.78845 9.00607 3.72742 8.83338C3.63539 8.57302 3.34973 8.43656 3.08937 8.52858C2.82901 8.6206 2.69255 8.90626 2.78458 9.16662C2.86101 9.38288 2.95188 9.59228 3.056 9.79364C3.20094 10.074 3.37167 10.3388 3.56506 10.5852L2.13003 12.1637C1.94428 12.368 1.95933 12.6842 2.16366 12.87C2.36799 13.0558 2.68422 13.0407 2.86997 12.8364L4.25951 11.3079C5.01297 11.9497 5.95951 12.372 7.00014 12.4753L7.00014 14L5.50014 14C5.22399 14 5.00014 14.2239 5.00014 14.5C5.00014 14.7761 5.22399 15 5.50014 15L7.50014 15L9.50014 15C9.77628 15 10.0001 14.7761 10.0001 14.5C10.0001 14.2239 9.77628 14 9.50014 14L8.00014 14L8.00014 12.4753C9.72168 12.3043 11.1857 11.26 11.9439 9.79364C12.048 9.59228 12.1389 9.38288 12.2153 9.16662C12.3073 8.90626 12.1709 8.6206 11.9105 8.52858C11.6501 8.43656 11.3645 8.57302 11.2725 8.83338C11.2114 9.00607 11.1388 9.17337 11.0556 9.33433C10.3899 10.6218 9.04706 11.5 7.49994 11.5C6.523 11.5 5.62751 11.1498 4.93254 10.5675L5.60604 9.82669C6.12251 10.2476 6.78178 10.5 7.5 10.5C9.15685 10.5 10.5 9.15685 10.5 7.5L10.5 4.44333L12.87 1.83637ZM9.5 4.05673L9.5 3.5C9.5 2.39543 8.60457 1.5 7.5 1.5C6.39543 1.5 5.5 2.39543 5.5 3.5L5.5 7.5C5.5 7.77755 5.55653 8.04189 5.65872 8.28214L9.5 4.05673ZM6.28022 9.08509L9.5 5.54333L9.5 7.5C9.5 8.60457 8.60457 9.5 7.5 9.5C7.04083 9.5 6.6178 9.34527 6.28022 9.08509Z" fill="black"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 1.3 KiB After Width: | Height: | Size: 1.8 KiB |
@@ -1,5 +1,8 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M12.8 3H3.2C2.53726 3 2 3.51167 2 4.14286V9.85714C2 10.4883 2.53726 11 3.2 11H12.8C13.4627 11 14 10.4883 14 9.85714V4.14286C14 3.51167 13.4627 3 12.8 3Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M5.33325 14H10.6666" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 11.3333V14" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
clip-rule="evenodd"
|
||||
d="M1 3.25C1 3.11193 1.11193 3 1.25 3H13.75C13.8881 3 14 3.11193 14 3.25V10.75C14 10.8881 13.8881 11 13.75 11H1.25C1.11193 11 1 10.8881 1 10.75V3.25ZM1.25 2C0.559643 2 0 2.55964 0 3.25V10.75C0 11.4404 0.559644 12 1.25 12H5.07341L4.82991 13.2986C4.76645 13.6371 5.02612 13.95 5.37049 13.95H9.62951C9.97389 13.95 10.2336 13.6371 10.1701 13.2986L9.92659 12H13.75C14.4404 12 15 11.4404 15 10.75V3.25C15 2.55964 14.4404 2 13.75 2H1.25ZM9.01091 12H5.98909L5.79222 13.05H9.20778L9.01091 12Z"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 569 B After Width: | Height: | Size: 677 B |
@@ -872,6 +872,8 @@
|
||||
"tab": "git_panel::FocusEditor",
|
||||
"shift-tab": "git_panel::FocusEditor",
|
||||
"escape": "git_panel::ToggleFocus",
|
||||
"ctrl-enter": "git::Commit",
|
||||
"ctrl-shift-enter": "git::Amend",
|
||||
"alt-enter": "menu::SecondaryConfirm",
|
||||
"delete": ["git::RestoreFile", { "skip_prompt": false }],
|
||||
"backspace": ["git::RestoreFile", { "skip_prompt": false }],
|
||||
@@ -908,9 +910,7 @@
|
||||
"ctrl-g backspace": "git::RestoreTrackedFiles",
|
||||
"ctrl-g shift-backspace": "git::TrashUntrackedFiles",
|
||||
"ctrl-space": "git::StageAll",
|
||||
"ctrl-shift-space": "git::UnstageAll",
|
||||
"ctrl-enter": "git::Commit",
|
||||
"ctrl-shift-enter": "git::Amend"
|
||||
"ctrl-shift-space": "git::UnstageAll"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -975,14 +975,9 @@
|
||||
"context": "CollabPanel && not_editing",
|
||||
"bindings": {
|
||||
"ctrl-backspace": "collab_panel::Remove",
|
||||
"space": "menu::Confirm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "CollabPanel",
|
||||
"bindings": {
|
||||
"alt-up": "collab_panel::MoveChannelUp",
|
||||
"alt-down": "collab_panel::MoveChannelDown"
|
||||
"space": "menu::Confirm",
|
||||
"ctrl-up": "collab_panel::MoveChannelUp",
|
||||
"ctrl-down": "collab_panel::MoveChannelDown"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -1137,10 +1132,7 @@
|
||||
"alt-ctrl-f": "keymap_editor::ToggleKeystrokeSearch",
|
||||
"alt-c": "keymap_editor::ToggleConflictFilter",
|
||||
"enter": "keymap_editor::EditBinding",
|
||||
"alt-enter": "keymap_editor::CreateBinding",
|
||||
"ctrl-c": "keymap_editor::CopyAction",
|
||||
"ctrl-shift-c": "keymap_editor::CopyContext",
|
||||
"ctrl-t": "keymap_editor::ShowMatchingKeybinds"
|
||||
"alt-enter": "keymap_editor::CreateBinding"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -950,6 +950,8 @@
|
||||
"tab": "git_panel::FocusEditor",
|
||||
"shift-tab": "git_panel::FocusEditor",
|
||||
"escape": "git_panel::ToggleFocus",
|
||||
"cmd-enter": "git::Commit",
|
||||
"cmd-shift-enter": "git::Amend",
|
||||
"backspace": ["git::RestoreFile", { "skip_prompt": false }],
|
||||
"delete": ["git::RestoreFile", { "skip_prompt": false }],
|
||||
"cmd-backspace": ["git::RestoreFile", { "skip_prompt": true }],
|
||||
@@ -999,9 +1001,7 @@
|
||||
"ctrl-g backspace": "git::RestoreTrackedFiles",
|
||||
"ctrl-g shift-backspace": "git::TrashUntrackedFiles",
|
||||
"cmd-ctrl-y": "git::StageAll",
|
||||
"cmd-ctrl-shift-y": "git::UnstageAll",
|
||||
"cmd-enter": "git::Commit",
|
||||
"cmd-shift-enter": "git::Amend"
|
||||
"cmd-ctrl-shift-y": "git::UnstageAll"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -1037,15 +1037,9 @@
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-backspace": "collab_panel::Remove",
|
||||
"space": "menu::Confirm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "CollabPanel",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"alt-up": "collab_panel::MoveChannelUp",
|
||||
"alt-down": "collab_panel::MoveChannelDown"
|
||||
"space": "menu::Confirm",
|
||||
"cmd-up": "collab_panel::MoveChannelUp",
|
||||
"cmd-down": "collab_panel::MoveChannelDown"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -1239,10 +1233,7 @@
|
||||
"cmd-alt-f": "keymap_editor::ToggleKeystrokeSearch",
|
||||
"cmd-alt-c": "keymap_editor::ToggleConflictFilter",
|
||||
"enter": "keymap_editor::EditBinding",
|
||||
"alt-enter": "keymap_editor::CreateBinding",
|
||||
"cmd-c": "keymap_editor::CopyAction",
|
||||
"cmd-shift-c": "keymap_editor::CopyContext",
|
||||
"cmd-t": "keymap_editor::ShowMatchingKeybinds"
|
||||
"alt-enter": "keymap_editor::CreateBinding"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && vim_mode == insert",
|
||||
"context": "Editor && vim_mode == insert && !menu",
|
||||
"bindings": {
|
||||
// "j k": "vim::NormalBefore"
|
||||
// "j k": "vim::SwitchToNormalMode"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
"ctrl-alt-s": "zed::OpenSettings",
|
||||
"ctrl-{": "pane::ActivatePreviousItem",
|
||||
"ctrl-}": "pane::ActivateNextItem",
|
||||
"shift-escape": null, // Unmap workspace::zoom
|
||||
"ctrl-f2": "debugger::Stop",
|
||||
"f6": "debugger::Pause",
|
||||
"f7": "debugger::StepInto",
|
||||
@@ -45,8 +44,8 @@
|
||||
"ctrl-alt-right": "pane::GoForward",
|
||||
"alt-f7": "editor::FindAllReferences",
|
||||
"ctrl-alt-f7": "editor::FindAllReferences",
|
||||
"ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock
|
||||
"ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleRightDock
|
||||
// "ctrl-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock
|
||||
// "ctrl-alt-b": "editor::GoToDefinitionSplit", // Conflicts with workspace::ToggleLeftDock
|
||||
"ctrl-shift-b": "editor::GoToTypeDefinition",
|
||||
"ctrl-alt-shift-b": "editor::GoToTypeDefinitionSplit",
|
||||
"f2": "editor::GoToDiagnostic",
|
||||
@@ -101,27 +100,12 @@
|
||||
"shift shift": "command_palette::Toggle",
|
||||
"ctrl-alt-shift-n": "project_symbols::Toggle",
|
||||
"alt-0": "git_panel::ToggleFocus",
|
||||
"alt-1": "project_panel::ToggleFocus",
|
||||
"alt-1": "workspace::ToggleLeftDock",
|
||||
"alt-5": "debug_panel::ToggleFocus",
|
||||
"alt-6": "diagnostics::Deploy",
|
||||
"alt-7": "outline_panel::ToggleFocus"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Pane", // this is to override the default Pane mappings to switch tabs
|
||||
"bindings": {
|
||||
"alt-1": "project_panel::ToggleFocus",
|
||||
"alt-2": null, // Bookmarks (left dock)
|
||||
"alt-3": null, // Find Panel (bottom dock)
|
||||
"alt-4": null, // Run Panel (bottom dock)
|
||||
"alt-5": "debug_panel::ToggleFocus",
|
||||
"alt-6": "diagnostics::Deploy",
|
||||
"alt-7": "outline_panel::ToggleFocus",
|
||||
"alt-8": null, // Services (bottom dock)
|
||||
"alt-9": null, // Git History (bottom dock)
|
||||
"alt-0": "git_panel::ToggleFocus"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Workspace || Editor",
|
||||
"bindings": {
|
||||
@@ -167,9 +151,6 @@
|
||||
{ "context": "OutlinePanel", "bindings": { "alt-7": "workspace::CloseActiveDock" } },
|
||||
{
|
||||
"context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)",
|
||||
"bindings": {
|
||||
"escape": "editor::ToggleFocus",
|
||||
"shift-escape": "workspace::CloseActiveDock"
|
||||
}
|
||||
"bindings": { "escape": "editor::ToggleFocus" }
|
||||
}
|
||||
]
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
"cmd-{": "pane::ActivatePreviousItem",
|
||||
"cmd-}": "pane::ActivateNextItem",
|
||||
"cmd-0": "git_panel::ToggleFocus", // overrides `cmd-0` zoom reset
|
||||
"shift-escape": null, // Unmap workspace::zoom
|
||||
"ctrl-f2": "debugger::Stop",
|
||||
"f6": "debugger::Pause",
|
||||
"f7": "debugger::StepInto",
|
||||
@@ -109,21 +108,6 @@
|
||||
"cmd-7": "outline_panel::ToggleFocus"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Pane", // this is to override the default Pane mappings to switch tabs
|
||||
"bindings": {
|
||||
"cmd-1": "project_panel::ToggleFocus",
|
||||
"cmd-2": null, // Bookmarks (left dock)
|
||||
"cmd-3": null, // Find Panel (bottom dock)
|
||||
"cmd-4": null, // Run Panel (bottom dock)
|
||||
"cmd-5": "debug_panel::ToggleFocus",
|
||||
"cmd-6": "diagnostics::Deploy",
|
||||
"cmd-7": "outline_panel::ToggleFocus",
|
||||
"cmd-8": null, // Services (bottom dock)
|
||||
"cmd-9": null, // Git History (bottom dock)
|
||||
"cmd-0": "git_panel::ToggleFocus"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Workspace || Editor",
|
||||
"bindings": {
|
||||
@@ -162,15 +146,11 @@
|
||||
}
|
||||
},
|
||||
{ "context": "GitPanel", "bindings": { "cmd-0": "workspace::CloseActiveDock" } },
|
||||
{ "context": "ProjectPanel", "bindings": { "cmd-1": "workspace::CloseActiveDock" } },
|
||||
{ "context": "DebugPanel", "bindings": { "cmd-5": "workspace::CloseActiveDock" } },
|
||||
{ "context": "Diagnostics > Editor", "bindings": { "cmd-6": "pane::CloseActiveItem" } },
|
||||
{ "context": "OutlinePanel", "bindings": { "cmd-7": "workspace::CloseActiveDock" } },
|
||||
{
|
||||
"context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)",
|
||||
"bindings": {
|
||||
"escape": "editor::ToggleFocus",
|
||||
"shift-escape": "workspace::CloseActiveDock"
|
||||
}
|
||||
"bindings": { "escape": "editor::ToggleFocus" }
|
||||
}
|
||||
]
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && mode == full",
|
||||
"context": "Editor",
|
||||
"bindings": {
|
||||
"cmd-l": "go_to_line::Toggle",
|
||||
"ctrl-shift-d": "editor::DuplicateLineDown",
|
||||
@@ -15,12 +15,7 @@
|
||||
"cmd-enter": "editor::NewlineBelow",
|
||||
"cmd-alt-enter": "editor::NewlineAbove",
|
||||
"cmd-shift-l": "editor::SelectLine",
|
||||
"cmd-shift-t": "outline::Toggle"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor",
|
||||
"bindings": {
|
||||
"cmd-shift-t": "outline::Toggle",
|
||||
"alt-backspace": "editor::DeleteToPreviousWordStart",
|
||||
"alt-shift-backspace": "editor::DeleteToNextWordEnd",
|
||||
"alt-delete": "editor::DeleteToNextWordEnd",
|
||||
@@ -44,6 +39,10 @@
|
||||
"ctrl-_": "editor::ConvertToSnakeCase"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && mode == full",
|
||||
"bindings": {}
|
||||
},
|
||||
{
|
||||
"context": "BufferSearchBar",
|
||||
"bindings": {
|
||||
|
||||
@@ -220,8 +220,6 @@
|
||||
{
|
||||
"context": "vim_mode == normal",
|
||||
"bindings": {
|
||||
"i": "vim::InsertBefore",
|
||||
"a": "vim::InsertAfter",
|
||||
"ctrl-[": "editor::Cancel",
|
||||
":": "command_palette::Toggle",
|
||||
"c": "vim::PushChange",
|
||||
@@ -355,7 +353,9 @@
|
||||
"shift-d": "vim::DeleteToEndOfLine",
|
||||
"shift-j": "vim::JoinLines",
|
||||
"shift-y": "vim::YankLine",
|
||||
"i": "vim::InsertBefore",
|
||||
"shift-i": "vim::InsertFirstNonWhitespace",
|
||||
"a": "vim::InsertAfter",
|
||||
"shift-a": "vim::InsertEndOfLine",
|
||||
"o": "vim::InsertLineBelow",
|
||||
"shift-o": "vim::InsertLineAbove",
|
||||
@@ -377,8 +377,6 @@
|
||||
{
|
||||
"context": "vim_mode == helix_normal && !menu",
|
||||
"bindings": {
|
||||
"i": "vim::HelixInsert",
|
||||
"a": "vim::HelixAppend",
|
||||
"ctrl-[": "editor::Cancel",
|
||||
";": "vim::HelixCollapseSelection",
|
||||
":": "command_palette::Toggle",
|
||||
@@ -581,6 +579,13 @@
|
||||
"shift-u": "git::UnstageAndNext" // "d shift-u"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "VimControl && (AgentDiff || editor_agent_diff)",
|
||||
"bindings": {
|
||||
"d p": "agent::Reject",
|
||||
"d u": "agent::Keep"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "vim_operator == gu",
|
||||
"bindings": {
|
||||
|
||||
@@ -691,10 +691,7 @@
|
||||
// 5. Never show the scrollbar:
|
||||
// "never"
|
||||
"show": null
|
||||
},
|
||||
// Default depth to expand outline items in the current file.
|
||||
// Set to 0 to collapse all items that have children, 1 or higher to collapse items at that depth or deeper.
|
||||
"expand_outlines_with_depth": 100
|
||||
}
|
||||
},
|
||||
"collaboration_panel": {
|
||||
// Whether to show the collaboration panel button in the status bar.
|
||||
@@ -1079,10 +1076,6 @@
|
||||
// Send anonymized usage data like what languages you're using Zed with.
|
||||
"metrics": true
|
||||
},
|
||||
// Whether to disable all AI features in Zed.
|
||||
//
|
||||
// Default: false
|
||||
"disable_ai": false,
|
||||
// Automatically update Zed. This setting may be ignored on Linux if
|
||||
// installed through a package manager.
|
||||
"auto_update": true,
|
||||
@@ -1719,7 +1712,6 @@
|
||||
"openai": {
|
||||
"api_url": "https://api.openai.com/v1"
|
||||
},
|
||||
"openai_compatible": {},
|
||||
"open_router": {
|
||||
"api_url": "https://openrouter.ai/api/v1"
|
||||
},
|
||||
|
||||
@@ -15,15 +15,13 @@
|
||||
"adapter": "JavaScript",
|
||||
"program": "$ZED_FILE",
|
||||
"request": "launch",
|
||||
"cwd": "$ZED_WORKTREE_ROOT",
|
||||
"type": "pwa-node"
|
||||
"cwd": "$ZED_WORKTREE_ROOT"
|
||||
},
|
||||
{
|
||||
"label": "JavaScript debug terminal",
|
||||
"adapter": "JavaScript",
|
||||
"request": "launch",
|
||||
"cwd": "$ZED_WORKTREE_ROOT",
|
||||
"console": "integratedTerminal",
|
||||
"type": "pwa-node"
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -16,7 +16,6 @@ doctest = false
|
||||
test-support = ["gpui/test-support", "project/test-support"]
|
||||
|
||||
[dependencies]
|
||||
agent-client-protocol.workspace = true
|
||||
agentic-coding-protocol.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
|
||||
@@ -1,26 +1,20 @@
|
||||
use std::{path::Path, rc::Rc};
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use agentic_coding_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use project::Project;
|
||||
use ui::App;
|
||||
|
||||
use crate::AcpThread;
|
||||
use futures::future::{FutureExt as _, LocalBoxFuture};
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
fn request_any(
|
||||
&self,
|
||||
params: acp::AnyAgentRequest,
|
||||
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>>;
|
||||
}
|
||||
|
||||
impl AgentConnection for acp::AgentConnection {
|
||||
fn request_any(
|
||||
&self,
|
||||
params: acp::AnyAgentRequest,
|
||||
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
|
||||
let task = self.request_any(params);
|
||||
async move { Ok(task.await?) }.boxed_local()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,449 +0,0 @@
|
||||
// Translates old acp agents into the new schema
|
||||
use agent_client_protocol as acp;
|
||||
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
|
||||
use anyhow::{Context as _, Result};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use project::Project;
|
||||
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
|
||||
use ui::App;
|
||||
|
||||
use crate::{AcpThread, AgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OldAcpClientDelegate {
|
||||
thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
||||
cx: AsyncApp,
|
||||
next_tool_call_id: Rc<RefCell<u64>>,
|
||||
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
|
||||
}
|
||||
|
||||
impl OldAcpClientDelegate {
|
||||
pub fn new(thread: Rc<RefCell<WeakEntity<AcpThread>>>, cx: AsyncApp) -> Self {
|
||||
Self {
|
||||
thread,
|
||||
cx,
|
||||
next_tool_call_id: Rc::new(RefCell::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl acp_old::Client for OldAcpClientDelegate {
|
||||
async fn stream_assistant_message_chunk(
|
||||
&self,
|
||||
params: acp_old::StreamAssistantMessageChunkParams,
|
||||
) -> Result<(), acp_old::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
|
||||
cx.update(|cx| {
|
||||
self.thread
|
||||
.borrow()
|
||||
.update(cx, |thread, cx| match params.chunk {
|
||||
acp_old::AssistantMessageChunk::Text { text } => {
|
||||
thread.push_assistant_content_block(text.into(), false, cx)
|
||||
}
|
||||
acp_old::AssistantMessageChunk::Thought { thought } => {
|
||||
thread.push_assistant_content_block(thought.into(), true, cx)
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn request_tool_call_confirmation(
|
||||
&self,
|
||||
request: acp_old::RequestToolCallConfirmationParams,
|
||||
) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
|
||||
let old_acp_id = *self.next_tool_call_id.borrow() + 1;
|
||||
self.next_tool_call_id.replace(old_acp_id);
|
||||
|
||||
let tool_call = into_new_tool_call(
|
||||
acp::ToolCallId(old_acp_id.to_string().into()),
|
||||
request.tool_call,
|
||||
);
|
||||
|
||||
let mut options = match request.confirmation {
|
||||
acp_old::ToolCallConfirmation::Edit { .. } => vec![(
|
||||
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
acp::PermissionOptionKind::AllowAlways,
|
||||
"Always Allow Edits".to_string(),
|
||||
)],
|
||||
acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![(
|
||||
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
acp::PermissionOptionKind::AllowAlways,
|
||||
format!("Always Allow {}", root_command),
|
||||
)],
|
||||
acp_old::ToolCallConfirmation::Mcp {
|
||||
server_name,
|
||||
tool_name,
|
||||
..
|
||||
} => vec![
|
||||
(
|
||||
acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
|
||||
acp::PermissionOptionKind::AllowAlways,
|
||||
format!("Always Allow {}", server_name),
|
||||
),
|
||||
(
|
||||
acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool,
|
||||
acp::PermissionOptionKind::AllowAlways,
|
||||
format!("Always Allow {}", tool_name),
|
||||
),
|
||||
],
|
||||
acp_old::ToolCallConfirmation::Fetch { .. } => vec![(
|
||||
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
acp::PermissionOptionKind::AllowAlways,
|
||||
"Always Allow".to_string(),
|
||||
)],
|
||||
acp_old::ToolCallConfirmation::Other { .. } => vec![(
|
||||
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
acp::PermissionOptionKind::AllowAlways,
|
||||
"Always Allow".to_string(),
|
||||
)],
|
||||
};
|
||||
|
||||
options.extend([
|
||||
(
|
||||
acp_old::ToolCallConfirmationOutcome::Allow,
|
||||
acp::PermissionOptionKind::AllowOnce,
|
||||
"Allow".to_string(),
|
||||
),
|
||||
(
|
||||
acp_old::ToolCallConfirmationOutcome::Reject,
|
||||
acp::PermissionOptionKind::RejectOnce,
|
||||
"Reject".to_string(),
|
||||
),
|
||||
]);
|
||||
|
||||
let mut outcomes = Vec::with_capacity(options.len());
|
||||
let mut acp_options = Vec::with_capacity(options.len());
|
||||
|
||||
for (index, (outcome, kind, label)) in options.into_iter().enumerate() {
|
||||
outcomes.push(outcome);
|
||||
acp_options.push(acp::PermissionOption {
|
||||
id: acp::PermissionOptionId(index.to_string().into()),
|
||||
label,
|
||||
kind,
|
||||
})
|
||||
}
|
||||
|
||||
let response = cx
|
||||
.update(|cx| {
|
||||
self.thread.borrow().update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(tool_call, acp_options, cx)
|
||||
})
|
||||
})?
|
||||
.context("Failed to update thread")?
|
||||
.await;
|
||||
|
||||
let outcome = match response {
|
||||
Ok(option_id) => outcomes[option_id.0.parse::<usize>().unwrap_or(0)],
|
||||
Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel,
|
||||
};
|
||||
|
||||
Ok(acp_old::RequestToolCallConfirmationResponse {
|
||||
id: acp_old::ToolCallId(old_acp_id),
|
||||
outcome: outcome,
|
||||
})
|
||||
}
|
||||
|
||||
async fn push_tool_call(
|
||||
&self,
|
||||
request: acp_old::PushToolCallParams,
|
||||
) -> Result<acp_old::PushToolCallResponse, acp_old::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
|
||||
let old_acp_id = *self.next_tool_call_id.borrow() + 1;
|
||||
self.next_tool_call_id.replace(old_acp_id);
|
||||
|
||||
cx.update(|cx| {
|
||||
self.thread.borrow().update(cx, |thread, cx| {
|
||||
thread.upsert_tool_call(
|
||||
into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?
|
||||
.context("Failed to update thread")?;
|
||||
|
||||
Ok(acp_old::PushToolCallResponse {
|
||||
id: acp_old::ToolCallId(old_acp_id),
|
||||
})
|
||||
}
|
||||
|
||||
async fn update_tool_call(
|
||||
&self,
|
||||
request: acp_old::UpdateToolCallParams,
|
||||
) -> Result<(), acp_old::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
|
||||
cx.update(|cx| {
|
||||
self.thread.borrow().update(cx, |thread, cx| {
|
||||
thread.update_tool_call(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(request.tool_call_id.0.to_string().into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(into_new_tool_call_status(request.status)),
|
||||
content: Some(
|
||||
request
|
||||
.content
|
||||
.into_iter()
|
||||
.map(into_new_tool_call_content)
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?
|
||||
.context("Failed to update thread")??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
|
||||
cx.update(|cx| {
|
||||
self.thread.borrow().update(cx, |thread, cx| {
|
||||
thread.update_plan(
|
||||
acp::Plan {
|
||||
entries: request
|
||||
.entries
|
||||
.into_iter()
|
||||
.map(into_new_plan_entry)
|
||||
.collect(),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?
|
||||
.context("Failed to update thread")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_text_file(
|
||||
&self,
|
||||
acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams,
|
||||
) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
|
||||
let content = self
|
||||
.cx
|
||||
.update(|cx| {
|
||||
self.thread.borrow().update(cx, |thread, cx| {
|
||||
thread.read_text_file(path, line, limit, false, cx)
|
||||
})
|
||||
})?
|
||||
.context("Failed to update thread")?
|
||||
.await?;
|
||||
Ok(acp_old::ReadTextFileResponse { content })
|
||||
}
|
||||
|
||||
async fn write_text_file(
|
||||
&self,
|
||||
acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams,
|
||||
) -> Result<(), acp_old::Error> {
|
||||
self.cx
|
||||
.update(|cx| {
|
||||
self.thread
|
||||
.borrow()
|
||||
.update(cx, |thread, cx| thread.write_text_file(path, content, cx))
|
||||
})?
|
||||
.context("Failed to update thread")?
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall {
|
||||
acp::ToolCall {
|
||||
id: id,
|
||||
label: request.label,
|
||||
kind: acp_kind_from_old_icon(request.icon),
|
||||
status: acp::ToolCallStatus::InProgress,
|
||||
content: request
|
||||
.content
|
||||
.into_iter()
|
||||
.map(into_new_tool_call_content)
|
||||
.collect(),
|
||||
locations: request
|
||||
.locations
|
||||
.into_iter()
|
||||
.map(into_new_tool_call_location)
|
||||
.collect(),
|
||||
raw_input: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind {
|
||||
match icon {
|
||||
acp_old::Icon::FileSearch => acp::ToolKind::Search,
|
||||
acp_old::Icon::Folder => acp::ToolKind::Search,
|
||||
acp_old::Icon::Globe => acp::ToolKind::Search,
|
||||
acp_old::Icon::Hammer => acp::ToolKind::Other,
|
||||
acp_old::Icon::LightBulb => acp::ToolKind::Think,
|
||||
acp_old::Icon::Pencil => acp::ToolKind::Edit,
|
||||
acp_old::Icon::Regex => acp::ToolKind::Search,
|
||||
acp_old::Icon::Terminal => acp::ToolKind::Execute,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus {
|
||||
match status {
|
||||
acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress,
|
||||
acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed,
|
||||
acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent {
|
||||
match content {
|
||||
acp_old::ToolCallContent::Markdown { markdown } => markdown.into(),
|
||||
acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff {
|
||||
diff: into_new_diff(diff),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn into_new_diff(diff: acp_old::Diff) -> acp::Diff {
|
||||
acp::Diff {
|
||||
path: diff.path,
|
||||
old_text: diff.old_text,
|
||||
new_text: diff.new_text,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation {
|
||||
acp::ToolCallLocation {
|
||||
path: location.path,
|
||||
line: location.line,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry {
|
||||
acp::PlanEntry {
|
||||
content: entry.content,
|
||||
priority: into_new_plan_priority(entry.priority),
|
||||
status: into_new_plan_status(entry.status),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority {
|
||||
match priority {
|
||||
acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low,
|
||||
acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium,
|
||||
acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus {
|
||||
match status {
|
||||
acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending,
|
||||
acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress,
|
||||
acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Unauthenticated;
|
||||
|
||||
impl Error for Unauthenticated {}
|
||||
impl fmt::Display for Unauthenticated {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Unauthenticated")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OldAcpAgentConnection {
|
||||
pub name: &'static str,
|
||||
pub connection: acp_old::AgentConnection,
|
||||
pub child_status: Task<Result<()>>,
|
||||
}
|
||||
|
||||
impl AgentConnection for OldAcpAgentConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let task = self.connection.request_any(
|
||||
acp_old::InitializeParams {
|
||||
protocol_version: acp_old::ProtocolVersion::latest(),
|
||||
}
|
||||
.into_any(),
|
||||
);
|
||||
cx.spawn(async move |cx| {
|
||||
let result = task.await?;
|
||||
let result = acp_old::InitializeParams::response_from_any(result)?;
|
||||
|
||||
if !result.is_authenticated {
|
||||
anyhow::bail!(Unauthenticated)
|
||||
}
|
||||
|
||||
cx.update(|cx| {
|
||||
let thread = cx.new(|cx| {
|
||||
let session_id = acp::SessionId("acp-old-no-id".into());
|
||||
AcpThread::new(self.clone(), project, session_id, cx)
|
||||
});
|
||||
thread
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
let task = self
|
||||
.connection
|
||||
.request_any(acp_old::AuthenticateParams.into_any());
|
||||
cx.foreground_executor().spawn(async move {
|
||||
task.await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
|
||||
let chunks = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
.filter_map(|block| match block {
|
||||
acp::ContentBlock::Text(text) => {
|
||||
Some(acp_old::UserMessageChunk::Text { text: text.text })
|
||||
}
|
||||
acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path {
|
||||
path: link.uri.into(),
|
||||
}),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let task = self
|
||||
.connection
|
||||
.request_any(acp_old::SendUserMessageParams { chunks }.into_any());
|
||||
cx.foreground_executor().spawn(async move {
|
||||
task.await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) {
|
||||
let task = self
|
||||
.connection
|
||||
.request_any(acp_old::CancelSendMessageParams.into_any());
|
||||
cx.foreground_executor()
|
||||
.spawn(async move {
|
||||
task.await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx)
|
||||
}
|
||||
}
|
||||
@@ -308,12 +308,7 @@ mod tests {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn needs_confirmation(
|
||||
&self,
|
||||
_input: &serde_json::Value,
|
||||
_project: &Entity<Project>,
|
||||
_cx: &App,
|
||||
) -> bool {
|
||||
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ impl Tool for ContextServerTool {
|
||||
}
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ use std::{
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use util::{ResultExt as _, post_inc};
|
||||
use util::{ResultExt as _, debug_panic, post_inc};
|
||||
use uuid::Uuid;
|
||||
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
|
||||
@@ -942,7 +942,7 @@ impl Thread {
|
||||
}
|
||||
|
||||
pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
|
||||
self.tool_use.tool_uses_for_message(id, &self.project, cx)
|
||||
self.tool_use.tool_uses_for_message(id, cx)
|
||||
}
|
||||
|
||||
pub fn tool_results_for_message(
|
||||
@@ -1582,18 +1582,20 @@ impl Thread {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut App,
|
||||
) -> Option<PendingToolUse> {
|
||||
// Represent notification as a simulated `project_notifications` tool call
|
||||
let tool_name = Arc::from("project_notifications");
|
||||
let tool = self.tools.read(cx).tool(&tool_name, cx)?;
|
||||
let action_log = self.action_log.read(cx);
|
||||
|
||||
if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
|
||||
if !action_log.has_unnotified_user_edits() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if self
|
||||
.action_log
|
||||
.update(cx, |log, cx| log.unnotified_user_edits(cx).is_none())
|
||||
{
|
||||
// Represent notification as a simulated `project_notifications` tool call
|
||||
let tool_name = Arc::from("project_notifications");
|
||||
let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
|
||||
debug_panic!("`project_notifications` tool not found");
|
||||
return None;
|
||||
};
|
||||
|
||||
if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -2037,12 +2039,6 @@ impl Thread {
|
||||
if let Some(retry_strategy) =
|
||||
Thread::get_retry_strategy(completion_error)
|
||||
{
|
||||
log::info!(
|
||||
"Retrying with {:?} for language model completion error {:?}",
|
||||
retry_strategy,
|
||||
completion_error
|
||||
);
|
||||
|
||||
retry_scheduled = thread
|
||||
.handle_retryable_error_with_delay(
|
||||
&completion_error,
|
||||
@@ -2252,14 +2248,15 @@ impl Thread {
|
||||
..
|
||||
}
|
||||
| AuthenticationError { .. }
|
||||
| PermissionError { .. }
|
||||
| NoApiKey { .. }
|
||||
| ApiEndpointNotFound { .. }
|
||||
| PromptTooLarge { .. } => None,
|
||||
| PermissionError { .. } => None,
|
||||
// These errors might be transient, so retry them
|
||||
SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
|
||||
SerializeRequest { .. }
|
||||
| BuildRequestBody { .. }
|
||||
| PromptTooLarge { .. }
|
||||
| ApiEndpointNotFound { .. }
|
||||
| NoApiKey { .. } => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 1,
|
||||
max_attempts: 2,
|
||||
}),
|
||||
// Retry all other 4xx and 5xx errors once.
|
||||
HttpResponseError { status_code, .. }
|
||||
@@ -2557,7 +2554,7 @@ impl Thread {
|
||||
return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
|
||||
}
|
||||
|
||||
if tool.needs_confirmation(&tool_use.input, &self.project, cx)
|
||||
if tool.needs_confirmation(&tool_use.input, cx)
|
||||
&& !AgentSettings::get_global(cx).always_allow_tool_actions
|
||||
{
|
||||
self.tool_use.confirm_tool_use(
|
||||
@@ -5495,7 +5492,7 @@ fn main() {{
|
||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
||||
|
||||
let provider = Arc::new(FakeLanguageModelProvider::default());
|
||||
let provider = Arc::new(FakeLanguageModelProvider);
|
||||
let model = provider.test_model();
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(model);
|
||||
|
||||
|
||||
@@ -41,9 +41,6 @@ use std::{
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
|
||||
pub static ZED_STATELESS: std::sync::LazyLock<bool> =
|
||||
std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DataType {
|
||||
#[serde(rename = "json")]
|
||||
@@ -877,11 +874,7 @@ impl ThreadsDatabase {
|
||||
|
||||
let needs_migration_from_heed = mdb_path.exists();
|
||||
|
||||
let connection = if *ZED_STATELESS {
|
||||
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
|
||||
} else {
|
||||
Connection::open_file(&sqlite_path.to_string_lossy())
|
||||
};
|
||||
let connection = Connection::open_file(&sqlite_path.to_string_lossy());
|
||||
|
||||
connection.exec(indoc! {"
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
|
||||
@@ -165,12 +165,7 @@ impl ToolUseState {
|
||||
self.pending_tool_uses_by_id.values().collect()
|
||||
}
|
||||
|
||||
pub fn tool_uses_for_message(
|
||||
&self,
|
||||
id: MessageId,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Vec<ToolUse> {
|
||||
pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
|
||||
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
|
||||
return Vec::new();
|
||||
};
|
||||
@@ -216,10 +211,7 @@ impl ToolUseState {
|
||||
|
||||
let (icon, needs_confirmation) =
|
||||
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
|
||||
(
|
||||
tool.icon(),
|
||||
tool.needs_confirmation(&tool_use.input, project, cx),
|
||||
)
|
||||
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
|
||||
} else {
|
||||
(IconName::Cog, false)
|
||||
};
|
||||
|
||||
@@ -18,7 +18,6 @@ doctest = false
|
||||
|
||||
[dependencies]
|
||||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agentic-coding-protocol.workspace = true
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
@@ -29,7 +28,6 @@ itertools.workspace = true
|
||||
log.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
rand.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
@@ -41,7 +39,6 @@ ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
watch.workspace = true
|
||||
indoc.workspace = true
|
||||
which.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
mod claude;
|
||||
mod codex;
|
||||
mod gemini;
|
||||
mod mcp_server;
|
||||
mod settings;
|
||||
mod stdio_agent_server;
|
||||
|
||||
#[cfg(test)]
|
||||
mod e2e_tests;
|
||||
|
||||
pub use claude::*;
|
||||
pub use codex::*;
|
||||
pub use gemini::*;
|
||||
pub use settings::*;
|
||||
pub use stdio_agent_server::*;
|
||||
|
||||
use acp_thread::AgentConnection;
|
||||
use acp_thread::AcpThread;
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AsyncApp, Entity, SharedString, Task};
|
||||
@@ -21,7 +20,6 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
@@ -35,14 +33,14 @@ pub trait AgentServer: Send {
|
||||
fn name(&self) -> &'static str;
|
||||
fn empty_state_headline(&self) -> &'static str;
|
||||
fn empty_state_message(&self) -> &'static str;
|
||||
fn supports_always_allow(&self) -> bool;
|
||||
|
||||
fn connect(
|
||||
fn new_thread(
|
||||
&self,
|
||||
// these will go away when old_acp is fully removed
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>>;
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AgentServerCommand {
|
||||
|
||||
@@ -1,35 +1,39 @@
|
||||
mod mcp_server;
|
||||
pub mod tools;
|
||||
mod tools;
|
||||
|
||||
use collections::HashMap;
|
||||
use context_server::listener::McpServerTool;
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::process::Child;
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
use std::pin::pin;
|
||||
use std::rc::Rc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use agentic_coding_protocol::{
|
||||
self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion,
|
||||
StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams,
|
||||
};
|
||||
use anyhow::{Result, anyhow};
|
||||
use futures::channel::oneshot;
|
||||
use futures::{AsyncBufReadExt, AsyncWriteExt};
|
||||
use futures::future::LocalBoxFuture;
|
||||
use futures::{AsyncBufReadExt, AsyncWriteExt, SinkExt};
|
||||
use futures::{
|
||||
AsyncRead, AsyncWrite, FutureExt, StreamExt,
|
||||
channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
|
||||
io::BufReader,
|
||||
select_biased,
|
||||
};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
|
||||
use crate::claude::mcp_server::ClaudeMcpServer;
|
||||
use crate::claude::tools::ClaudeTool;
|
||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
|
||||
use acp_thread::{AcpThread, AgentConnection};
|
||||
use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ClaudeCode;
|
||||
@@ -44,51 +48,36 @@ impl AgentServer for ClaudeCode {
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"How can I help you today?"
|
||||
""
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiClaude
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
_project: &Entity<Project>,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let connection = ClaudeAgentConnection {
|
||||
sessions: Default::default(),
|
||||
};
|
||||
|
||||
Task::ready(Ok(Rc::new(connection) as _))
|
||||
}
|
||||
}
|
||||
|
||||
struct ClaudeAgentConnection {
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
|
||||
}
|
||||
|
||||
impl AgentConnection for ClaudeAgentConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
ClaudeCode.name()
|
||||
fn supports_always_allow(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
&self,
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let cwd = cwd.to_owned();
|
||||
let project = project.clone();
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
let title = self.name().into();
|
||||
cx.spawn(async move |cx| {
|
||||
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
|
||||
let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
|
||||
let (mut delegate_tx, delegate_rx) = watch::channel(None);
|
||||
let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let mcp_server = ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
|
||||
|
||||
let mut mcp_servers = HashMap::default();
|
||||
mcp_servers.insert(
|
||||
mcp_server::SERVER_NAME.to_string(),
|
||||
permission_mcp_server.server_config()?,
|
||||
mcp_server.server_config()?,
|
||||
);
|
||||
let mcp_config = McpConfig { mcp_servers };
|
||||
|
||||
@@ -113,158 +102,192 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
|
||||
let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
|
||||
let (cancel_tx, mut cancel_rx) = mpsc::unbounded::<oneshot::Sender<Result<()>>>();
|
||||
|
||||
let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
|
||||
let session_id = Uuid::new_v4();
|
||||
|
||||
log::trace!("Starting session with id: {}", session_id);
|
||||
|
||||
cx.background_spawn({
|
||||
let session_id = session_id.clone();
|
||||
async move {
|
||||
let mut outgoing_rx = Some(outgoing_rx);
|
||||
cx.background_spawn(async move {
|
||||
let mut outgoing_rx = Some(outgoing_rx);
|
||||
let mut mode = ClaudeSessionMode::Start;
|
||||
|
||||
let mut child = spawn_claude(
|
||||
&command,
|
||||
ClaudeSessionMode::Start,
|
||||
session_id.clone(),
|
||||
&mcp_config_path,
|
||||
&cwd,
|
||||
)
|
||||
.await?;
|
||||
loop {
|
||||
let mut child =
|
||||
spawn_claude(&command, mode, session_id, &mcp_config_path, &root_dir)
|
||||
.await?;
|
||||
mode = ClaudeSessionMode::Resume;
|
||||
|
||||
let pid = child.id();
|
||||
log::trace!("Spawned (pid: {})", pid);
|
||||
|
||||
ClaudeAgentSession::handle_io(
|
||||
outgoing_rx.take().unwrap(),
|
||||
incoming_message_tx.clone(),
|
||||
child.stdin.take().unwrap(),
|
||||
child.stdout.take().unwrap(),
|
||||
)
|
||||
.await?;
|
||||
let mut io_fut = pin!(
|
||||
ClaudeAgentConnection::handle_io(
|
||||
outgoing_rx.take().unwrap(),
|
||||
incoming_message_tx.clone(),
|
||||
child.stdin.take().unwrap(),
|
||||
child.stdout.take().unwrap(),
|
||||
)
|
||||
.fuse()
|
||||
);
|
||||
|
||||
select_biased! {
|
||||
done_tx = cancel_rx.next() => {
|
||||
if let Some(done_tx) = done_tx {
|
||||
log::trace!("Interrupted (pid: {})", pid);
|
||||
let result = send_interrupt(pid as i32);
|
||||
outgoing_rx.replace(io_fut.await?);
|
||||
done_tx.send(result).log_err();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
result = io_fut => {
|
||||
result?;
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!("Stopped (pid: {})", pid);
|
||||
|
||||
drop(mcp_config_path);
|
||||
anyhow::Ok(())
|
||||
break;
|
||||
}
|
||||
|
||||
drop(mcp_config_path);
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
let end_turn_tx = Rc::new(RefCell::new(None));
|
||||
let handler_task = cx.spawn({
|
||||
let end_turn_tx = end_turn_tx.clone();
|
||||
let thread_rx = thread_rx.clone();
|
||||
async move |cx| {
|
||||
while let Some(message) = incoming_message_rx.next().await {
|
||||
ClaudeAgentSession::handle_message(
|
||||
thread_rx.clone(),
|
||||
message,
|
||||
end_turn_tx.clone(),
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
cx.new(|cx| {
|
||||
let end_turn_tx = Rc::new(RefCell::new(None));
|
||||
let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
|
||||
delegate_tx.send(Some(delegate.clone())).log_err();
|
||||
|
||||
let handler_task = cx.foreground_executor().spawn({
|
||||
let end_turn_tx = end_turn_tx.clone();
|
||||
let tool_id_map = tool_id_map.clone();
|
||||
let delegate = delegate.clone();
|
||||
async move {
|
||||
while let Some(message) = incoming_message_rx.next().await {
|
||||
ClaudeAgentConnection::handle_message(
|
||||
delegate.clone(),
|
||||
message,
|
||||
end_turn_tx.clone(),
|
||||
tool_id_map.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
|
||||
let mut connection = ClaudeAgentConnection {
|
||||
delegate,
|
||||
outgoing_tx,
|
||||
end_turn_tx,
|
||||
cancel_tx,
|
||||
session_id,
|
||||
_handler_task: handler_task,
|
||||
_mcp_server: None,
|
||||
};
|
||||
|
||||
thread_tx.send(thread.downgrade())?;
|
||||
|
||||
let session = ClaudeAgentSession {
|
||||
outgoing_tx,
|
||||
end_turn_tx,
|
||||
_handler_task: handler_task,
|
||||
_mcp_server: Some(permission_mcp_server),
|
||||
};
|
||||
|
||||
self.sessions.borrow_mut().insert(session_id, session);
|
||||
|
||||
Ok(thread)
|
||||
connection._mcp_server = Some(mcp_server);
|
||||
acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
}
|
||||
#[cfg(unix)]
|
||||
fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> {
|
||||
let pid = nix::unistd::Pid::from_raw(pid);
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
|
||||
let sessions = self.sessions.borrow();
|
||||
let Some(session) = sessions.get(¶ms.session_id) else {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Attempted to send message to nonexistent session {}",
|
||||
params.session_id
|
||||
)));
|
||||
};
|
||||
nix::sys::signal::kill(pid, nix::sys::signal::SIGINT)
|
||||
.map_err(|e| anyhow!("Failed to interrupt process: {}", e))
|
||||
}
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
session.end_turn_tx.borrow_mut().replace(tx);
|
||||
#[cfg(windows)]
|
||||
fn send_interrupt(_pid: i32) -> anyhow::Result<()> {
|
||||
panic!("Cancel not implemented on Windows")
|
||||
}
|
||||
|
||||
let mut content = String::new();
|
||||
for chunk in params.prompt {
|
||||
match chunk {
|
||||
acp::ContentBlock::Text(text_content) => {
|
||||
content.push_str(&text_content.text);
|
||||
impl AgentConnection for ClaudeAgentConnection {
|
||||
/// Send a request to the agent and wait for a response.
|
||||
fn request_any(
|
||||
&self,
|
||||
params: AnyAgentRequest,
|
||||
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
|
||||
let delegate = self.delegate.clone();
|
||||
let end_turn_tx = self.end_turn_tx.clone();
|
||||
let outgoing_tx = self.outgoing_tx.clone();
|
||||
let mut cancel_tx = self.cancel_tx.clone();
|
||||
let session_id = self.session_id;
|
||||
async move {
|
||||
match params {
|
||||
// todo: consider sending an empty request so we get the init response?
|
||||
AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
|
||||
acp::InitializeResponse {
|
||||
is_authenticated: true,
|
||||
protocol_version: ProtocolVersion::latest(),
|
||||
},
|
||||
)),
|
||||
AnyAgentRequest::AuthenticateParams(_) => {
|
||||
Err(anyhow!("Authentication not supported"))
|
||||
}
|
||||
acp::ContentBlock::ResourceLink(resource_link) => {
|
||||
content.push_str(&format!("@{}", resource_link.uri));
|
||||
AnyAgentRequest::SendUserMessageParams(message) => {
|
||||
delegate.clear_completed_plan_entries().await?;
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
end_turn_tx.borrow_mut().replace(tx);
|
||||
let mut content = String::new();
|
||||
for chunk in message.chunks {
|
||||
match chunk {
|
||||
agentic_coding_protocol::UserMessageChunk::Text { text } => {
|
||||
content.push_str(&text)
|
||||
}
|
||||
agentic_coding_protocol::UserMessageChunk::Path { path } => {
|
||||
content.push_str(&format!("@{path:?}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
outgoing_tx.unbounded_send(SdkMessage::User {
|
||||
message: Message {
|
||||
role: Role::User,
|
||||
content: Content::UntaggedText(content),
|
||||
id: None,
|
||||
model: None,
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: None,
|
||||
},
|
||||
session_id: Some(session_id),
|
||||
})?;
|
||||
rx.await??;
|
||||
Ok(AnyAgentResult::SendUserMessageResponse(
|
||||
acp::SendUserMessageResponse,
|
||||
))
|
||||
}
|
||||
acp::ContentBlock::Audio(_)
|
||||
| acp::ContentBlock::Image(_)
|
||||
| acp::ContentBlock::Resource(_) => {
|
||||
// TODO
|
||||
AnyAgentRequest::CancelSendMessageParams(_) => {
|
||||
let (done_tx, done_rx) = oneshot::channel();
|
||||
cancel_tx.send(done_tx).await?;
|
||||
done_rx.await??;
|
||||
|
||||
Ok(AnyAgentResult::CancelSendMessageResponse(
|
||||
acp::CancelSendMessageResponse,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
|
||||
message: Message {
|
||||
role: Role::User,
|
||||
content: Content::UntaggedText(content),
|
||||
id: None,
|
||||
model: None,
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: None,
|
||||
},
|
||||
session_id: Some(params.session_id.to_string()),
|
||||
}) {
|
||||
return Task::ready(Err(anyhow!(err)));
|
||||
}
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
rx.await??;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
||||
let sessions = self.sessions.borrow();
|
||||
let Some(session) = sessions.get(&session_id) else {
|
||||
log::warn!("Attempted to cancel nonexistent session {}", session_id);
|
||||
return;
|
||||
};
|
||||
|
||||
session
|
||||
.outgoing_tx
|
||||
.unbounded_send(SdkMessage::new_interrupt_message())
|
||||
.log_err();
|
||||
.boxed_local()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum ClaudeSessionMode {
|
||||
Start,
|
||||
#[expect(dead_code)]
|
||||
Resume,
|
||||
}
|
||||
|
||||
async fn spawn_claude(
|
||||
command: &AgentServerCommand,
|
||||
mode: ClaudeSessionMode,
|
||||
session_id: acp::SessionId,
|
||||
session_id: Uuid,
|
||||
mcp_config_path: &Path,
|
||||
root_dir: &Path,
|
||||
) -> Result<Child> {
|
||||
@@ -282,16 +305,10 @@ async fn spawn_claude(
|
||||
&format!(
|
||||
"mcp__{}__{}",
|
||||
mcp_server::SERVER_NAME,
|
||||
mcp_server::PermissionTool::NAME,
|
||||
mcp_server::PERMISSION_TOOL
|
||||
),
|
||||
"--allowedTools",
|
||||
&format!(
|
||||
"mcp__{}__{},mcp__{}__{}",
|
||||
mcp_server::SERVER_NAME,
|
||||
mcp_server::EditTool::NAME,
|
||||
mcp_server::SERVER_NAME,
|
||||
mcp_server::ReadTool::NAME
|
||||
),
|
||||
"mcp__zed__Read,mcp__zed__Edit",
|
||||
"--disallowedTools",
|
||||
"Read,Edit",
|
||||
])
|
||||
@@ -310,135 +327,105 @@ async fn spawn_claude(
|
||||
Ok(child)
|
||||
}
|
||||
|
||||
struct ClaudeAgentSession {
|
||||
struct ClaudeAgentConnection {
|
||||
delegate: AcpClientDelegate,
|
||||
session_id: Uuid,
|
||||
outgoing_tx: UnboundedSender<SdkMessage>,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
|
||||
_mcp_server: Option<ClaudeZedMcpServer>,
|
||||
cancel_tx: UnboundedSender<oneshot::Sender<Result<()>>>,
|
||||
_mcp_server: Option<ClaudeMcpServer>,
|
||||
_handler_task: Task<()>,
|
||||
}
|
||||
|
||||
impl ClaudeAgentSession {
|
||||
impl ClaudeAgentConnection {
|
||||
async fn handle_message(
|
||||
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
delegate: AcpClientDelegate,
|
||||
message: SdkMessage,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
|
||||
cx: &mut AsyncApp,
|
||||
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
|
||||
) {
|
||||
match message {
|
||||
// we should only be sending these out, they don't need to be in the thread
|
||||
SdkMessage::ControlRequest { .. } => {}
|
||||
SdkMessage::Assistant {
|
||||
message,
|
||||
session_id: _,
|
||||
}
|
||||
| SdkMessage::User {
|
||||
message,
|
||||
session_id: _,
|
||||
} => {
|
||||
let Some(thread) = thread_rx
|
||||
.recv()
|
||||
.await
|
||||
.log_err()
|
||||
.and_then(|entity| entity.upgrade())
|
||||
else {
|
||||
log::error!("Received an SDK message but thread is gone");
|
||||
return;
|
||||
};
|
||||
|
||||
SdkMessage::Assistant { message, .. } | SdkMessage::User { message, .. } => {
|
||||
for chunk in message.content.chunks() {
|
||||
match chunk {
|
||||
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(text.into(), false, cx)
|
||||
delegate
|
||||
.stream_assistant_message_chunk(StreamAssistantMessageChunkParams {
|
||||
chunk: acp::AssistantMessageChunk::Text { text },
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::ToolUse { id, name, input } => {
|
||||
let claude_tool = ClaudeTool::infer(&name, input);
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
|
||||
thread.update_plan(
|
||||
acp::Plan {
|
||||
entries: params
|
||||
.todos
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect(),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
} else {
|
||||
thread.upsert_tool_call(
|
||||
claude_tool.as_acp(acp::ToolCallId(id.into())),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
|
||||
delegate
|
||||
.update_plan(acp::UpdatePlanParams {
|
||||
entries: params.todos.into_iter().map(Into::into).collect(),
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
} else if let Some(resp) = delegate
|
||||
.push_tool_call(claude_tool.as_acp())
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
tool_id_map.borrow_mut().insert(id, resp.id);
|
||||
}
|
||||
}
|
||||
ContentChunk::ToolResult {
|
||||
content,
|
||||
tool_use_id,
|
||||
} => {
|
||||
let content = content.to_string();
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.update_tool_call(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use_id.into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
content: (!content.is_empty())
|
||||
.then(|| vec![content.into()]),
|
||||
..Default::default()
|
||||
let id = tool_id_map.borrow_mut().remove(&tool_use_id);
|
||||
if let Some(id) = id {
|
||||
let content = content.to_string();
|
||||
delegate
|
||||
.update_tool_call(UpdateToolCallParams {
|
||||
tool_call_id: id,
|
||||
status: acp::ToolCallStatus::Finished,
|
||||
// Don't unset existing content
|
||||
content: (!content.is_empty()).then_some(
|
||||
ToolCallContent::Markdown {
|
||||
// For now we only include text content
|
||||
markdown: content,
|
||||
},
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
),
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
ContentChunk::Image
|
||||
| ContentChunk::Document
|
||||
| ContentChunk::Thinking
|
||||
| ContentChunk::RedactedThinking
|
||||
| ContentChunk::WebSearchToolResult => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
format!("Unsupported content: {:?}", chunk).into(),
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
delegate
|
||||
.stream_assistant_message_chunk(StreamAssistantMessageChunkParams {
|
||||
chunk: acp::AssistantMessageChunk::Text {
|
||||
text: format!("Unsupported content: {:?}", chunk),
|
||||
},
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
SdkMessage::Result {
|
||||
is_error,
|
||||
subtype,
|
||||
result,
|
||||
..
|
||||
is_error, subtype, ..
|
||||
} => {
|
||||
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
|
||||
if is_error {
|
||||
end_turn_tx
|
||||
.send(Err(anyhow!(
|
||||
"Error: {}",
|
||||
result.unwrap_or_else(|| subtype.to_string())
|
||||
)))
|
||||
.ok();
|
||||
end_turn_tx.send(Err(anyhow!("Error: {subtype}"))).ok();
|
||||
} else {
|
||||
end_turn_tx.send(Ok(())).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {}
|
||||
SdkMessage::System { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -605,14 +592,16 @@ enum SdkMessage {
|
||||
Assistant {
|
||||
message: Message, // from Anthropic SDK
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
session_id: Option<Uuid>,
|
||||
},
|
||||
|
||||
// A user message
|
||||
User {
|
||||
message: Message, // from Anthropic SDK
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
session_id: Option<Uuid>,
|
||||
},
|
||||
|
||||
// Emitted as the last message in a conversation
|
||||
Result {
|
||||
subtype: ResultErrorType,
|
||||
@@ -637,26 +626,6 @@ enum SdkMessage {
|
||||
#[serde(rename = "permissionMode")]
|
||||
permission_mode: PermissionMode,
|
||||
},
|
||||
/// Messages used to control the conversation, outside of chat messages to the model
|
||||
ControlRequest {
|
||||
request_id: String,
|
||||
request: ControlRequest,
|
||||
},
|
||||
/// Response to a control request
|
||||
ControlResponse { response: ControlResponse },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "subtype", rename_all = "snake_case")]
|
||||
enum ControlRequest {
|
||||
/// Cancel the current conversation
|
||||
Interrupt,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ControlResponse {
|
||||
request_id: String,
|
||||
subtype: ResultErrorType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -677,24 +646,6 @@ impl Display for ResultErrorType {
|
||||
}
|
||||
}
|
||||
|
||||
impl SdkMessage {
|
||||
fn new_interrupt_message() -> Self {
|
||||
use rand::Rng;
|
||||
// In the Claude Code TS SDK they just generate a random 12 character string,
|
||||
// `Math.random().toString(36).substring(2, 15)`
|
||||
let request_id = rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(12)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
Self::ControlRequest {
|
||||
request_id,
|
||||
request: ControlRequest::Interrupt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct McpServer {
|
||||
name: String,
|
||||
@@ -710,12 +661,27 @@ enum PermissionMode {
|
||||
Plan,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct McpConfig {
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct McpServerConfig {
|
||||
command: String,
|
||||
args: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
|
||||
crate::common_e2e_tests!(ClaudeCode);
|
||||
|
||||
pub fn local_command() -> AgentServerCommand {
|
||||
AgentServerCommand {
|
||||
|
||||
@@ -1,53 +1,78 @@
|
||||
use std::path::PathBuf;
|
||||
use std::{cell::RefCell, rc::Rc};
|
||||
|
||||
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
|
||||
use acp_thread::AcpThread;
|
||||
use agent_client_protocol as acp;
|
||||
use acp_thread::AcpClientDelegate;
|
||||
use agentic_coding_protocol::{self as acp, Client, ReadTextFileParams, WriteTextFileParams};
|
||||
use anyhow::{Context, Result};
|
||||
use collections::HashMap;
|
||||
use context_server::listener::{McpServerTool, ToolResponse};
|
||||
use context_server::types::{
|
||||
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
|
||||
ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests,
|
||||
use context_server::{
|
||||
listener::McpServer,
|
||||
types::{
|
||||
CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
|
||||
ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
|
||||
ToolResponseContent, ToolsCapabilities, requests,
|
||||
},
|
||||
};
|
||||
use gpui::{App, AsyncApp, Task, WeakEntity};
|
||||
use gpui::{App, AsyncApp, Task};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::debug_panic;
|
||||
|
||||
pub struct ClaudeZedMcpServer {
|
||||
server: context_server::listener::McpServer,
|
||||
use crate::claude::{
|
||||
McpServerConfig,
|
||||
tools::{ClaudeTool, EditToolParams, EditToolResponse, ReadToolParams, ReadToolResponse},
|
||||
};
|
||||
|
||||
pub struct ClaudeMcpServer {
|
||||
server: McpServer,
|
||||
}
|
||||
|
||||
pub const SERVER_NAME: &str = "zed";
|
||||
pub const READ_TOOL: &str = "Read";
|
||||
pub const EDIT_TOOL: &str = "Edit";
|
||||
pub const PERMISSION_TOOL: &str = "Confirmation";
|
||||
|
||||
impl ClaudeZedMcpServer {
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
struct PermissionToolParams {
|
||||
tool_name: String,
|
||||
input: serde_json::Value,
|
||||
tool_use_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior,
|
||||
updated_input: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum PermissionToolBehavior {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
impl ClaudeMcpServer {
|
||||
pub async fn new(
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
delegate: watch::Receiver<Option<AcpClientDelegate>>,
|
||||
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
|
||||
let mut mcp_server = McpServer::new(cx).await?;
|
||||
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
||||
|
||||
mcp_server.add_tool(PermissionTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(ReadTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(EditTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
|
||||
mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
|
||||
Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx)
|
||||
});
|
||||
|
||||
Ok(Self { server: mcp_server })
|
||||
}
|
||||
|
||||
pub fn server_config(&self) -> Result<McpServerConfig> {
|
||||
#[cfg(not(test))]
|
||||
let zed_path = std::env::current_exe()
|
||||
.context("finding current executable path for use in mcp_server")?;
|
||||
|
||||
#[cfg(test)]
|
||||
let zed_path = crate::e2e_tests::get_zed_path();
|
||||
.context("finding current executable path for use in mcp_server")?
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
Ok(McpServerConfig {
|
||||
command: zed_path,
|
||||
@@ -81,222 +106,195 @@ impl ClaudeZedMcpServer {
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct McpConfig {
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct McpServerConfig {
|
||||
pub command: PathBuf,
|
||||
pub args: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
// Tools
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PermissionTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct PermissionToolParams {
|
||||
tool_name: String,
|
||||
input: serde_json::Value,
|
||||
tool_use_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior,
|
||||
updated_input: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum PermissionToolBehavior {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
impl McpServerTool for PermissionTool {
|
||||
type Input = PermissionToolParams;
|
||||
type Output = ();
|
||||
|
||||
const NAME: &'static str = "Confirmation";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Request permission for tool calls"
|
||||
fn handle_list_tools(_: (), cx: &App) -> Task<Result<ListToolsResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
Ok(ListToolsResponse {
|
||||
tools: vec![
|
||||
Tool {
|
||||
name: PERMISSION_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(PermissionToolParams).into(),
|
||||
description: None,
|
||||
annotations: None,
|
||||
},
|
||||
Tool {
|
||||
name: READ_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(ReadToolParams).into(),
|
||||
description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()),
|
||||
annotations: Some(ToolAnnotations {
|
||||
title: Some("Read file".to_string()),
|
||||
read_only_hint: Some(true),
|
||||
destructive_hint: Some(false),
|
||||
open_world_hint: Some(false),
|
||||
// if time passes the contents might change, but it's not going to do anything different
|
||||
// true or false seem too strong, let's try a none.
|
||||
idempotent_hint: None,
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: EDIT_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(EditToolParams).into(),
|
||||
description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()),
|
||||
annotations: Some(ToolAnnotations {
|
||||
title: Some("Edit file".to_string()),
|
||||
read_only_hint: Some(false),
|
||||
destructive_hint: Some(false),
|
||||
open_world_hint: Some(false),
|
||||
idempotent_hint: Some(false),
|
||||
}),
|
||||
},
|
||||
],
|
||||
next_cursor: None,
|
||||
meta: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
fn handle_call_tool(
|
||||
request: CallToolParams,
|
||||
mut delegate_watch: watch::Receiver<Option<AcpClientDelegate>>,
|
||||
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
|
||||
cx: &App,
|
||||
) -> Task<Result<CallToolResponse>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let Some(delegate) = delegate_watch.recv().await? else {
|
||||
debug_panic!("Sent None delegate");
|
||||
anyhow::bail!("Server not available");
|
||||
};
|
||||
|
||||
let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone());
|
||||
let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into());
|
||||
let allow_option_id = acp::PermissionOptionId("allow".into());
|
||||
let reject_option_id = acp::PermissionOptionId("reject".into());
|
||||
if request.name.as_str() == PERMISSION_TOOL {
|
||||
let input =
|
||||
serde_json::from_value(request.arguments.context("Arguments required")?)?;
|
||||
|
||||
let chosen_option = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(
|
||||
claude_tool.as_acp(tool_call_id),
|
||||
vec![
|
||||
acp::PermissionOption {
|
||||
id: allow_option_id.clone(),
|
||||
label: "Allow".into(),
|
||||
kind: acp::PermissionOptionKind::AllowOnce,
|
||||
},
|
||||
acp::PermissionOption {
|
||||
id: reject_option_id.clone(),
|
||||
label: "Reject".into(),
|
||||
kind: acp::PermissionOptionKind::RejectOnce,
|
||||
},
|
||||
],
|
||||
cx,
|
||||
let result =
|
||||
Self::handle_permissions_tool_call(input, delegate, tool_id_map, cx).await?;
|
||||
Ok(CallToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&result)?,
|
||||
}],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
})
|
||||
} else if request.name.as_str() == READ_TOOL {
|
||||
let input =
|
||||
serde_json::from_value(request.arguments.context("Arguments required")?)?;
|
||||
|
||||
let result = Self::handle_read_tool_call(input, delegate, cx).await?;
|
||||
Ok(CallToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&result)?,
|
||||
}],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
})
|
||||
} else if request.name.as_str() == EDIT_TOOL {
|
||||
let input =
|
||||
serde_json::from_value(request.arguments.context("Arguments required")?)?;
|
||||
|
||||
let result = Self::handle_edit_tool_call(input, delegate, cx).await?;
|
||||
Ok(CallToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&result)?,
|
||||
}],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
})
|
||||
} else {
|
||||
anyhow::bail!("Unsupported tool");
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_read_tool_call(
|
||||
params: ReadToolParams,
|
||||
delegate: AcpClientDelegate,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<ReadToolResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let response = delegate
|
||||
.read_text_file(ReadTextFileParams {
|
||||
path: params.abs_path,
|
||||
line: params.offset,
|
||||
limit: params.limit,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(ReadToolResponse {
|
||||
content: response.content,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_edit_tool_call(
|
||||
params: EditToolParams,
|
||||
delegate: AcpClientDelegate,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<EditToolResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let response = delegate
|
||||
.read_text_file_reusing_snapshot(ReadTextFileParams {
|
||||
path: params.abs_path.clone(),
|
||||
line: None,
|
||||
limit: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let new_content = response.content.replace(¶ms.old_text, ¶ms.new_text);
|
||||
if new_content == response.content {
|
||||
return Err(anyhow::anyhow!("The old_text was not found in the content"));
|
||||
}
|
||||
|
||||
delegate
|
||||
.write_text_file(WriteTextFileParams {
|
||||
path: params.abs_path,
|
||||
content: new_content,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(EditToolResponse)
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_permissions_tool_call(
|
||||
params: PermissionToolParams,
|
||||
delegate: AcpClientDelegate,
|
||||
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<PermissionToolResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone());
|
||||
|
||||
let tool_call_id = match params.tool_use_id {
|
||||
Some(tool_use_id) => tool_id_map
|
||||
.borrow()
|
||||
.get(&tool_use_id)
|
||||
.cloned()
|
||||
.context("Tool call ID not found")?,
|
||||
|
||||
None => delegate.push_tool_call(claude_tool.as_acp()).await?.id,
|
||||
};
|
||||
|
||||
let outcome = delegate
|
||||
.request_existing_tool_call_confirmation(
|
||||
tool_call_id,
|
||||
claude_tool.confirmation(None),
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
.await?;
|
||||
|
||||
let response = if chosen_option == allow_option_id {
|
||||
PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Allow,
|
||||
updated_input: input.input,
|
||||
match outcome {
|
||||
acp::ToolCallConfirmationOutcome::Allow
|
||||
| acp::ToolCallConfirmationOutcome::AlwaysAllow
|
||||
| acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
|
||||
| acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Allow,
|
||||
updated_input: params.input,
|
||||
}),
|
||||
acp::ToolCallConfirmationOutcome::Reject
|
||||
| acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Deny,
|
||||
updated_input: params.input,
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
debug_assert_eq!(chosen_option, reject_option_id);
|
||||
PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Deny,
|
||||
updated_input: input.input,
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&response)?,
|
||||
}],
|
||||
structured_content: (),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ReadTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for ReadTool {
|
||||
type Input = ReadToolParams;
|
||||
type Output = ();
|
||||
|
||||
const NAME: &'static str = "Read";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents."
|
||||
}
|
||||
|
||||
fn annotations(&self) -> ToolAnnotations {
|
||||
ToolAnnotations {
|
||||
title: Some("Read file".to_string()),
|
||||
read_only_hint: Some(true),
|
||||
destructive_hint: Some(false),
|
||||
open_world_hint: Some(false),
|
||||
idempotent_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let content = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![ToolResponseContent::Text { text: content }],
|
||||
structured_content: (),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EditTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for EditTool {
|
||||
type Input = EditToolParams;
|
||||
type Output = ();
|
||||
|
||||
const NAME: &'static str = "Edit";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better."
|
||||
}
|
||||
|
||||
fn annotations(&self) -> ToolAnnotations {
|
||||
ToolAnnotations {
|
||||
title: Some("Edit file".to_string()),
|
||||
read_only_hint: Some(false),
|
||||
destructive_hint: Some(false),
|
||||
open_world_hint: Some(false),
|
||||
idempotent_hint: Some(false),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let content = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let new_content = content.replace(&input.old_text, &input.new_text);
|
||||
if new_content == content {
|
||||
return Err(anyhow::anyhow!("The old_text was not found in the content"));
|
||||
}
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.write_text_file(input.abs_path, new_content, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: (),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use agentic_coding_protocol::{self as acp, PushToolCallParams, ToolCallLocation};
|
||||
use itertools::Itertools;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -115,36 +115,51 @@ impl ClaudeTool {
|
||||
Self::Other { name, .. } => name.clone(),
|
||||
}
|
||||
}
|
||||
pub fn content(&self) -> Vec<acp::ToolCallContent> {
|
||||
|
||||
pub fn content(&self) -> Option<acp::ToolCallContent> {
|
||||
match &self {
|
||||
Self::Other { input, .. } => vec![
|
||||
format!(
|
||||
Self::Other { input, .. } => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: format!(
|
||||
"```json\n{}```",
|
||||
serde_json::to_string_pretty(&input).unwrap_or("{}".to_string())
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
Self::Task(Some(params)) => vec![params.prompt.clone().into()],
|
||||
Self::NotebookRead(Some(params)) => {
|
||||
vec![params.notebook_path.display().to_string().into()]
|
||||
}
|
||||
Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()],
|
||||
Self::Terminal(Some(params)) => vec![
|
||||
format!(
|
||||
),
|
||||
}),
|
||||
Self::Task(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: params.prompt.clone(),
|
||||
}),
|
||||
Self::NotebookRead(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: params.notebook_path.display().to_string(),
|
||||
}),
|
||||
Self::NotebookEdit(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: params.new_source.clone(),
|
||||
}),
|
||||
Self::Terminal(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: format!(
|
||||
"`{}`\n\n{}",
|
||||
params.command,
|
||||
params.description.as_deref().unwrap_or_default()
|
||||
)
|
||||
.into(),
|
||||
],
|
||||
Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()],
|
||||
Self::Ls(Some(params)) => vec![params.path.display().to_string().into()],
|
||||
Self::Glob(Some(params)) => vec![params.to_string().into()],
|
||||
Self::Grep(Some(params)) => vec![format!("`{params}`").into()],
|
||||
Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()],
|
||||
Self::WebSearch(Some(params)) => vec![params.to_string().into()],
|
||||
Self::TodoWrite(Some(params)) => vec![
|
||||
params
|
||||
),
|
||||
}),
|
||||
Self::ReadFile(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: params.abs_path.display().to_string(),
|
||||
}),
|
||||
Self::Ls(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: params.path.display().to_string(),
|
||||
}),
|
||||
Self::Glob(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: params.to_string(),
|
||||
}),
|
||||
Self::Grep(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: format!("`{params}`"),
|
||||
}),
|
||||
Self::WebFetch(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: params.prompt.clone(),
|
||||
}),
|
||||
Self::WebSearch(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: params.to_string(),
|
||||
}),
|
||||
Self::TodoWrite(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: params
|
||||
.todos
|
||||
.iter()
|
||||
.map(|todo| {
|
||||
@@ -159,39 +174,34 @@ impl ClaudeTool {
|
||||
todo.content
|
||||
)
|
||||
})
|
||||
.join("\n")
|
||||
.into(),
|
||||
],
|
||||
Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()],
|
||||
Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff {
|
||||
.join("\n"),
|
||||
}),
|
||||
Self::ExitPlanMode(Some(params)) => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: params.plan.clone(),
|
||||
}),
|
||||
Self::Edit(Some(params)) => Some(acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
path: params.abs_path.clone(),
|
||||
old_text: Some(params.old_text.clone()),
|
||||
new_text: params.new_text.clone(),
|
||||
},
|
||||
}],
|
||||
Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff {
|
||||
}),
|
||||
Self::Write(Some(params)) => Some(acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
path: params.file_path.clone(),
|
||||
old_text: None,
|
||||
new_text: params.content.clone(),
|
||||
},
|
||||
}],
|
||||
}),
|
||||
Self::MultiEdit(Some(params)) => {
|
||||
// todo: show multiple edits in a multibuffer?
|
||||
params
|
||||
.edits
|
||||
.first()
|
||||
.map(|edit| {
|
||||
vec![acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
path: params.file_path.clone(),
|
||||
old_text: Some(edit.old_string.clone()),
|
||||
new_text: edit.new_string.clone(),
|
||||
},
|
||||
}]
|
||||
})
|
||||
.unwrap_or_default()
|
||||
params.edits.first().map(|edit| acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
path: params.file_path.clone(),
|
||||
old_text: Some(edit.old_string.clone()),
|
||||
new_text: edit.new_string.clone(),
|
||||
},
|
||||
})
|
||||
}
|
||||
Self::Task(None)
|
||||
| Self::NotebookRead(None)
|
||||
@@ -207,80 +217,181 @@ impl ClaudeTool {
|
||||
| Self::ExitPlanMode(None)
|
||||
| Self::Edit(None)
|
||||
| Self::Write(None)
|
||||
| Self::MultiEdit(None) => vec![],
|
||||
| Self::MultiEdit(None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn kind(&self) -> acp::ToolKind {
|
||||
pub fn icon(&self) -> acp::Icon {
|
||||
match self {
|
||||
Self::Task(_) => acp::ToolKind::Think,
|
||||
Self::NotebookRead(_) => acp::ToolKind::Read,
|
||||
Self::NotebookEdit(_) => acp::ToolKind::Edit,
|
||||
Self::Edit(_) => acp::ToolKind::Edit,
|
||||
Self::MultiEdit(_) => acp::ToolKind::Edit,
|
||||
Self::Write(_) => acp::ToolKind::Edit,
|
||||
Self::ReadFile(_) => acp::ToolKind::Read,
|
||||
Self::Ls(_) => acp::ToolKind::Search,
|
||||
Self::Glob(_) => acp::ToolKind::Search,
|
||||
Self::Grep(_) => acp::ToolKind::Search,
|
||||
Self::Terminal(_) => acp::ToolKind::Execute,
|
||||
Self::WebSearch(_) => acp::ToolKind::Search,
|
||||
Self::WebFetch(_) => acp::ToolKind::Fetch,
|
||||
Self::TodoWrite(_) => acp::ToolKind::Think,
|
||||
Self::ExitPlanMode(_) => acp::ToolKind::Think,
|
||||
Self::Other { .. } => acp::ToolKind::Other,
|
||||
Self::Task(_) => acp::Icon::Hammer,
|
||||
Self::NotebookRead(_) => acp::Icon::FileSearch,
|
||||
Self::NotebookEdit(_) => acp::Icon::Pencil,
|
||||
Self::Edit(_) => acp::Icon::Pencil,
|
||||
Self::MultiEdit(_) => acp::Icon::Pencil,
|
||||
Self::Write(_) => acp::Icon::Pencil,
|
||||
Self::ReadFile(_) => acp::Icon::FileSearch,
|
||||
Self::Ls(_) => acp::Icon::Folder,
|
||||
Self::Glob(_) => acp::Icon::FileSearch,
|
||||
Self::Grep(_) => acp::Icon::Regex,
|
||||
Self::Terminal(_) => acp::Icon::Terminal,
|
||||
Self::WebSearch(_) => acp::Icon::Globe,
|
||||
Self::WebFetch(_) => acp::Icon::Globe,
|
||||
Self::TodoWrite(_) => acp::Icon::LightBulb,
|
||||
Self::ExitPlanMode(_) => acp::Icon::Hammer,
|
||||
Self::Other { .. } => acp::Icon::Hammer,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn confirmation(&self, description: Option<String>) -> acp::ToolCallConfirmation {
|
||||
match &self {
|
||||
Self::Edit(_) | Self::Write(_) | Self::NotebookEdit(_) | Self::MultiEdit(_) => {
|
||||
acp::ToolCallConfirmation::Edit { description }
|
||||
}
|
||||
Self::WebFetch(params) => acp::ToolCallConfirmation::Fetch {
|
||||
urls: params
|
||||
.as_ref()
|
||||
.map(|p| vec![p.url.clone()])
|
||||
.unwrap_or_default(),
|
||||
description,
|
||||
},
|
||||
Self::Terminal(Some(BashToolParams {
|
||||
description,
|
||||
command,
|
||||
..
|
||||
})) => acp::ToolCallConfirmation::Execute {
|
||||
command: command.clone(),
|
||||
root_command: command.clone(),
|
||||
description: description.clone(),
|
||||
},
|
||||
Self::ExitPlanMode(Some(params)) => acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {}", params.plan)
|
||||
} else {
|
||||
params.plan.clone()
|
||||
},
|
||||
},
|
||||
Self::Task(Some(params)) => acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {}", params.description)
|
||||
} else {
|
||||
params.description.clone()
|
||||
},
|
||||
},
|
||||
Self::Ls(Some(LsToolParams { path, .. }))
|
||||
| Self::ReadFile(Some(ReadToolParams { abs_path: path, .. })) => {
|
||||
let path = path.display();
|
||||
acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {path}")
|
||||
} else {
|
||||
path.to_string()
|
||||
},
|
||||
}
|
||||
}
|
||||
Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
|
||||
let path = notebook_path.display();
|
||||
acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {path}")
|
||||
} else {
|
||||
path.to_string()
|
||||
},
|
||||
}
|
||||
}
|
||||
Self::Glob(Some(params)) => acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {params}")
|
||||
} else {
|
||||
params.to_string()
|
||||
},
|
||||
},
|
||||
Self::Grep(Some(params)) => acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {params}")
|
||||
} else {
|
||||
params.to_string()
|
||||
},
|
||||
},
|
||||
Self::WebSearch(Some(params)) => acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {params}")
|
||||
} else {
|
||||
params.to_string()
|
||||
},
|
||||
},
|
||||
Self::TodoWrite(Some(params)) => {
|
||||
let params = params.todos.iter().map(|todo| &todo.content).join(", ");
|
||||
acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {params}")
|
||||
} else {
|
||||
params
|
||||
},
|
||||
}
|
||||
}
|
||||
Self::Terminal(None)
|
||||
| Self::Task(None)
|
||||
| Self::NotebookRead(None)
|
||||
| Self::ExitPlanMode(None)
|
||||
| Self::Ls(None)
|
||||
| Self::Glob(None)
|
||||
| Self::Grep(None)
|
||||
| Self::ReadFile(None)
|
||||
| Self::WebSearch(None)
|
||||
| Self::TodoWrite(None)
|
||||
| Self::Other { .. } => acp::ToolCallConfirmation::Other {
|
||||
description: description.unwrap_or("".to_string()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn locations(&self) -> Vec<acp::ToolCallLocation> {
|
||||
match &self {
|
||||
Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation {
|
||||
Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![ToolCallLocation {
|
||||
path: abs_path.clone(),
|
||||
line: None,
|
||||
}],
|
||||
Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => {
|
||||
vec![acp::ToolCallLocation {
|
||||
path: file_path.clone(),
|
||||
line: None,
|
||||
}]
|
||||
}
|
||||
Self::Write(Some(WriteToolParams { file_path, .. })) => {
|
||||
vec![acp::ToolCallLocation {
|
||||
vec![ToolCallLocation {
|
||||
path: file_path.clone(),
|
||||
line: None,
|
||||
}]
|
||||
}
|
||||
Self::Write(Some(WriteToolParams { file_path, .. })) => vec![ToolCallLocation {
|
||||
path: file_path.clone(),
|
||||
line: None,
|
||||
}],
|
||||
Self::ReadFile(Some(ReadToolParams {
|
||||
abs_path, offset, ..
|
||||
})) => vec![acp::ToolCallLocation {
|
||||
})) => vec![ToolCallLocation {
|
||||
path: abs_path.clone(),
|
||||
line: *offset,
|
||||
}],
|
||||
Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
|
||||
vec![acp::ToolCallLocation {
|
||||
vec![ToolCallLocation {
|
||||
path: notebook_path.clone(),
|
||||
line: None,
|
||||
}]
|
||||
}
|
||||
Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => {
|
||||
vec![acp::ToolCallLocation {
|
||||
vec![ToolCallLocation {
|
||||
path: notebook_path.clone(),
|
||||
line: None,
|
||||
}]
|
||||
}
|
||||
Self::Glob(Some(GlobToolParams {
|
||||
path: Some(path), ..
|
||||
})) => vec![acp::ToolCallLocation {
|
||||
})) => vec![ToolCallLocation {
|
||||
path: path.clone(),
|
||||
line: None,
|
||||
}],
|
||||
Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation {
|
||||
Self::Ls(Some(LsToolParams { path, .. })) => vec![ToolCallLocation {
|
||||
path: path.clone(),
|
||||
line: None,
|
||||
}],
|
||||
Self::Grep(Some(GrepToolParams {
|
||||
path: Some(path), ..
|
||||
})) => vec![acp::ToolCallLocation {
|
||||
})) => vec![ToolCallLocation {
|
||||
path: PathBuf::from(path),
|
||||
line: None,
|
||||
}],
|
||||
@@ -303,15 +414,12 @@ impl ClaudeTool {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall {
|
||||
acp::ToolCall {
|
||||
id,
|
||||
kind: self.kind(),
|
||||
status: acp::ToolCallStatus::InProgress,
|
||||
pub fn as_acp(&self) -> PushToolCallParams {
|
||||
PushToolCallParams {
|
||||
label: self.label(),
|
||||
content: self.content(),
|
||||
icon: self.icon(),
|
||||
locations: self.locations(),
|
||||
raw_input: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -326,6 +434,10 @@ pub struct EditToolParams {
|
||||
pub new_text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct EditToolResponse;
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct ReadToolParams {
|
||||
/// The absolute path to the file to read.
|
||||
@@ -338,6 +450,12 @@ pub struct ReadToolParams {
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ReadToolResponse {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct WriteToolParams {
|
||||
/// Absolute path for new file
|
||||
|
||||
@@ -1,317 +0,0 @@
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use context_server::listener::McpServerTool;
|
||||
use context_server::types::requests;
|
||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use std::{path::Path, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
|
||||
use crate::mcp_server::ZedMcpServer;
|
||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
|
||||
use acp_thread::{AcpThread, AgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Codex;
|
||||
|
||||
impl AgentServer for Codex {
|
||||
fn name(&self) -> &'static str {
|
||||
"Codex"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
"Welcome to Codex"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"What can I help with?"
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiOpenAi
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let project = project.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||
})?;
|
||||
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
|
||||
else {
|
||||
anyhow::bail!("Failed to find codex binary");
|
||||
};
|
||||
|
||||
let client: Arc<ContextServer> = ContextServer::stdio(
|
||||
ContextServerId("codex-mcp-server".into()),
|
||||
ContextServerCommand {
|
||||
path: command.path,
|
||||
args: command.args,
|
||||
env: command.env,
|
||||
},
|
||||
)
|
||||
.into();
|
||||
ContextServer::start(client.clone(), cx).await?;
|
||||
|
||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||
client
|
||||
.client()
|
||||
.context("Failed to subscribe")?
|
||||
.on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
|
||||
move |notification, _cx| {
|
||||
let notification_tx = notification_tx.clone();
|
||||
log::trace!(
|
||||
"ACP Notification: {}",
|
||||
serde_json::to_string_pretty(¬ification).unwrap()
|
||||
);
|
||||
|
||||
if let Some(notification) =
|
||||
serde_json::from_value::<acp::SessionNotification>(notification)
|
||||
.log_err()
|
||||
{
|
||||
notification_tx.unbounded_send(notification).ok();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let notification_handler_task = cx.spawn({
|
||||
let sessions = sessions.clone();
|
||||
async move |cx| {
|
||||
while let Some(notification) = notification_rx.next().await {
|
||||
CodexConnection::handle_session_notification(
|
||||
notification,
|
||||
sessions.clone(),
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let connection = CodexConnection {
|
||||
client,
|
||||
sessions,
|
||||
_notification_handler_task: notification_handler_task,
|
||||
};
|
||||
Ok(Rc::new(connection) as _)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct CodexConnection {
|
||||
client: Arc<context_server::ContextServer>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
||||
_notification_handler_task: Task<()>,
|
||||
}
|
||||
|
||||
struct CodexSession {
|
||||
thread: WeakEntity<AcpThread>,
|
||||
cancel_tx: Option<oneshot::Sender<()>>,
|
||||
_mcp_server: ZedMcpServer,
|
||||
}
|
||||
|
||||
impl AgentConnection for CodexConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
"Codex"
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let client = self.client.client();
|
||||
let sessions = self.sessions.clone();
|
||||
let cwd = cwd.to_path_buf();
|
||||
cx.spawn(async move |cx| {
|
||||
let client = client.context("MCP server is not initialized yet")?;
|
||||
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
|
||||
|
||||
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
|
||||
|
||||
let response = client
|
||||
.request::<requests::CallTool>(context_server::types::CallToolParams {
|
||||
name: acp::NEW_SESSION_TOOL_NAME.into(),
|
||||
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
|
||||
mcp_servers: [(
|
||||
mcp_server::SERVER_NAME.to_string(),
|
||||
mcp_server.server_config()?,
|
||||
)]
|
||||
.into(),
|
||||
client_tools: acp::ClientTools {
|
||||
request_permission: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
|
||||
}),
|
||||
read_text_file: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
|
||||
}),
|
||||
write_text_file: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
|
||||
}),
|
||||
},
|
||||
cwd,
|
||||
})?),
|
||||
meta: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
if response.is_error.unwrap_or_default() {
|
||||
return Err(anyhow!(response.text_contents()));
|
||||
}
|
||||
|
||||
let result = serde_json::from_value::<acp::NewSessionOutput>(
|
||||
response.structured_content.context("Empty response")?,
|
||||
)?;
|
||||
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
|
||||
|
||||
thread_tx.send(thread.downgrade())?;
|
||||
|
||||
let session = CodexSession {
|
||||
thread: thread.downgrade(),
|
||||
cancel_tx: None,
|
||||
_mcp_server: mcp_server,
|
||||
};
|
||||
sessions.borrow_mut().insert(result.session_id, session);
|
||||
|
||||
Ok(thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
params: agent_client_protocol::PromptArguments,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
let client = self.client.client();
|
||||
let sessions = self.sessions.clone();
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let client = client.context("MCP server is not initialized yet")?;
|
||||
|
||||
let (new_cancel_tx, cancel_rx) = oneshot::channel();
|
||||
{
|
||||
let mut sessions = sessions.borrow_mut();
|
||||
let session = sessions
|
||||
.get_mut(¶ms.session_id)
|
||||
.context("Session not found")?;
|
||||
session.cancel_tx.replace(new_cancel_tx);
|
||||
}
|
||||
|
||||
let result = client
|
||||
.request_with::<requests::CallTool>(
|
||||
context_server::types::CallToolParams {
|
||||
name: acp::PROMPT_TOOL_NAME.into(),
|
||||
arguments: Some(serde_json::to_value(params)?),
|
||||
meta: None,
|
||||
},
|
||||
Some(cancel_rx),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(err) = &result
|
||||
&& err.is::<context_server::client::RequestCanceled>()
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let response = result?;
|
||||
|
||||
if response.is_error.unwrap_or_default() {
|
||||
return Err(anyhow!(response.text_contents()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
|
||||
let mut sessions = self.sessions.borrow_mut();
|
||||
|
||||
if let Some(cancel_tx) = sessions
|
||||
.get_mut(session_id)
|
||||
.and_then(|session| session.cancel_tx.take())
|
||||
{
|
||||
cancel_tx.send(()).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexConnection {
|
||||
pub fn handle_session_notification(
|
||||
notification: acp::SessionNotification,
|
||||
threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let threads = threads.borrow();
|
||||
let Some(thread) = threads
|
||||
.get(¬ification.session_id)
|
||||
.and_then(|session| session.thread.upgrade())
|
||||
else {
|
||||
log::error!(
|
||||
"Thread not found for session ID: {}",
|
||||
notification.session_id
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(notification.update, cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CodexConnection {
|
||||
fn drop(&mut self) {
|
||||
self.client.stop().log_err();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use crate::AgentServerCommand;
|
||||
use std::path::Path;
|
||||
|
||||
crate::common_e2e_tests!(Codex, allow_option_id = "approve");
|
||||
|
||||
pub fn local_command() -> AgentServerCommand {
|
||||
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../../codex/codex-rs/target/debug/codex");
|
||||
|
||||
AgentServerCommand {
|
||||
path: cli_path,
|
||||
args: vec!["mcp".into()],
|
||||
env: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,10 @@
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
|
||||
use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
|
||||
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
|
||||
use agent_client_protocol as acp;
|
||||
|
||||
use acp_thread::{
|
||||
AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
|
||||
};
|
||||
use agentic_coding_protocol as acp;
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
||||
use gpui::{Entity, TestAppContext};
|
||||
use indoc::indoc;
|
||||
@@ -57,25 +54,19 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
vec![
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text: "Read the file ".into(),
|
||||
annotations: None,
|
||||
}),
|
||||
acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||
uri: "foo.rs".into(),
|
||||
name: "foo.rs".into(),
|
||||
annotations: None,
|
||||
description: None,
|
||||
mime_type: None,
|
||||
size: None,
|
||||
title: None,
|
||||
}),
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text: " and tell me what the content of the println! is".into(),
|
||||
annotations: None,
|
||||
}),
|
||||
],
|
||||
acp::SendUserMessageParams {
|
||||
chunks: vec![
|
||||
acp::UserMessageChunk::Text {
|
||||
text: "Read the file ".into(),
|
||||
},
|
||||
acp::UserMessageChunk::Path {
|
||||
path: Path::new("foo.rs").into(),
|
||||
},
|
||||
acp::UserMessageChunk::Text {
|
||||
text: " and tell me what the content of the println! is".into(),
|
||||
},
|
||||
],
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -83,28 +74,21 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries().len(), 3);
|
||||
assert!(matches!(
|
||||
thread.entries()[0],
|
||||
AgentThreadEntry::UserMessage(_)
|
||||
));
|
||||
let assistant_message = &thread
|
||||
.entries()
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(|entry| match entry {
|
||||
AgentThreadEntry::AssistantMessage(msg) => Some(msg),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
|
||||
let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
|
||||
panic!("Expected AssistantMessage")
|
||||
};
|
||||
assert!(
|
||||
assistant_message.to_markdown(cx).contains("Hello, world!"),
|
||||
"unexpected assistant message: {:?}",
|
||||
assistant_message.to_markdown(cx)
|
||||
);
|
||||
});
|
||||
|
||||
drop(tempdir);
|
||||
}
|
||||
|
||||
pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
|
||||
@@ -147,7 +131,6 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp
|
||||
|
||||
pub async fn test_tool_call_with_confirmation(
|
||||
server: impl AgentServer + 'static,
|
||||
allow_option_id: acp::PermissionOptionId,
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
let fs = init_test(cx).await;
|
||||
@@ -178,8 +161,11 @@ pub async fn test_tool_call_with_confirmation(
|
||||
let tool_call_id = thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
content,
|
||||
status: ToolCallStatus::WaitingForConfirmation { .. },
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread
|
||||
.entries()
|
||||
@@ -190,18 +176,13 @@ pub async fn test_tool_call_with_confirmation(
|
||||
panic!();
|
||||
};
|
||||
|
||||
assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
|
||||
assert!(root_command.contains("touch"));
|
||||
|
||||
id.clone()
|
||||
*id
|
||||
});
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.authorize_tool_call(
|
||||
tool_call_id,
|
||||
allow_option_id,
|
||||
acp::PermissionOptionKind::AllowOnce,
|
||||
cx,
|
||||
);
|
||||
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
|
||||
|
||||
assert!(thread.entries().iter().any(|entry| matches!(
|
||||
entry,
|
||||
@@ -216,7 +197,7 @@ pub async fn test_tool_call_with_confirmation(
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
content,
|
||||
content: Some(ToolCallContent::Markdown { markdown }),
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
}) = thread
|
||||
@@ -228,10 +209,13 @@ pub async fn test_tool_call_with_confirmation(
|
||||
panic!();
|
||||
};
|
||||
|
||||
assert!(
|
||||
content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
|
||||
"Expected content to contain 'Hello'"
|
||||
);
|
||||
markdown.read_with(cx, |md, _cx| {
|
||||
assert!(
|
||||
md.source().contains("Hello"),
|
||||
r#"Expected '{}' to contain "Hello""#,
|
||||
md.source()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -265,20 +249,26 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
content,
|
||||
status: ToolCallStatus::WaitingForConfirmation { .. },
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread.entries()[first_tool_call_ix]
|
||||
else {
|
||||
panic!("{:?}", thread.entries()[1]);
|
||||
};
|
||||
|
||||
assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
|
||||
assert!(root_command.contains("touch"));
|
||||
|
||||
id.clone()
|
||||
*id
|
||||
});
|
||||
|
||||
let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.cancel(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
full_turn.await.unwrap();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
@@ -306,7 +296,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! common_e2e_tests {
|
||||
($server:expr, allow_option_id = $allow_option_id:expr) => {
|
||||
($server:expr) => {
|
||||
mod common_e2e {
|
||||
use super::*;
|
||||
|
||||
@@ -331,12 +321,7 @@ macro_rules! common_e2e_tests {
|
||||
#[::gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_tool_call_with_confirmation(
|
||||
$server,
|
||||
::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
$crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
|
||||
}
|
||||
|
||||
#[::gpui::test]
|
||||
@@ -368,9 +353,6 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
||||
gemini: Some(AgentServerSettings {
|
||||
command: crate::gemini::tests::local_command(),
|
||||
}),
|
||||
codex: Some(AgentServerSettings {
|
||||
command: crate::codex::tests::local_command(),
|
||||
}),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
@@ -387,16 +369,15 @@ pub async fn new_test_thread(
|
||||
current_dir: impl AsRef<Path>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> Entity<AcpThread> {
|
||||
let connection = cx
|
||||
.update(|cx| server.connect(current_dir.as_ref(), &project, cx))
|
||||
let thread = cx
|
||||
.update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let thread = connection
|
||||
.new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
|
||||
thread
|
||||
.update(cx, |thread, _| thread.initialize())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread
|
||||
}
|
||||
|
||||
@@ -429,24 +410,3 @@ pub async fn run_until_first_tool_call(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_zed_path() -> PathBuf {
|
||||
let mut zed_path = std::env::current_exe().unwrap();
|
||||
|
||||
while zed_path
|
||||
.file_name()
|
||||
.map_or(true, |name| name.to_string_lossy() != "debug")
|
||||
{
|
||||
if !zed_path.pop() {
|
||||
panic!("Could not find target directory");
|
||||
}
|
||||
}
|
||||
|
||||
zed_path.push("zed");
|
||||
|
||||
if !zed_path.exists() {
|
||||
panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
|
||||
}
|
||||
|
||||
zed_path
|
||||
}
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
use anyhow::anyhow;
|
||||
use std::cell::RefCell;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
|
||||
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
|
||||
use agentic_coding_protocol as acp_old;
|
||||
use crate::stdio_agent_server::StdioAgentServer;
|
||||
use crate::{AgentServerCommand, AgentServerVersion};
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use ui::App;
|
||||
|
||||
use crate::AllAgentServersSettings;
|
||||
|
||||
@@ -20,7 +12,7 @@ pub struct Gemini;
|
||||
|
||||
const ACP_ARG: &str = "--experimental-acp";
|
||||
|
||||
impl AgentServer for Gemini {
|
||||
impl StdioAgentServer for Gemini {
|
||||
fn name(&self) -> &'static str {
|
||||
"Gemini"
|
||||
}
|
||||
@@ -33,88 +25,14 @@ impl AgentServer for Gemini {
|
||||
"Ask questions, edit files, run commands.\nBe specific for the best results."
|
||||
}
|
||||
|
||||
fn supports_always_allow(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiGemini
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
let project = project.clone();
|
||||
let this = self.clone();
|
||||
let name = self.name();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let command = this.command(&project, cx).await?;
|
||||
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
|
||||
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
|
||||
|
||||
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
|
||||
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
|
||||
stdin,
|
||||
stdout,
|
||||
move |fut| foreground_executor.spawn(fut).detach(),
|
||||
);
|
||||
|
||||
let io_task = cx.background_spawn(async move {
|
||||
io_fut.await.log_err();
|
||||
});
|
||||
|
||||
let child_status = cx.background_spawn(async move {
|
||||
let result = match child.status().await {
|
||||
Err(e) => Err(anyhow!(e)),
|
||||
Ok(result) if result.success() => Ok(()),
|
||||
Ok(result) => {
|
||||
if let Some(AgentServerVersion::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command,
|
||||
}) = this.version(&command).await.log_err()
|
||||
{
|
||||
Err(anyhow!(LoadError::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command
|
||||
}))
|
||||
} else {
|
||||
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
|
||||
}
|
||||
}
|
||||
};
|
||||
drop(io_task);
|
||||
result
|
||||
});
|
||||
|
||||
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
|
||||
name,
|
||||
connection,
|
||||
child_status,
|
||||
});
|
||||
|
||||
Ok(connection)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Gemini {
|
||||
async fn command(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
@@ -188,7 +106,7 @@ pub(crate) mod tests {
|
||||
use crate::AgentServerCommand;
|
||||
use std::path::Path;
|
||||
|
||||
crate::common_e2e_tests!(Gemini, allow_option_id = "0");
|
||||
crate::common_e2e_tests!(Gemini);
|
||||
|
||||
pub fn local_command() -> AgentServerCommand {
|
||||
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
|
||||
@@ -1,207 +0,0 @@
|
||||
use acp_thread::AcpThread;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use context_server::listener::{McpServerTool, ToolResponse};
|
||||
use context_server::types::{
|
||||
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
|
||||
ToolsCapabilities, requests,
|
||||
};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{App, AsyncApp, Task, WeakEntity};
|
||||
use indoc::indoc;
|
||||
|
||||
pub struct ZedMcpServer {
|
||||
server: context_server::listener::McpServer,
|
||||
}
|
||||
|
||||
pub const SERVER_NAME: &str = "zed";
|
||||
|
||||
impl ZedMcpServer {
|
||||
pub async fn new(
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
|
||||
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
||||
|
||||
mcp_server.add_tool(RequestPermissionTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(ReadTextFileTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(WriteTextFileTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
|
||||
Ok(Self { server: mcp_server })
|
||||
}
|
||||
|
||||
pub fn server_config(&self) -> Result<acp::McpServerConfig> {
|
||||
#[cfg(not(test))]
|
||||
let zed_path = anyhow::Context::context(
|
||||
std::env::current_exe(),
|
||||
"finding current executable path for use in mcp_server",
|
||||
)?;
|
||||
|
||||
#[cfg(test)]
|
||||
let zed_path = crate::e2e_tests::get_zed_path();
|
||||
|
||||
Ok(acp::McpServerConfig {
|
||||
command: zed_path,
|
||||
args: vec![
|
||||
"--nc".into(),
|
||||
self.server.socket_path().display().to_string(),
|
||||
],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
Ok(InitializeResponse {
|
||||
protocol_version: ProtocolVersion("2025-06-18".into()),
|
||||
capabilities: ServerCapabilities {
|
||||
experimental: None,
|
||||
logging: None,
|
||||
completions: None,
|
||||
prompts: None,
|
||||
resources: None,
|
||||
tools: Some(ToolsCapabilities {
|
||||
list_changed: Some(false),
|
||||
}),
|
||||
},
|
||||
server_info: Implementation {
|
||||
name: SERVER_NAME.into(),
|
||||
version: "0.1.0".into(),
|
||||
},
|
||||
meta: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Tools
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RequestPermissionTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for RequestPermissionTool {
|
||||
type Input = acp::RequestPermissionArguments;
|
||||
type Output = acp::RequestPermissionOutput;
|
||||
|
||||
const NAME: &'static str = "Confirmation";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
indoc! {"
|
||||
Request permission for tool calls.
|
||||
|
||||
This tool is meant to be called programmatically by the agent loop, not the LLM.
|
||||
"}
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let result = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(input.tool_call, input.options, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let outcome = match result {
|
||||
Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id },
|
||||
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
|
||||
};
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: acp::RequestPermissionOutput { outcome },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ReadTextFileTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for ReadTextFileTool {
|
||||
type Input = acp::ReadTextFileArguments;
|
||||
type Output = acp::ReadTextFileOutput;
|
||||
|
||||
const NAME: &'static str = "Read";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Reads the content of the given file in the project including unsaved changes."
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let content = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(input.path, input.line, input.limit, false, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: acp::ReadTextFileOutput { content },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WriteTextFileTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for WriteTextFileTool {
|
||||
type Input = acp::WriteTextFileArguments;
|
||||
type Output = ();
|
||||
|
||||
const NAME: &'static str = "Write";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Write to a file replacing its contents"
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.write_text_file(input.path, input.content, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: (),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,6 @@ pub fn init(cx: &mut App) {
|
||||
pub struct AllAgentServersSettings {
|
||||
pub gemini: Option<AgentServerSettings>,
|
||||
pub claude: Option<AgentServerSettings>,
|
||||
pub codex: Option<AgentServerSettings>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
@@ -30,21 +29,13 @@ impl settings::Settings for AllAgentServersSettings {
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
let mut settings = AllAgentServersSettings::default();
|
||||
|
||||
for AllAgentServersSettings {
|
||||
gemini,
|
||||
claude,
|
||||
codex,
|
||||
} in sources.defaults_and_customizations()
|
||||
{
|
||||
for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
|
||||
if gemini.is_some() {
|
||||
settings.gemini = gemini.clone();
|
||||
}
|
||||
if claude.is_some() {
|
||||
settings.claude = claude.clone();
|
||||
}
|
||||
if codex.is_some() {
|
||||
settings.codex = codex.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
|
||||
119
crates/agent_servers/src/stdio_agent_server.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
|
||||
use acp_thread::{AcpClientDelegate, AcpThread, LoadError};
|
||||
use agentic_coding_protocol as acp;
|
||||
use anyhow::{Result, anyhow};
|
||||
use gpui::{App, AsyncApp, Entity, Task, prelude::*};
|
||||
use project::Project;
|
||||
use std::path::Path;
|
||||
use util::ResultExt;
|
||||
|
||||
pub trait StdioAgentServer: Send + Clone {
|
||||
fn logo(&self) -> ui::IconName;
|
||||
fn name(&self) -> &'static str;
|
||||
fn empty_state_headline(&self) -> &'static str;
|
||||
fn empty_state_message(&self) -> &'static str;
|
||||
fn supports_always_allow(&self) -> bool;
|
||||
|
||||
fn command(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> impl Future<Output = Result<AgentServerCommand>>;
|
||||
|
||||
fn version(
|
||||
&self,
|
||||
command: &AgentServerCommand,
|
||||
) -> impl Future<Output = Result<AgentServerVersion>> + Send;
|
||||
}
|
||||
|
||||
impl<T: StdioAgentServer + 'static> AgentServer for T {
|
||||
fn name(&self) -> &'static str {
|
||||
self.name()
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
self.empty_state_headline()
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
self.empty_state_message()
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
self.logo()
|
||||
}
|
||||
|
||||
fn supports_always_allow(&self) -> bool {
|
||||
self.supports_always_allow()
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
&self,
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
let project = project.clone();
|
||||
let this = self.clone();
|
||||
let title = self.name().into();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let command = this.command(&project, cx).await?;
|
||||
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
|
||||
cx.new(|cx| {
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
|
||||
let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
|
||||
AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
|
||||
stdin,
|
||||
stdout,
|
||||
move |fut| foreground_executor.spawn(fut).detach(),
|
||||
);
|
||||
|
||||
let io_task = cx.background_spawn(async move {
|
||||
io_fut.await.log_err();
|
||||
});
|
||||
|
||||
let child_status = cx.background_spawn(async move {
|
||||
let result = match child.status().await {
|
||||
Err(e) => Err(anyhow!(e)),
|
||||
Ok(result) if result.success() => Ok(()),
|
||||
Ok(result) => {
|
||||
if let Some(AgentServerVersion::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command,
|
||||
}) = this.version(&command).await.log_err()
|
||||
{
|
||||
Err(anyhow!(LoadError::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command
|
||||
}))
|
||||
} else {
|
||||
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
|
||||
}
|
||||
}
|
||||
};
|
||||
drop(io_task);
|
||||
result
|
||||
});
|
||||
|
||||
AcpThread::new(connection, title, Some(child_status), project.clone(), cx)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -17,10 +17,10 @@ test-support = ["gpui/test-support", "language/test-support"]
|
||||
|
||||
[dependencies]
|
||||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent.workspace = true
|
||||
agent_servers.workspace = true
|
||||
agentic-coding-protocol.workspace = true
|
||||
agent_settings.workspace = true
|
||||
agent_servers.workspace = true
|
||||
ai_onboarding.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_context.workspace = true
|
||||
@@ -32,7 +32,6 @@ buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
component.workspace = true
|
||||
context_server.workspace = true
|
||||
db.workspace = true
|
||||
@@ -54,7 +53,6 @@ itertools.workspace = true
|
||||
jsonschema.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
markdown.workspace = true
|
||||
@@ -89,7 +87,6 @@ theme.workspace = true
|
||||
time.workspace = true
|
||||
time_format.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
urlencoding.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use acp_thread::{AgentConnection, Plan};
|
||||
use acp_thread::Plan;
|
||||
use agent_servers::AgentServer;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::BTreeMap;
|
||||
@@ -7,7 +7,7 @@ use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use agentic_coding_protocol::{self as acp};
|
||||
use assistant_tool::ActionLog;
|
||||
use buffer_diff::BufferDiff;
|
||||
use collections::{HashMap, HashSet};
|
||||
@@ -16,6 +16,7 @@ use editor::{
|
||||
EditorStyle, MinimapVisibility, MultiBuffer, PathKey,
|
||||
};
|
||||
use file_icons::FileIcons;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
|
||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement,
|
||||
@@ -38,7 +39,8 @@ use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
|
||||
|
||||
use ::acp_thread::{
|
||||
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff,
|
||||
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
|
||||
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent,
|
||||
ToolCallId, ToolCallStatus,
|
||||
};
|
||||
|
||||
use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet};
|
||||
@@ -62,13 +64,12 @@ pub struct AcpThreadView {
|
||||
last_error: Option<Entity<Markdown>>,
|
||||
list_state: ListState,
|
||||
auth_task: Option<Task<()>>,
|
||||
expanded_tool_calls: HashSet<acp::ToolCallId>,
|
||||
expanded_tool_calls: HashSet<ToolCallId>,
|
||||
expanded_thinking_blocks: HashSet<(usize, usize)>,
|
||||
edits_expanded: bool,
|
||||
plan_expanded: bool,
|
||||
editor_expanded: bool,
|
||||
message_history: Rc<RefCell<MessageHistory<Vec<acp::ContentBlock>>>>,
|
||||
_cancel_task: Option<Task<()>>,
|
||||
message_history: Rc<RefCell<MessageHistory<acp::SendUserMessageParams>>>,
|
||||
}
|
||||
|
||||
enum ThreadState {
|
||||
@@ -81,16 +82,22 @@ enum ThreadState {
|
||||
},
|
||||
LoadError(LoadError),
|
||||
Unauthenticated {
|
||||
connection: Rc<dyn AgentConnection>,
|
||||
thread: Entity<AcpThread>,
|
||||
},
|
||||
}
|
||||
|
||||
struct AlwaysAllowOption {
|
||||
id: &'static str,
|
||||
label: SharedString,
|
||||
outcome: acp::ToolCallConfirmationOutcome,
|
||||
}
|
||||
|
||||
impl AcpThreadView {
|
||||
pub fn new(
|
||||
agent: Rc<dyn AgentServer>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
message_history: Rc<RefCell<MessageHistory<Vec<acp::ContentBlock>>>>,
|
||||
message_history: Rc<RefCell<MessageHistory<acp::SendUserMessageParams>>>,
|
||||
min_lines: usize,
|
||||
max_lines: Option<usize>,
|
||||
window: &mut Window,
|
||||
@@ -184,7 +191,6 @@ impl AcpThreadView {
|
||||
plan_expanded: false,
|
||||
editor_expanded: false,
|
||||
message_history,
|
||||
_cancel_task: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,9 +208,9 @@ impl AcpThreadView {
|
||||
.map(|worktree| worktree.read(cx).abs_path())
|
||||
.unwrap_or_else(|| paths::home_dir().as_path().into());
|
||||
|
||||
let connect_task = agent.connect(&root_dir, &project, cx);
|
||||
let task = agent.new_thread(&root_dir, &project, cx);
|
||||
let load_task = cx.spawn_in(window, async move |this, cx| {
|
||||
let connection = match connect_task.await {
|
||||
let thread = match task.await {
|
||||
Ok(thread) => thread,
|
||||
Err(err) => {
|
||||
this.update(cx, |this, cx| {
|
||||
@@ -216,30 +222,48 @@ impl AcpThreadView {
|
||||
}
|
||||
};
|
||||
|
||||
let result = match connection
|
||||
.clone()
|
||||
.new_thread(project.clone(), &root_dir, cx)
|
||||
.await
|
||||
{
|
||||
let init_response = async {
|
||||
let resp = thread
|
||||
.read_with(cx, |thread, _cx| thread.initialize())?
|
||||
.await?;
|
||||
anyhow::Ok(resp)
|
||||
};
|
||||
|
||||
let result = match init_response.await {
|
||||
Err(e) => {
|
||||
let mut cx = cx.clone();
|
||||
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
return;
|
||||
if e.downcast_ref::<oneshot::Canceled>().is_some() {
|
||||
let child_status = thread
|
||||
.update(&mut cx, |thread, _| thread.child_status())
|
||||
.ok()
|
||||
.flatten();
|
||||
if let Some(child_status) = child_status {
|
||||
match child_status.await {
|
||||
Ok(_) => Err(e),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
Ok(session_id) => Ok(session_id),
|
||||
Ok(response) => {
|
||||
if !response.is_authenticated {
|
||||
this.update(cx, |this, _| {
|
||||
this.thread_state = ThreadState::Unauthenticated { thread };
|
||||
})
|
||||
.ok();
|
||||
return;
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
match result {
|
||||
Ok(thread) => {
|
||||
Ok(()) => {
|
||||
let thread_subscription =
|
||||
cx.subscribe_in(&thread, window, Self::handle_thread_event);
|
||||
|
||||
@@ -281,10 +305,10 @@ impl AcpThreadView {
|
||||
|
||||
pub fn thread(&self) -> Option<&Entity<AcpThread>> {
|
||||
match &self.thread_state {
|
||||
ThreadState::Ready { thread, .. } => Some(thread),
|
||||
ThreadState::Unauthenticated { .. }
|
||||
| ThreadState::Loading { .. }
|
||||
| ThreadState::LoadError(..) => None,
|
||||
ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => {
|
||||
Some(thread)
|
||||
}
|
||||
ThreadState::Loading { .. } | ThreadState::LoadError(..) => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,7 +325,7 @@ impl AcpThreadView {
|
||||
self.last_error.take();
|
||||
|
||||
if let Some(thread) = self.thread() {
|
||||
self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
|
||||
thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,7 +362,7 @@ impl AcpThreadView {
|
||||
self.last_error.take();
|
||||
|
||||
let mut ix = 0;
|
||||
let mut chunks: Vec<acp::ContentBlock> = Vec::new();
|
||||
let mut chunks: Vec<acp::UserMessageChunk> = Vec::new();
|
||||
let project = self.project.clone();
|
||||
self.message_editor.update(cx, |editor, cx| {
|
||||
let text = editor.text(cx);
|
||||
@@ -350,19 +374,12 @@ impl AcpThreadView {
|
||||
{
|
||||
let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot);
|
||||
if crease_range.start > ix {
|
||||
chunks.push(text[ix..crease_range.start].into());
|
||||
chunks.push(acp::UserMessageChunk::Text {
|
||||
text: text[ix..crease_range.start].to_string(),
|
||||
});
|
||||
}
|
||||
if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) {
|
||||
let path_str = abs_path.display().to_string();
|
||||
chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||
uri: path_str.clone(),
|
||||
name: path_str,
|
||||
annotations: None,
|
||||
description: None,
|
||||
mime_type: None,
|
||||
size: None,
|
||||
title: None,
|
||||
}));
|
||||
chunks.push(acp::UserMessageChunk::Path { path: abs_path });
|
||||
}
|
||||
ix = crease_range.end;
|
||||
}
|
||||
@@ -371,7 +388,9 @@ impl AcpThreadView {
|
||||
if ix < text.len() {
|
||||
let last_chunk = text[ix..].trim();
|
||||
if !last_chunk.is_empty() {
|
||||
chunks.push(last_chunk.into());
|
||||
chunks.push(acp::UserMessageChunk::Text {
|
||||
text: last_chunk.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -382,7 +401,8 @@ impl AcpThreadView {
|
||||
}
|
||||
|
||||
let Some(thread) = self.thread() else { return };
|
||||
let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx));
|
||||
let message = acp::SendUserMessageParams { chunks };
|
||||
let task = thread.update(cx, |thread, cx| thread.send(message.clone(), cx));
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let result = task.await;
|
||||
@@ -404,7 +424,7 @@ impl AcpThreadView {
|
||||
editor.remove_creases(mention_set.lock().drain(), cx)
|
||||
});
|
||||
|
||||
self.message_history.borrow_mut().push(chunks);
|
||||
self.message_history.borrow_mut().push(message);
|
||||
}
|
||||
|
||||
fn previous_history_message(
|
||||
@@ -470,7 +490,7 @@ impl AcpThreadView {
|
||||
message_editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
project: Entity<Project>,
|
||||
message: Option<&Vec<acp::ContentBlock>>,
|
||||
message: Option<&acp::SendUserMessageParams>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
@@ -483,19 +503,18 @@ impl AcpThreadView {
|
||||
let mut text = String::new();
|
||||
let mut mentions = Vec::new();
|
||||
|
||||
for chunk in message {
|
||||
for chunk in &message.chunks {
|
||||
match chunk {
|
||||
acp::ContentBlock::Text(text_content) => {
|
||||
text.push_str(&text_content.text);
|
||||
acp::UserMessageChunk::Text { text: chunk } => {
|
||||
text.push_str(&chunk);
|
||||
}
|
||||
acp::ContentBlock::ResourceLink(resource_link) => {
|
||||
let path = Path::new(&resource_link.uri);
|
||||
acp::UserMessageChunk::Path { path } => {
|
||||
let start = text.len();
|
||||
let content = MentionPath::new(&path).to_string();
|
||||
let content = MentionPath::new(path).to_string();
|
||||
text.push_str(&content);
|
||||
let end = text.len();
|
||||
if let Some(project_path) =
|
||||
project.read(cx).project_path_for_absolute_path(&path, cx)
|
||||
project.read(cx).project_path_for_absolute_path(path, cx)
|
||||
{
|
||||
let filename: SharedString = path
|
||||
.file_name()
|
||||
@@ -506,9 +525,6 @@ impl AcpThreadView {
|
||||
mentions.push((start..end, project_path, filename));
|
||||
}
|
||||
}
|
||||
acp::ContentBlock::Image(_)
|
||||
| acp::ContentBlock::Audio(_)
|
||||
| acp::ContentBlock::Resource(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -574,79 +590,71 @@ impl AcpThreadView {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else {
|
||||
let Some(multibuffer) = self.entry_diff_multibuffer(entry_ix, cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let multibuffers = multibuffers.collect::<Vec<_>>();
|
||||
|
||||
for multibuffer in multibuffers {
|
||||
if self.diff_editors.contains_key(&multibuffer.entity_id()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::new(
|
||||
EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: true,
|
||||
},
|
||||
multibuffer.clone(),
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.disable_inline_diagnostics();
|
||||
editor.disable_expand_excerpt_buttons(cx);
|
||||
editor.set_show_vertical_scrollbar(false, cx);
|
||||
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
|
||||
editor.set_soft_wrap_mode(SoftWrap::None, cx);
|
||||
editor.scroll_manager.set_forbid_vertical_scroll(true);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_read_only(true);
|
||||
editor.set_show_breakpoints(false, cx);
|
||||
editor.set_show_code_actions(false, cx);
|
||||
editor.set_show_git_diff_gutter(false, cx);
|
||||
editor.set_expand_all_diff_hunks(cx);
|
||||
editor.set_text_style_refinement(TextStyleRefinement {
|
||||
font_size: Some(
|
||||
TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
|
||||
.into(),
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
editor
|
||||
});
|
||||
let entity_id = multibuffer.entity_id();
|
||||
cx.observe_release(&multibuffer, move |this, _, _| {
|
||||
this.diff_editors.remove(&entity_id);
|
||||
})
|
||||
.detach();
|
||||
|
||||
self.diff_editors.insert(entity_id, editor);
|
||||
if self.diff_editors.contains_key(&multibuffer.entity_id()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::new(
|
||||
EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: true,
|
||||
},
|
||||
multibuffer.clone(),
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.disable_inline_diagnostics();
|
||||
editor.disable_expand_excerpt_buttons(cx);
|
||||
editor.set_show_vertical_scrollbar(false, cx);
|
||||
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
|
||||
editor.set_soft_wrap_mode(SoftWrap::None, cx);
|
||||
editor.scroll_manager.set_forbid_vertical_scroll(true);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_read_only(true);
|
||||
editor.set_show_breakpoints(false, cx);
|
||||
editor.set_show_code_actions(false, cx);
|
||||
editor.set_show_git_diff_gutter(false, cx);
|
||||
editor.set_expand_all_diff_hunks(cx);
|
||||
editor.set_text_style_refinement(TextStyleRefinement {
|
||||
font_size: Some(
|
||||
TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
|
||||
.into(),
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
editor
|
||||
});
|
||||
let entity_id = multibuffer.entity_id();
|
||||
cx.observe_release(&multibuffer, move |this, _, _| {
|
||||
this.diff_editors.remove(&entity_id);
|
||||
})
|
||||
.detach();
|
||||
|
||||
self.diff_editors.insert(entity_id, editor);
|
||||
}
|
||||
|
||||
fn entry_diff_multibuffers(
|
||||
&self,
|
||||
entry_ix: usize,
|
||||
cx: &App,
|
||||
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
|
||||
fn entry_diff_multibuffer(&self, entry_ix: usize, cx: &App) -> Option<Entity<MultiBuffer>> {
|
||||
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
|
||||
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
||||
entry.diff().map(|diff| diff.multibuffer.clone())
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
||||
let Some(thread) = self.thread().cloned() else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.last_error.take();
|
||||
let authenticate = connection.authenticate(cx);
|
||||
let authenticate = thread.read(cx).authenticate();
|
||||
self.auth_task = Some(cx.spawn_in(window, {
|
||||
let project = self.project.clone();
|
||||
let agent = self.agent.clone();
|
||||
@@ -676,16 +684,15 @@ impl AcpThreadView {
|
||||
|
||||
fn authorize_tool_call(
|
||||
&mut self,
|
||||
tool_call_id: acp::ToolCallId,
|
||||
option_id: acp::PermissionOptionId,
|
||||
option_kind: acp::PermissionOptionKind,
|
||||
id: ToolCallId,
|
||||
outcome: acp::ToolCallConfirmationOutcome,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(thread) = self.thread() else {
|
||||
return;
|
||||
};
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.authorize_tool_call(tool_call_id, option_id, option_kind, cx);
|
||||
thread.authorize_tool_call(id, outcome, cx);
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
@@ -712,12 +719,10 @@ impl AcpThreadView {
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.text_xs()
|
||||
.children(message.content.markdown().map(|md| {
|
||||
self.render_markdown(
|
||||
md.clone(),
|
||||
user_message_markdown_style(window, cx),
|
||||
)
|
||||
})),
|
||||
.child(self.render_markdown(
|
||||
message.content.clone(),
|
||||
user_message_markdown_style(window, cx),
|
||||
)),
|
||||
)
|
||||
.into_any(),
|
||||
AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => {
|
||||
@@ -725,28 +730,20 @@ impl AcpThreadView {
|
||||
let message_body = v_flex()
|
||||
.w_full()
|
||||
.gap_2p5()
|
||||
.children(chunks.iter().enumerate().filter_map(
|
||||
|(chunk_ix, chunk)| match chunk {
|
||||
AssistantMessageChunk::Message { block } => {
|
||||
block.markdown().map(|md| {
|
||||
self.render_markdown(md.clone(), style.clone())
|
||||
.into_any_element()
|
||||
})
|
||||
}
|
||||
AssistantMessageChunk::Thought { block } => {
|
||||
block.markdown().map(|md| {
|
||||
self.render_thinking_block(
|
||||
index,
|
||||
chunk_ix,
|
||||
md.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
})
|
||||
}
|
||||
},
|
||||
))
|
||||
.children(chunks.iter().enumerate().map(|(chunk_ix, chunk)| {
|
||||
match chunk {
|
||||
AssistantMessageChunk::Text { chunk } => self
|
||||
.render_markdown(chunk.clone(), style.clone())
|
||||
.into_any_element(),
|
||||
AssistantMessageChunk::Thought { chunk } => self.render_thinking_block(
|
||||
index,
|
||||
chunk_ix,
|
||||
chunk.clone(),
|
||||
window,
|
||||
cx,
|
||||
),
|
||||
}
|
||||
}))
|
||||
.into_any();
|
||||
|
||||
v_flex()
|
||||
@@ -872,12 +869,9 @@ impl AcpThreadView {
|
||||
let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix));
|
||||
|
||||
let status_icon = match &tool_call.status {
|
||||
ToolCallStatus::WaitingForConfirmation { .. } => None,
|
||||
ToolCallStatus::Allowed {
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
}
|
||||
| ToolCallStatus::WaitingForConfirmation { .. } => None,
|
||||
ToolCallStatus::Allowed {
|
||||
status: acp::ToolCallStatus::InProgress,
|
||||
status: acp::ToolCallStatus::Running,
|
||||
..
|
||||
} => Some(
|
||||
Icon::new(IconName::ArrowCircle)
|
||||
@@ -891,13 +885,13 @@ impl AcpThreadView {
|
||||
.into_any(),
|
||||
),
|
||||
ToolCallStatus::Allowed {
|
||||
status: acp::ToolCallStatus::Completed,
|
||||
status: acp::ToolCallStatus::Finished,
|
||||
..
|
||||
} => None,
|
||||
ToolCallStatus::Rejected
|
||||
| ToolCallStatus::Canceled
|
||||
| ToolCallStatus::Allowed {
|
||||
status: acp::ToolCallStatus::Failed,
|
||||
status: acp::ToolCallStatus::Error,
|
||||
..
|
||||
} => Some(
|
||||
Icon::new(IconName::X)
|
||||
@@ -915,9 +909,34 @@ impl AcpThreadView {
|
||||
.any(|content| matches!(content, ToolCallContent::Diff { .. })),
|
||||
};
|
||||
|
||||
let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation;
|
||||
let is_collapsible = tool_call.content.is_some() && !needs_confirmation;
|
||||
let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id);
|
||||
|
||||
let content = if is_open {
|
||||
match &tool_call.status {
|
||||
ToolCallStatus::WaitingForConfirmation { confirmation, .. } => {
|
||||
Some(self.render_tool_call_confirmation(
|
||||
tool_call.id,
|
||||
confirmation,
|
||||
tool_call.content.as_ref(),
|
||||
window,
|
||||
cx,
|
||||
))
|
||||
}
|
||||
ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {
|
||||
tool_call.content.as_ref().map(|content| {
|
||||
div()
|
||||
.py_1p5()
|
||||
.child(self.render_tool_call_content(content, window, cx))
|
||||
.into_any_element()
|
||||
})
|
||||
}
|
||||
ToolCallStatus::Rejected => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.when(needs_confirmation, |this| {
|
||||
this.rounded_lg()
|
||||
@@ -957,19 +976,9 @@ impl AcpThreadView {
|
||||
})
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(match tool_call.kind {
|
||||
acp::ToolKind::Read => IconName::ToolRead,
|
||||
acp::ToolKind::Edit => IconName::ToolPencil,
|
||||
acp::ToolKind::Delete => IconName::ToolDeleteFile,
|
||||
acp::ToolKind::Move => IconName::ArrowRightLeft,
|
||||
acp::ToolKind::Search => IconName::ToolSearch,
|
||||
acp::ToolKind::Execute => IconName::ToolTerminal,
|
||||
acp::ToolKind::Think => IconName::ToolBulb,
|
||||
acp::ToolKind::Fetch => IconName::ToolWeb,
|
||||
acp::ToolKind::Other => IconName::ToolHammer,
|
||||
})
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
Icon::new(tool_call.icon)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(if tool_call.locations.len() == 1 {
|
||||
let name = tool_call.locations[0]
|
||||
@@ -1014,16 +1023,16 @@ impl AcpThreadView {
|
||||
.gap_0p5()
|
||||
.when(is_collapsible, |this| {
|
||||
this.child(
|
||||
Disclosure::new(("expand", entry_ix), is_open)
|
||||
Disclosure::new(("expand", tool_call.id.0), is_open)
|
||||
.opened_icon(IconName::ChevronUp)
|
||||
.closed_icon(IconName::ChevronDown)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call.id.clone();
|
||||
let id = tool_call.id;
|
||||
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
|
||||
if is_open {
|
||||
this.expanded_tool_calls.remove(&id);
|
||||
} else {
|
||||
this.expanded_tool_calls.insert(id.clone());
|
||||
this.expanded_tool_calls.insert(id);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
@@ -1033,12 +1042,12 @@ impl AcpThreadView {
|
||||
.children(status_icon),
|
||||
)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call.id.clone();
|
||||
let id = tool_call.id;
|
||||
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
|
||||
if is_open {
|
||||
this.expanded_tool_calls.remove(&id);
|
||||
} else {
|
||||
this.expanded_tool_calls.insert(id.clone());
|
||||
this.expanded_tool_calls.insert(id);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
@@ -1046,7 +1055,7 @@ impl AcpThreadView {
|
||||
)
|
||||
.when(is_open, |this| {
|
||||
this.child(
|
||||
v_flex()
|
||||
div()
|
||||
.text_xs()
|
||||
.when(is_collapsible, |this| {
|
||||
this.mt_1()
|
||||
@@ -1055,45 +1064,7 @@ impl AcpThreadView {
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_lg()
|
||||
})
|
||||
.map(|this| {
|
||||
if is_open {
|
||||
match &tool_call.status {
|
||||
ToolCallStatus::WaitingForConfirmation { options, .. } => this
|
||||
.children(tool_call.content.iter().map(|content| {
|
||||
div()
|
||||
.py_1p5()
|
||||
.child(
|
||||
self.render_tool_call_content(
|
||||
content, window, cx,
|
||||
),
|
||||
)
|
||||
.into_any_element()
|
||||
}))
|
||||
.child(self.render_permission_buttons(
|
||||
options,
|
||||
entry_ix,
|
||||
tool_call.id.clone(),
|
||||
tool_call.content.is_empty(),
|
||||
cx,
|
||||
)),
|
||||
ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {
|
||||
this.children(tool_call.content.iter().map(|content| {
|
||||
div()
|
||||
.py_1p5()
|
||||
.child(
|
||||
self.render_tool_call_content(
|
||||
content, window, cx,
|
||||
),
|
||||
)
|
||||
.into_any_element()
|
||||
}))
|
||||
}
|
||||
ToolCallStatus::Rejected => this,
|
||||
}
|
||||
} else {
|
||||
this
|
||||
}
|
||||
}),
|
||||
.children(content),
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -1105,20 +1076,14 @@ impl AcpThreadView {
|
||||
cx: &Context<Self>,
|
||||
) -> AnyElement {
|
||||
match content {
|
||||
ToolCallContent::ContentBlock { content } => {
|
||||
if let Some(md) = content.markdown() {
|
||||
div()
|
||||
.p_2()
|
||||
.child(
|
||||
self.render_markdown(
|
||||
md.clone(),
|
||||
default_markdown_style(false, window, cx),
|
||||
),
|
||||
)
|
||||
.into_any_element()
|
||||
} else {
|
||||
Empty.into_any_element()
|
||||
}
|
||||
ToolCallContent::Markdown { markdown } => {
|
||||
div()
|
||||
.p_2()
|
||||
.child(self.render_markdown(
|
||||
markdown.clone(),
|
||||
default_markdown_style(false, window, cx),
|
||||
))
|
||||
.into_any_element()
|
||||
}
|
||||
ToolCallContent::Diff {
|
||||
diff: Diff { multibuffer, .. },
|
||||
@@ -1127,56 +1092,223 @@ impl AcpThreadView {
|
||||
}
|
||||
}
|
||||
|
||||
fn render_permission_buttons(
|
||||
fn render_tool_call_confirmation(
|
||||
&self,
|
||||
options: &[acp::PermissionOption],
|
||||
entry_ix: usize,
|
||||
tool_call_id: acp::ToolCallId,
|
||||
empty_content: bool,
|
||||
tool_call_id: ToolCallId,
|
||||
confirmation: &ToolCallConfirmation,
|
||||
content: Option<&ToolCallContent>,
|
||||
window: &Window,
|
||||
cx: &Context<Self>,
|
||||
) -> AnyElement {
|
||||
let confirmation_container = v_flex().mt_1().py_1p5();
|
||||
|
||||
match confirmation {
|
||||
ToolCallConfirmation::Edit { description } => confirmation_container
|
||||
.child(
|
||||
div()
|
||||
.px_2()
|
||||
.children(description.clone().map(|description| {
|
||||
self.render_markdown(
|
||||
description,
|
||||
default_markdown_style(false, window, cx),
|
||||
)
|
||||
})),
|
||||
)
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[AlwaysAllowOption {
|
||||
id: "always_allow",
|
||||
label: "Always Allow Edits".into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
}],
|
||||
tool_call_id,
|
||||
cx,
|
||||
))
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Execute {
|
||||
command,
|
||||
root_command,
|
||||
description,
|
||||
} => confirmation_container
|
||||
.child(v_flex().px_2().pb_1p5().child(command.clone()).children(
|
||||
description.clone().map(|description| {
|
||||
self.render_markdown(description, default_markdown_style(false, window, cx))
|
||||
.on_url_click({
|
||||
let workspace = self.workspace.clone();
|
||||
move |text, window, cx| {
|
||||
Self::open_link(text, &workspace, window, cx);
|
||||
}
|
||||
})
|
||||
}),
|
||||
))
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[AlwaysAllowOption {
|
||||
id: "always_allow",
|
||||
label: format!("Always Allow {root_command}").into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
}],
|
||||
tool_call_id,
|
||||
cx,
|
||||
))
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Mcp {
|
||||
server_name,
|
||||
tool_name: _,
|
||||
tool_display_name,
|
||||
description,
|
||||
} => confirmation_container
|
||||
.child(
|
||||
v_flex()
|
||||
.px_2()
|
||||
.pb_1p5()
|
||||
.child(format!("{server_name} - {tool_display_name}"))
|
||||
.children(description.clone().map(|description| {
|
||||
self.render_markdown(
|
||||
description,
|
||||
default_markdown_style(false, window, cx),
|
||||
)
|
||||
})),
|
||||
)
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[
|
||||
AlwaysAllowOption {
|
||||
id: "always_allow_server",
|
||||
label: format!("Always Allow {server_name}").into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
|
||||
},
|
||||
AlwaysAllowOption {
|
||||
id: "always_allow_tool",
|
||||
label: format!("Always Allow {tool_display_name}").into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowTool,
|
||||
},
|
||||
],
|
||||
tool_call_id,
|
||||
cx,
|
||||
))
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Fetch { description, urls } => confirmation_container
|
||||
.child(
|
||||
v_flex()
|
||||
.px_2()
|
||||
.pb_1p5()
|
||||
.gap_1()
|
||||
.children(urls.iter().map(|url| {
|
||||
h_flex().child(
|
||||
Button::new(url.clone(), url)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.on_click({
|
||||
let url = url.clone();
|
||||
move |_, _, cx| cx.open_url(&url)
|
||||
}),
|
||||
)
|
||||
}))
|
||||
.children(description.clone().map(|description| {
|
||||
self.render_markdown(
|
||||
description,
|
||||
default_markdown_style(false, window, cx),
|
||||
)
|
||||
})),
|
||||
)
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[AlwaysAllowOption {
|
||||
id: "always_allow",
|
||||
label: "Always Allow".into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
}],
|
||||
tool_call_id,
|
||||
cx,
|
||||
))
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Other { description } => confirmation_container
|
||||
.child(v_flex().px_2().pb_1p5().child(self.render_markdown(
|
||||
description.clone(),
|
||||
default_markdown_style(false, window, cx),
|
||||
)))
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[AlwaysAllowOption {
|
||||
id: "always_allow",
|
||||
label: "Always Allow".into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
}],
|
||||
tool_call_id,
|
||||
cx,
|
||||
))
|
||||
.into_any(),
|
||||
}
|
||||
}
|
||||
|
||||
fn render_confirmation_buttons(
|
||||
&self,
|
||||
always_allow_options: &[AlwaysAllowOption],
|
||||
tool_call_id: ToolCallId,
|
||||
cx: &Context<Self>,
|
||||
) -> Div {
|
||||
h_flex()
|
||||
.py_1p5()
|
||||
.pt_1p5()
|
||||
.px_1p5()
|
||||
.gap_1()
|
||||
.justify_end()
|
||||
.when(!empty_content, |this| {
|
||||
this.border_t_1()
|
||||
.border_color(self.tool_card_border_color(cx))
|
||||
})
|
||||
.children(options.iter().map(|option| {
|
||||
let option_id = SharedString::from(option.id.0.clone());
|
||||
Button::new((option_id, entry_ix), option.label.clone())
|
||||
.map(|this| match option.kind {
|
||||
acp::PermissionOptionKind::AllowOnce => {
|
||||
this.icon(IconName::Check).icon_color(Color::Success)
|
||||
}
|
||||
acp::PermissionOptionKind::AllowAlways => {
|
||||
this.icon(IconName::CheckDouble).icon_color(Color::Success)
|
||||
}
|
||||
acp::PermissionOptionKind::RejectOnce => {
|
||||
this.icon(IconName::X).icon_color(Color::Error)
|
||||
}
|
||||
acp::PermissionOptionKind::RejectAlways => {
|
||||
this.icon(IconName::X).icon_color(Color::Error)
|
||||
}
|
||||
})
|
||||
.border_t_1()
|
||||
.border_color(self.tool_card_border_color(cx))
|
||||
.when(self.agent.supports_always_allow(), |this| {
|
||||
this.children(always_allow_options.into_iter().map(|always_allow_option| {
|
||||
let outcome = always_allow_option.outcome;
|
||||
Button::new(
|
||||
(always_allow_option.id, tool_call_id.0),
|
||||
always_allow_option.label.clone(),
|
||||
)
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let tool_call_id = tool_call_id.clone();
|
||||
let option_id = option.id.clone();
|
||||
let option_kind = option.kind;
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(id, outcome, cx);
|
||||
}
|
||||
}))
|
||||
}))
|
||||
})
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.0), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
tool_call_id.clone(),
|
||||
option_id.clone(),
|
||||
option_kind,
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}))
|
||||
}))
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.0), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Error)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_diff_editor(&self, multibuffer: &Entity<MultiBuffer>) -> AnyElement {
|
||||
@@ -2113,11 +2245,12 @@ impl AcpThreadView {
|
||||
.languages
|
||||
.language_for_name("Markdown");
|
||||
|
||||
let (thread_summary, markdown) = if let Some(thread) = self.thread() {
|
||||
let thread = thread.read(cx);
|
||||
(thread.title().to_string(), thread.to_markdown(cx))
|
||||
} else {
|
||||
return Task::ready(Ok(()));
|
||||
let (thread_summary, markdown) = match &self.thread_state {
|
||||
ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => {
|
||||
let thread = thread.read(cx);
|
||||
(thread.title().to_string(), thread.to_markdown(cx))
|
||||
}
|
||||
ThreadState::Loading { .. } | ThreadState::LoadError(..) => return Task::ready(Ok(())),
|
||||
};
|
||||
|
||||
window.spawn(cx, async move |cx| {
|
||||
|
||||
@@ -3895,7 +3895,7 @@ mod tests {
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: Arc::new(FakeLanguageModelProvider::default()),
|
||||
provider: Arc::new(FakeLanguageModelProvider),
|
||||
model,
|
||||
}),
|
||||
cx,
|
||||
@@ -3979,7 +3979,7 @@ mod tests {
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: Arc::new(FakeLanguageModelProvider::default()),
|
||||
provider: Arc::new(FakeLanguageModelProvider),
|
||||
model: model.clone(),
|
||||
}),
|
||||
cx,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
mod add_llm_provider_modal;
|
||||
mod configure_context_server_modal;
|
||||
mod manage_profiles_modal;
|
||||
mod tool_picker;
|
||||
@@ -29,7 +28,7 @@ use proto::Plan;
|
||||
use settings::{Settings, update_settings_file};
|
||||
use ui::{
|
||||
Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
|
||||
Scrollbar, ScrollbarState, Switch, SwitchColor, SwitchField, Tooltip, prelude::*,
|
||||
Scrollbar, ScrollbarState, Switch, SwitchColor, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
@@ -38,10 +37,7 @@ use zed_actions::ExtensionCategoryFilter;
|
||||
pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
|
||||
pub(crate) use manage_profiles_modal::ManageProfilesModal;
|
||||
|
||||
use crate::{
|
||||
AddContextServer,
|
||||
agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider},
|
||||
};
|
||||
use crate::AddContextServer;
|
||||
|
||||
pub struct AgentConfiguration {
|
||||
fs: Arc<dyn Fs>,
|
||||
@@ -185,15 +181,7 @@ impl AgentConfiguration {
|
||||
None
|
||||
};
|
||||
|
||||
let is_signed_in = self
|
||||
.workspace
|
||||
.read_with(cx, |workspace, _| {
|
||||
workspace.client().status().borrow().is_connected()
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
v_flex()
|
||||
.w_full()
|
||||
.when(is_expanded, |this| this.mb_2())
|
||||
.child(
|
||||
div()
|
||||
@@ -224,7 +212,6 @@ impl AgentConfiguration {
|
||||
.hover(|hover| hover.bg(cx.theme().colors().element_hover))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.child(
|
||||
Icon::new(provider.icon())
|
||||
@@ -233,15 +220,14 @@ impl AgentConfiguration {
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.child(
|
||||
Label::new(provider_name.clone())
|
||||
.size(LabelSize::Large),
|
||||
)
|
||||
.map(|this| {
|
||||
if is_zed_provider && is_signed_in {
|
||||
this.child(
|
||||
if is_zed_provider {
|
||||
this.gap_2().child(
|
||||
self.render_zed_plan_info(current_plan, cx),
|
||||
)
|
||||
} else {
|
||||
@@ -317,78 +303,21 @@ impl AgentConfiguration {
|
||||
let providers = LanguageModelRegistry::read_global(cx).providers();
|
||||
|
||||
v_flex()
|
||||
.w_full()
|
||||
.child(
|
||||
h_flex()
|
||||
v_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.pb_0()
|
||||
.mb_2p5()
|
||||
.items_start()
|
||||
.justify_between()
|
||||
.gap_0p5()
|
||||
.child(Headline::new("LLM Providers"))
|
||||
.child(
|
||||
v_flex()
|
||||
.w_full()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.child(Headline::new("LLM Providers"))
|
||||
.child(
|
||||
PopoverMenu::new("add-provider-popover")
|
||||
.trigger(
|
||||
Button::new("add-provider", "Add Provider")
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon(IconName::Plus)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.label_size(LabelSize::Small),
|
||||
)
|
||||
.anchor(gpui::Corner::TopRight)
|
||||
.menu({
|
||||
let workspace = self.workspace.clone();
|
||||
move |window, cx| {
|
||||
Some(ContextMenu::build(
|
||||
window,
|
||||
cx,
|
||||
|menu, _window, _cx| {
|
||||
menu.header("Compatible APIs").entry(
|
||||
"OpenAI",
|
||||
None,
|
||||
{
|
||||
let workspace =
|
||||
workspace.clone();
|
||||
move |window, cx| {
|
||||
workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
AddLlmProviderModal::toggle(
|
||||
LlmCompatibleProvider::OpenAi,
|
||||
workspace,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
},
|
||||
)
|
||||
},
|
||||
))
|
||||
}
|
||||
}),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Label::new("Add at least one provider to use AI-powered features.")
|
||||
.color(Color::Muted),
|
||||
),
|
||||
Label::new("Add at least one provider to use AI-powered features.")
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.w_full()
|
||||
.pl(DynamicSpacing::Base08.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.children(
|
||||
@@ -401,74 +330,119 @@ impl AgentConfiguration {
|
||||
|
||||
fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let always_allow_tool_actions = AgentSettings::get_global(cx).always_allow_tool_actions;
|
||||
let fs = self.fs.clone();
|
||||
|
||||
SwitchField::new(
|
||||
"always-allow-tool-actions-switch",
|
||||
"Allow running commands without asking for confirmation",
|
||||
"The agent can perform potentially destructive actions without asking for your confirmation.",
|
||||
always_allow_tool_actions,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
update_settings_file::<AgentSettings>(fs.clone(), cx, move |settings, _| {
|
||||
settings.set_always_allow_tool_actions(allow);
|
||||
});
|
||||
},
|
||||
)
|
||||
h_flex()
|
||||
.gap_4()
|
||||
.justify_between()
|
||||
.flex_wrap()
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.max_w_5_6()
|
||||
.child(Label::new("Allow running editing tools without asking for confirmation"))
|
||||
.child(
|
||||
Label::new(
|
||||
"The agent can perform potentially destructive actions without asking for your confirmation.",
|
||||
)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Switch::new(
|
||||
"always-allow-tool-actions-switch",
|
||||
always_allow_tool_actions.into(),
|
||||
)
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let fs = self.fs.clone();
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
update_settings_file::<AgentSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
move |settings, _| {
|
||||
settings.set_always_allow_tool_actions(allow);
|
||||
},
|
||||
);
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_single_file_review(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let single_file_review = AgentSettings::get_global(cx).single_file_review;
|
||||
let fs = self.fs.clone();
|
||||
|
||||
SwitchField::new(
|
||||
"single-file-review",
|
||||
"Enable single-file agent reviews",
|
||||
"Agent edits are also displayed in single-file editors for review.",
|
||||
single_file_review,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
update_settings_file::<AgentSettings>(fs.clone(), cx, move |settings, _| {
|
||||
settings.set_single_file_review(allow);
|
||||
});
|
||||
},
|
||||
)
|
||||
h_flex()
|
||||
.gap_4()
|
||||
.justify_between()
|
||||
.flex_wrap()
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.max_w_5_6()
|
||||
.child(Label::new("Enable single-file agent reviews"))
|
||||
.child(
|
||||
Label::new(
|
||||
"Agent edits are also displayed in single-file editors for review.",
|
||||
)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Switch::new("single-file-review-switch", single_file_review.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let fs = self.fs.clone();
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
update_settings_file::<AgentSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
move |settings, _| {
|
||||
settings.set_single_file_review(allow);
|
||||
},
|
||||
);
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_sound_notification(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let play_sound_when_agent_done = AgentSettings::get_global(cx).play_sound_when_agent_done;
|
||||
let fs = self.fs.clone();
|
||||
|
||||
SwitchField::new(
|
||||
"sound-notification",
|
||||
"Play sound when finished generating",
|
||||
"Hear a notification sound when the agent is done generating changes or needs your input.",
|
||||
play_sound_when_agent_done,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
update_settings_file::<AgentSettings>(fs.clone(), cx, move |settings, _| {
|
||||
settings.set_play_sound_when_agent_done(allow);
|
||||
});
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn render_modifier_to_send(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let use_modifier_to_send = AgentSettings::get_global(cx).use_modifier_to_send;
|
||||
let fs = self.fs.clone();
|
||||
|
||||
SwitchField::new(
|
||||
"modifier-send",
|
||||
"Use modifier to submit a message",
|
||||
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.",
|
||||
use_modifier_to_send,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
update_settings_file::<AgentSettings>(fs.clone(), cx, move |settings, _| {
|
||||
settings.set_use_modifier_to_send(allow);
|
||||
});
|
||||
},
|
||||
)
|
||||
h_flex()
|
||||
.gap_4()
|
||||
.justify_between()
|
||||
.flex_wrap()
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.max_w_5_6()
|
||||
.child(Label::new("Play sound when finished generating"))
|
||||
.child(
|
||||
Label::new(
|
||||
"Hear a notification sound when the agent is done generating changes or needs your input.",
|
||||
)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Switch::new("play-sound-notification-switch", play_sound_when_agent_done.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let fs = self.fs.clone();
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
update_settings_file::<AgentSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
move |settings, _| {
|
||||
settings.set_play_sound_when_agent_done(allow);
|
||||
},
|
||||
);
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_general_settings_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
@@ -482,7 +456,6 @@ impl AgentConfiguration {
|
||||
.child(self.render_command_permission(cx))
|
||||
.child(self.render_single_file_review(cx))
|
||||
.child(self.render_sound_notification(cx))
|
||||
.child(self.render_modifier_to_send(cx))
|
||||
}
|
||||
|
||||
fn render_zed_plan_info(&self, plan: Option<Plan>, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
|
||||
@@ -1,639 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use fs::Fs;
|
||||
use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use language_models::{
|
||||
AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
|
||||
provider::open_ai_compatible::AvailableModel,
|
||||
};
|
||||
use settings::update_settings_file;
|
||||
use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum LlmCompatibleProvider {
|
||||
OpenAi,
|
||||
}
|
||||
|
||||
impl LlmCompatibleProvider {
|
||||
fn name(&self) -> &'static str {
|
||||
match self {
|
||||
LlmCompatibleProvider::OpenAi => "OpenAI",
|
||||
}
|
||||
}
|
||||
|
||||
fn api_url(&self) -> &'static str {
|
||||
match self {
|
||||
LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AddLlmProviderInput {
|
||||
provider_name: Entity<SingleLineInput>,
|
||||
api_url: Entity<SingleLineInput>,
|
||||
api_key: Entity<SingleLineInput>,
|
||||
models: Vec<ModelInput>,
|
||||
}
|
||||
|
||||
impl AddLlmProviderInput {
|
||||
fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
|
||||
let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx);
|
||||
let api_url = single_line_input("API URL", provider.api_url(), None, window, cx);
|
||||
let api_key = single_line_input(
|
||||
"API Key",
|
||||
"000000000000000000000000000000000000000000000000",
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
Self {
|
||||
provider_name,
|
||||
api_url,
|
||||
api_key,
|
||||
models: vec![ModelInput::new(window, cx)],
|
||||
}
|
||||
}
|
||||
|
||||
fn add_model(&mut self, window: &mut Window, cx: &mut App) {
|
||||
self.models.push(ModelInput::new(window, cx));
|
||||
}
|
||||
|
||||
fn remove_model(&mut self, index: usize) {
|
||||
self.models.remove(index);
|
||||
}
|
||||
}
|
||||
|
||||
struct ModelInput {
|
||||
name: Entity<SingleLineInput>,
|
||||
max_completion_tokens: Entity<SingleLineInput>,
|
||||
max_output_tokens: Entity<SingleLineInput>,
|
||||
max_tokens: Entity<SingleLineInput>,
|
||||
}
|
||||
|
||||
impl ModelInput {
|
||||
fn new(window: &mut Window, cx: &mut App) -> Self {
|
||||
let model_name = single_line_input(
|
||||
"Model Name",
|
||||
"e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_completion_tokens = single_line_input(
|
||||
"Max Completion Tokens",
|
||||
"200000",
|
||||
Some("200000"),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_output_tokens = single_line_input(
|
||||
"Max Output Tokens",
|
||||
"Max Output Tokens",
|
||||
Some("32000"),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
|
||||
Self {
|
||||
name: model_name,
|
||||
max_completion_tokens,
|
||||
max_output_tokens,
|
||||
max_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
|
||||
let name = self.name.read(cx).text(cx);
|
||||
if name.is_empty() {
|
||||
return Err(SharedString::from("Model Name cannot be empty"));
|
||||
}
|
||||
Ok(AvailableModel {
|
||||
name,
|
||||
display_name: None,
|
||||
max_completion_tokens: Some(
|
||||
self.max_completion_tokens
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.parse::<u64>()
|
||||
.map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
|
||||
),
|
||||
max_output_tokens: Some(
|
||||
self.max_output_tokens
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.parse::<u64>()
|
||||
.map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
|
||||
),
|
||||
max_tokens: self
|
||||
.max_tokens
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.parse::<u64>()
|
||||
.map_err(|_| SharedString::from("Max Tokens must be a number"))?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn single_line_input(
|
||||
label: impl Into<SharedString>,
|
||||
placeholder: impl Into<SharedString>,
|
||||
text: Option<&str>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<SingleLineInput> {
|
||||
cx.new(|cx| {
|
||||
let input = SingleLineInput::new(window, cx, placeholder).label(label);
|
||||
if let Some(text) = text {
|
||||
input
|
||||
.editor()
|
||||
.update(cx, |editor, cx| editor.set_text(text, window, cx));
|
||||
}
|
||||
input
|
||||
})
|
||||
}
|
||||
|
||||
fn save_provider_to_settings(
|
||||
input: &AddLlmProviderInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(), SharedString>> {
|
||||
let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
|
||||
if provider_name.is_empty() {
|
||||
return Task::ready(Err("Provider Name cannot be empty".into()));
|
||||
}
|
||||
|
||||
if LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.any(|provider| {
|
||||
provider.id().0.as_ref() == provider_name.as_ref()
|
||||
|| provider.name().0.as_ref() == provider_name.as_ref()
|
||||
})
|
||||
{
|
||||
return Task::ready(Err(
|
||||
"Provider Name is already taken by another provider".into()
|
||||
));
|
||||
}
|
||||
|
||||
let api_url = input.api_url.read(cx).text(cx);
|
||||
if api_url.is_empty() {
|
||||
return Task::ready(Err("API URL cannot be empty".into()));
|
||||
}
|
||||
|
||||
let api_key = input.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return Task::ready(Err("API Key cannot be empty".into()));
|
||||
}
|
||||
|
||||
let mut models = Vec::new();
|
||||
let mut model_names: HashSet<String> = HashSet::default();
|
||||
for model in &input.models {
|
||||
match model.parse(cx) {
|
||||
Ok(model) => {
|
||||
if !model_names.insert(model.name.clone()) {
|
||||
return Task::ready(Err("Model Names must be unique".into()));
|
||||
}
|
||||
models.push(model)
|
||||
}
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
}
|
||||
}
|
||||
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(async move |cx| {
|
||||
task.await
|
||||
.map_err(|_| "Failed to write API key to keychain")?;
|
||||
cx.update(|cx| {
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
|
||||
settings.openai_compatible.get_or_insert_default().insert(
|
||||
provider_name,
|
||||
OpenAiCompatibleSettingsContent {
|
||||
api_url,
|
||||
available_models: models,
|
||||
},
|
||||
);
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub struct AddLlmProviderModal {
|
||||
provider: LlmCompatibleProvider,
|
||||
input: AddLlmProviderInput,
|
||||
focus_handle: FocusHandle,
|
||||
last_error: Option<SharedString>,
|
||||
}
|
||||
|
||||
impl AddLlmProviderModal {
|
||||
pub fn toggle(
|
||||
provider: LlmCompatibleProvider,
|
||||
workspace: &mut Workspace,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
|
||||
}
|
||||
|
||||
fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
input: AddLlmProviderInput::new(provider, window, cx),
|
||||
provider,
|
||||
last_error: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
}
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
|
||||
let task = save_provider_to_settings(&self.input, cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
let result = task.await;
|
||||
this.update(cx, |this, cx| match result {
|
||||
Ok(_) => {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
Err(error) => {
|
||||
this.last_error = Some(error);
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn render_section(&self) -> Section {
|
||||
Section::new()
|
||||
.child(self.input.provider_name.clone())
|
||||
.child(self.input.api_url.clone())
|
||||
.child(self.input.api_key.clone())
|
||||
}
|
||||
|
||||
fn render_model_section(&self, cx: &mut Context<Self>) -> Section {
|
||||
Section::new().child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_between()
|
||||
.child(Label::new("Models").size(LabelSize::Small))
|
||||
.child(
|
||||
Button::new("add-model", "Add Model")
|
||||
.icon(IconName::Plus)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.input.add_model(window, cx);
|
||||
cx.notify();
|
||||
})),
|
||||
),
|
||||
)
|
||||
.children(
|
||||
self.input
|
||||
.models
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(ix, _)| self.render_model(ix, cx)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
|
||||
let has_more_than_one_model = self.input.models.len() > 1;
|
||||
let model = &self.input.models[ix];
|
||||
|
||||
v_flex()
|
||||
.p_2()
|
||||
.gap_2()
|
||||
.rounded_sm()
|
||||
.border_1()
|
||||
.border_dashed()
|
||||
.border_color(cx.theme().colors().border.opacity(0.6))
|
||||
.bg(cx.theme().colors().element_active.opacity(0.15))
|
||||
.child(model.name.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(model.max_completion_tokens.clone())
|
||||
.child(model.max_output_tokens.clone()),
|
||||
)
|
||||
.child(model.max_tokens.clone())
|
||||
.when(has_more_than_one_model, |this| {
|
||||
this.child(
|
||||
Button::new(("remove-model", ix), "Remove Model")
|
||||
.icon(IconName::Trash)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.label_size(LabelSize::Small)
|
||||
.style(ButtonStyle::Outlined)
|
||||
.full_width()
|
||||
.on_click(cx.listener(move |this, _, _window, cx| {
|
||||
this.input.remove_model(ix);
|
||||
cx.notify();
|
||||
})),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
|
||||
|
||||
impl Focusable for AddLlmProviderModal {
|
||||
fn focus_handle(&self, _cx: &App) -> FocusHandle {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl ModalView for AddLlmProviderModal {}
|
||||
|
||||
impl Render for AddLlmProviderModal {
|
||||
fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
|
||||
let focus_handle = self.focus_handle(cx);
|
||||
|
||||
div()
|
||||
.id("add-llm-provider-modal")
|
||||
.key_context("AddLlmProviderModal")
|
||||
.w(rems(34.))
|
||||
.elevation_3(cx)
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.capture_any_mouse_down(cx.listener(|this, _, window, cx| {
|
||||
this.focus_handle(cx).focus(window);
|
||||
}))
|
||||
.child(
|
||||
Modal::new("configure-context-server", None)
|
||||
.header(ModalHeader::new().headline("Add LLM Provider").description(
|
||||
match self.provider {
|
||||
LlmCompatibleProvider::OpenAi => {
|
||||
"This provider will use an OpenAI compatible API."
|
||||
}
|
||||
},
|
||||
))
|
||||
.when_some(self.last_error.clone(), |this, error| {
|
||||
this.section(
|
||||
Section::new().child(
|
||||
Banner::new()
|
||||
.severity(ui::Severity::Warning)
|
||||
.child(div().text_xs().child(error)),
|
||||
),
|
||||
)
|
||||
})
|
||||
.child(
|
||||
v_flex()
|
||||
.id("modal_content")
|
||||
.max_h_128()
|
||||
.overflow_y_scroll()
|
||||
.gap_2()
|
||||
.child(self.render_section())
|
||||
.child(self.render_model_section(cx)),
|
||||
)
|
||||
.footer(
|
||||
ModalFooter::new().end_slot(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new("cancel", "Cancel")
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Cancel,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, window, cx| {
|
||||
this.cancel(&menu::Cancel, window, cx)
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("save-server", "Save Provider")
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Confirm,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, window, cx| {
|
||||
this.confirm(&menu::Confirm, window, cx)
|
||||
})),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use editor::EditorSettings;
|
||||
use fs::FakeFs;
|
||||
use gpui::{TestAppContext, VisualTestContext};
|
||||
use language::language_settings;
|
||||
use language_model::{
|
||||
LanguageModelProviderId, LanguageModelProviderName,
|
||||
fake_provider::FakeLanguageModelProvider,
|
||||
};
|
||||
use project::Project;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
|
||||
let cx = setup_test(cx).await;
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
|
||||
Some("Provider Name cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
|
||||
Some("API URL cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
|
||||
Some("API Key cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("", "200000", "200000", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Model Name cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("somemodel", "abc", "200000", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Max Tokens must be a number".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("somemodel", "200000", "abc", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Max Completion Tokens must be a number".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("somemodel", "200000", "200000", "abc")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Max Output Tokens must be a number".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![
|
||||
("somemodel", "200000", "200000", "32000"),
|
||||
("somemodel", "200000", "200000", "32000"),
|
||||
],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Model Names must be unique".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
|
||||
let cx = setup_test(cx).await;
|
||||
|
||||
cx.update(|_window, cx| {
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.register_provider(
|
||||
FakeLanguageModelProvider::new(
|
||||
LanguageModelProviderId::new("someprovider"),
|
||||
LanguageModelProviderName::new("Some Provider"),
|
||||
),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"someapikey",
|
||||
vec![("somemodel", "200000", "200000", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Provider Name is already taken by another provider".into())
|
||||
);
|
||||
}
|
||||
|
||||
async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
|
||||
cx.update(|cx| {
|
||||
let store = SettingsStore::test(cx);
|
||||
cx.set_global(store);
|
||||
workspace::init_settings(cx);
|
||||
Project::init_settings(cx);
|
||||
theme::init(theme::LoadThemes::JustBase, cx);
|
||||
language_settings::init(cx);
|
||||
EditorSettings::register(cx);
|
||||
language_model::init_settings(cx);
|
||||
language_models::init_settings(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
|
||||
let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
|
||||
let (_, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
cx
|
||||
}
|
||||
|
||||
async fn save_provider_validation_errors(
|
||||
provider_name: &str,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
models: Vec<(&str, &str, &str, &str)>,
|
||||
cx: &mut VisualTestContext,
|
||||
) -> Option<SharedString> {
|
||||
fn set_text(
|
||||
input: &Entity<SingleLineInput>,
|
||||
text: &str,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
input.update(cx, |input, cx| {
|
||||
input.editor().update(cx, |editor, cx| {
|
||||
editor.set_text(text, window, cx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
let task = cx.update(|window, cx| {
|
||||
let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
|
||||
set_text(&input.provider_name, provider_name, window, cx);
|
||||
set_text(&input.api_url, api_url, window, cx);
|
||||
set_text(&input.api_key, api_key, window, cx);
|
||||
|
||||
for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
|
||||
models.iter().enumerate()
|
||||
{
|
||||
if i >= input.models.len() {
|
||||
input.models.push(ModelInput::new(window, cx));
|
||||
}
|
||||
let model = &mut input.models[i];
|
||||
set_text(&model.name, name, window, cx);
|
||||
set_text(&model.max_tokens, max_tokens, window, cx);
|
||||
set_text(
|
||||
&model.max_completion_tokens,
|
||||
max_completion_tokens,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
set_text(&model.max_output_tokens, max_output_tokens, window, cx);
|
||||
}
|
||||
save_provider_to_settings(&input, cx)
|
||||
});
|
||||
|
||||
task.await.err()
|
||||
}
|
||||
}
|
||||
@@ -1506,7 +1506,8 @@ impl AgentDiff {
|
||||
.read(cx)
|
||||
.entries()
|
||||
.last()
|
||||
.map_or(false, |entry| entry.diffs().next().is_some())
|
||||
.and_then(|entry| entry.diff())
|
||||
.is_some()
|
||||
{
|
||||
self.update_reviewing_editors(workspace, window, cx);
|
||||
}
|
||||
@@ -1516,7 +1517,8 @@ impl AgentDiff {
|
||||
.read(cx)
|
||||
.entries()
|
||||
.get(*ix)
|
||||
.map_or(false, |entry| entry.diffs().next().is_some())
|
||||
.and_then(|entry| entry.diff())
|
||||
.is_some()
|
||||
{
|
||||
self.update_reviewing_editors(workspace, window, cx);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::cell::RefCell;
|
||||
use std::ops::{Not, Range};
|
||||
use std::ops::Range;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
@@ -43,7 +43,7 @@ use anyhow::{Result, anyhow};
|
||||
use assistant_context::{AssistantContext, ContextEvent, ContextSummary};
|
||||
use assistant_slash_command::SlashCommandWorkingSet;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use client::{DisableAiSettings, UserStore, zed_urls};
|
||||
use client::{UserStore, zed_urls};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use fs::Fs;
|
||||
@@ -440,7 +440,7 @@ pub struct AgentPanel {
|
||||
local_timezone: UtcOffset,
|
||||
active_view: ActiveView,
|
||||
acp_message_history:
|
||||
Rc<RefCell<crate::acp::MessageHistory<Vec<agent_client_protocol::ContentBlock>>>>,
|
||||
Rc<RefCell<crate::acp::MessageHistory<agentic_coding_protocol::SendUserMessageParams>>>,
|
||||
previous_view: Option<ActiveView>,
|
||||
history_store: Entity<HistoryStore>,
|
||||
history: Entity<ThreadHistory>,
|
||||
@@ -564,17 +564,6 @@ impl AgentPanel {
|
||||
let inline_assist_context_store =
|
||||
cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade())));
|
||||
|
||||
let thread_id = thread.read(cx).id().clone();
|
||||
|
||||
let history_store = cx.new(|cx| {
|
||||
HistoryStore::new(
|
||||
thread_store.clone(),
|
||||
context_store.clone(),
|
||||
[HistoryEntryId::Thread(thread_id)],
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let message_editor = cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
fs.clone(),
|
||||
@@ -584,13 +573,22 @@ impl AgentPanel {
|
||||
prompt_store.clone(),
|
||||
thread_store.downgrade(),
|
||||
context_store.downgrade(),
|
||||
Some(history_store.downgrade()),
|
||||
thread.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let thread_id = thread.read(cx).id().clone();
|
||||
let history_store = cx.new(|cx| {
|
||||
HistoryStore::new(
|
||||
thread_store.clone(),
|
||||
context_store.clone(),
|
||||
[HistoryEntryId::Thread(thread_id)],
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.observe(&history_store, |_, _, cx| cx.notify()).detach();
|
||||
|
||||
let active_thread = cx.new(|cx| {
|
||||
@@ -746,7 +744,6 @@ impl AgentPanel {
|
||||
if workspace
|
||||
.panel::<Self>(cx)
|
||||
.is_some_and(|panel| panel.read(cx).enabled(cx))
|
||||
&& !DisableAiSettings::get_global(cx).disable_ai
|
||||
{
|
||||
workspace.toggle_panel_focus::<Self>(window, cx);
|
||||
}
|
||||
@@ -853,7 +850,6 @@ impl AgentPanel {
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
self.context_store.downgrade(),
|
||||
Some(self.history_store.downgrade()),
|
||||
thread.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -1127,7 +1123,6 @@ impl AgentPanel {
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
self.context_store.downgrade(),
|
||||
Some(self.history_store.downgrade()),
|
||||
thread.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -1686,7 +1681,7 @@ impl Panel for AgentPanel {
|
||||
}
|
||||
|
||||
fn enabled(&self, cx: &App) -> bool {
|
||||
DisableAiSettings::get_global(cx).disable_ai.not() && AgentSettings::get_global(cx).enabled
|
||||
AgentSettings::get_global(cx).enabled
|
||||
}
|
||||
|
||||
fn is_zoomed(&self, _window: &Window, _cx: &App) -> bool {
|
||||
@@ -1905,110 +1900,85 @@ impl AgentPanel {
|
||||
)
|
||||
.anchor(Corner::TopRight)
|
||||
.with_handle(self.new_thread_menu_handle.clone())
|
||||
.menu({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
let active_thread = active_thread.clone();
|
||||
Some(ContextMenu::build(window, cx, |mut menu, _window, cx| {
|
||||
menu = menu
|
||||
.context(focus_handle.clone())
|
||||
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
this.header("Zed Agent")
|
||||
})
|
||||
.item(
|
||||
ContextMenuEntry::new("New Thread")
|
||||
.icon(IconName::NewThread)
|
||||
.icon_color(Color::Muted)
|
||||
.action(NewThread::default().boxed_clone())
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewThread::default().boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Text Thread")
|
||||
.icon(IconName::NewTextThread)
|
||||
.icon_color(Color::Muted)
|
||||
.action(NewTextThread.boxed_clone())
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(NewTextThread.boxed_clone(), cx);
|
||||
}),
|
||||
)
|
||||
.when_some(active_thread, |this, active_thread| {
|
||||
let thread = active_thread.read(cx);
|
||||
.menu(move |window, cx| {
|
||||
let active_thread = active_thread.clone();
|
||||
Some(ContextMenu::build(window, cx, |mut menu, _window, cx| {
|
||||
menu = menu
|
||||
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
this.header("Zed Agent")
|
||||
})
|
||||
.item(
|
||||
ContextMenuEntry::new("New Thread")
|
||||
.icon(IconName::NewThread)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(NewThread::default().boxed_clone(), cx);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Text Thread")
|
||||
.icon(IconName::NewTextThread)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(NewTextThread.boxed_clone(), cx);
|
||||
}),
|
||||
)
|
||||
.when_some(active_thread, |this, active_thread| {
|
||||
let thread = active_thread.read(cx);
|
||||
|
||||
if !thread.is_empty() {
|
||||
let thread_id = thread.id().clone();
|
||||
this.item(
|
||||
ContextMenuEntry::new("New From Summary")
|
||||
.icon(IconName::NewFromSummary)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewThread {
|
||||
from_thread_id: Some(thread_id.clone()),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
this
|
||||
}
|
||||
})
|
||||
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
this.separator()
|
||||
.header("External Agents")
|
||||
.item(
|
||||
ContextMenuEntry::new("New Gemini Thread")
|
||||
.icon(IconName::AiGemini)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Gemini),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Claude Code Thread")
|
||||
.icon(IconName::AiClaude)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::ClaudeCode,
|
||||
),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Codex Thread")
|
||||
.icon(IconName::AiOpenAi)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Codex),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
});
|
||||
menu
|
||||
}))
|
||||
}
|
||||
if !thread.is_empty() {
|
||||
let thread_id = thread.id().clone();
|
||||
this.item(
|
||||
ContextMenuEntry::new("New From Summary")
|
||||
.icon(IconName::NewFromSummary)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewThread {
|
||||
from_thread_id: Some(thread_id.clone()),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
this
|
||||
}
|
||||
})
|
||||
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
this.separator()
|
||||
.header("External Agents")
|
||||
.item(
|
||||
ContextMenuEntry::new("New Gemini Thread")
|
||||
.icon(IconName::AiGemini)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Gemini),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Claude Code Thread")
|
||||
.icon(IconName::AiClaude)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::ClaudeCode),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
});
|
||||
menu
|
||||
}))
|
||||
});
|
||||
|
||||
let agent_panel_menu = PopoverMenu::new("agent-options-menu")
|
||||
@@ -2030,69 +2000,65 @@ impl AgentPanel {
|
||||
)
|
||||
.anchor(Corner::TopRight)
|
||||
.with_handle(self.agent_panel_menu_handle.clone())
|
||||
.menu({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
Some(ContextMenu::build(window, cx, |mut menu, _window, _| {
|
||||
menu = menu.context(focus_handle.clone());
|
||||
if let Some(usage) = usage {
|
||||
menu = menu
|
||||
.header_with_link("Prompt Usage", "Manage", account_url.clone())
|
||||
.custom_entry(
|
||||
move |_window, cx| {
|
||||
let used_percentage = match usage.limit {
|
||||
UsageLimit::Limited(limit) => {
|
||||
Some((usage.amount as f32 / limit as f32) * 100.)
|
||||
}
|
||||
UsageLimit::Unlimited => None,
|
||||
};
|
||||
|
||||
h_flex()
|
||||
.flex_1()
|
||||
.gap_1p5()
|
||||
.children(used_percentage.map(|percent| {
|
||||
ProgressBar::new("usage", percent, 100., cx)
|
||||
}))
|
||||
.child(
|
||||
Label::new(match usage.limit {
|
||||
UsageLimit::Limited(limit) => {
|
||||
format!("{} / {limit}", usage.amount)
|
||||
}
|
||||
UsageLimit::Unlimited => {
|
||||
format!("{} / ∞", usage.amount)
|
||||
}
|
||||
})
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any_element()
|
||||
},
|
||||
move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
|
||||
)
|
||||
.separator()
|
||||
}
|
||||
|
||||
.menu(move |window, cx| {
|
||||
Some(ContextMenu::build(window, cx, |mut menu, _window, _| {
|
||||
if let Some(usage) = usage {
|
||||
menu = menu
|
||||
.header("MCP Servers")
|
||||
.action(
|
||||
"View Server Extensions",
|
||||
Box::new(zed_actions::Extensions {
|
||||
category_filter: Some(
|
||||
zed_actions::ExtensionCategoryFilter::ContextServers,
|
||||
),
|
||||
id: None,
|
||||
}),
|
||||
.header_with_link("Prompt Usage", "Manage", account_url.clone())
|
||||
.custom_entry(
|
||||
move |_window, cx| {
|
||||
let used_percentage = match usage.limit {
|
||||
UsageLimit::Limited(limit) => {
|
||||
Some((usage.amount as f32 / limit as f32) * 100.)
|
||||
}
|
||||
UsageLimit::Unlimited => None,
|
||||
};
|
||||
|
||||
h_flex()
|
||||
.flex_1()
|
||||
.gap_1p5()
|
||||
.children(used_percentage.map(|percent| {
|
||||
ProgressBar::new("usage", percent, 100., cx)
|
||||
}))
|
||||
.child(
|
||||
Label::new(match usage.limit {
|
||||
UsageLimit::Limited(limit) => {
|
||||
format!("{} / {limit}", usage.amount)
|
||||
}
|
||||
UsageLimit::Unlimited => {
|
||||
format!("{} / ∞", usage.amount)
|
||||
}
|
||||
})
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any_element()
|
||||
},
|
||||
move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
|
||||
)
|
||||
.action("Add Custom Server…", Box::new(AddContextServer))
|
||||
.separator();
|
||||
.separator()
|
||||
}
|
||||
|
||||
menu = menu
|
||||
.action("Rules…", Box::new(OpenRulesLibrary::default()))
|
||||
.action("Settings", Box::new(OpenConfiguration))
|
||||
.action(zoom_in_label, Box::new(ToggleZoom));
|
||||
menu
|
||||
}))
|
||||
}
|
||||
menu = menu
|
||||
.header("MCP Servers")
|
||||
.action(
|
||||
"View Server Extensions",
|
||||
Box::new(zed_actions::Extensions {
|
||||
category_filter: Some(
|
||||
zed_actions::ExtensionCategoryFilter::ContextServers,
|
||||
),
|
||||
id: None,
|
||||
}),
|
||||
)
|
||||
.action("Add Custom Server…", Box::new(AddContextServer))
|
||||
.separator();
|
||||
|
||||
menu = menu
|
||||
.action("Rules…", Box::new(OpenRulesLibrary::default()))
|
||||
.action("Settings", Box::new(OpenConfiguration))
|
||||
.action(zoom_in_label, Box::new(ToggleZoom));
|
||||
menu
|
||||
}))
|
||||
});
|
||||
|
||||
h_flex()
|
||||
@@ -2305,21 +2271,20 @@ impl AgentPanel {
|
||||
}
|
||||
|
||||
match &self.active_view {
|
||||
ActiveView::Thread { .. } | ActiveView::TextThread { .. } => {
|
||||
let history_is_empty = self
|
||||
.history_store
|
||||
.update(cx, |store, cx| store.recent_entries(1, cx).is_empty());
|
||||
|
||||
let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.any(|provider| {
|
||||
provider.is_authenticated(cx)
|
||||
&& provider.id() != language_model::ZED_CLOUD_PROVIDER_ID
|
||||
});
|
||||
|
||||
history_is_empty || !has_configured_non_zed_providers
|
||||
}
|
||||
ActiveView::Thread { thread, .. } => thread
|
||||
.read(cx)
|
||||
.thread()
|
||||
.read(cx)
|
||||
.configured_model()
|
||||
.map_or(true, |model| {
|
||||
model.provider.id() == language_model::ZED_CLOUD_PROVIDER_ID
|
||||
}),
|
||||
ActiveView::TextThread { .. } => LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.default_model()
|
||||
.map_or(true, |model| {
|
||||
model.provider.id() == language_model::ZED_CLOUD_PROVIDER_ID
|
||||
}),
|
||||
ActiveView::ExternalAgentThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => false,
|
||||
@@ -2340,8 +2305,9 @@ impl AgentPanel {
|
||||
|
||||
Some(
|
||||
div()
|
||||
.size_full()
|
||||
.when(thread_view, |this| {
|
||||
this.size_full().bg(cx.theme().colors().panel_background)
|
||||
this.bg(cx.theme().colors().panel_background)
|
||||
})
|
||||
.when(text_thread_view, |this| {
|
||||
this.bg(cx.theme().colors().editor_background)
|
||||
@@ -2666,25 +2632,6 @@ impl AgentPanel {
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-codex-thread-btn",
|
||||
"New Codex Thread",
|
||||
IconName::AiOpenAi,
|
||||
)
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::Codex,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
}),
|
||||
|
||||
@@ -31,8 +31,7 @@ use std::sync::Arc;
|
||||
use agent::{Thread, ThreadId};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection};
|
||||
use assistant_slash_command::SlashCommandRegistry;
|
||||
use client::{Client, DisableAiSettings};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use client::Client;
|
||||
use feature_flags::FeatureFlagAppExt as _;
|
||||
use fs::Fs;
|
||||
use gpui::{Action, App, Entity, actions};
|
||||
@@ -44,7 +43,6 @@ use prompt_store::PromptBuilder;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use std::any::TypeId;
|
||||
|
||||
pub use crate::active_thread::ActiveThread;
|
||||
use crate::agent_configuration::{ConfigureContextServerModal, ManageProfilesModal};
|
||||
@@ -54,7 +52,6 @@ use crate::slash_command_settings::SlashCommandSettings;
|
||||
pub use agent_diff::{AgentDiffPane, AgentDiffToolbar};
|
||||
pub use text_thread_editor::{AgentPanelDelegate, TextThreadEditor};
|
||||
pub use ui::preview::{all_agent_previews, get_agent_preview};
|
||||
use zed_actions;
|
||||
|
||||
actions!(
|
||||
agent,
|
||||
@@ -150,7 +147,6 @@ enum ExternalAgent {
|
||||
#[default]
|
||||
Gemini,
|
||||
ClaudeCode,
|
||||
Codex,
|
||||
}
|
||||
|
||||
impl ExternalAgent {
|
||||
@@ -158,7 +154,6 @@ impl ExternalAgent {
|
||||
match self {
|
||||
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
||||
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
||||
ExternalAgent::Codex => Rc::new(agent_servers::Codex),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -202,6 +197,11 @@ impl ModelUsageContext {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_settings(cx: &mut App) {
|
||||
AgentSettings::register(cx);
|
||||
SlashCommandSettings::register(cx);
|
||||
}
|
||||
|
||||
/// Initializes the `agent` crate.
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
@@ -211,8 +211,7 @@ pub fn init(
|
||||
is_eval: bool,
|
||||
cx: &mut App,
|
||||
) {
|
||||
AgentSettings::register(cx);
|
||||
SlashCommandSettings::register(cx);
|
||||
init_settings(cx);
|
||||
|
||||
assistant_context::init(client.clone(), cx);
|
||||
rules_library::init(cx);
|
||||
@@ -246,69 +245,6 @@ pub fn init(
|
||||
})
|
||||
.detach();
|
||||
cx.observe_new(ManageProfilesModal::register).detach();
|
||||
|
||||
// Update command palette filter based on AI settings
|
||||
update_command_palette_filter(cx);
|
||||
|
||||
// Watch for settings changes
|
||||
cx.observe_global::<SettingsStore>(|app_cx| {
|
||||
// When settings change, update the command palette filter
|
||||
update_command_palette_filter(app_cx);
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn update_command_palette_filter(cx: &mut App) {
|
||||
let disable_ai = DisableAiSettings::get_global(cx).disable_ai;
|
||||
CommandPaletteFilter::update_global(cx, |filter, _| {
|
||||
if disable_ai {
|
||||
filter.hide_namespace("agent");
|
||||
filter.hide_namespace("assistant");
|
||||
filter.hide_namespace("copilot");
|
||||
filter.hide_namespace("zed_predict_onboarding");
|
||||
|
||||
filter.hide_namespace("edit_prediction");
|
||||
|
||||
use editor::actions::{
|
||||
AcceptEditPrediction, AcceptPartialEditPrediction, NextEditPrediction,
|
||||
PreviousEditPrediction, ShowEditPrediction, ToggleEditPrediction,
|
||||
};
|
||||
let edit_prediction_actions = [
|
||||
TypeId::of::<AcceptEditPrediction>(),
|
||||
TypeId::of::<AcceptPartialEditPrediction>(),
|
||||
TypeId::of::<ShowEditPrediction>(),
|
||||
TypeId::of::<NextEditPrediction>(),
|
||||
TypeId::of::<PreviousEditPrediction>(),
|
||||
TypeId::of::<ToggleEditPrediction>(),
|
||||
];
|
||||
filter.hide_action_types(&edit_prediction_actions);
|
||||
filter.hide_action_types(&[TypeId::of::<zed_actions::OpenZedPredictOnboarding>()]);
|
||||
} else {
|
||||
filter.show_namespace("agent");
|
||||
filter.show_namespace("assistant");
|
||||
filter.show_namespace("copilot");
|
||||
filter.show_namespace("zed_predict_onboarding");
|
||||
|
||||
filter.show_namespace("edit_prediction");
|
||||
|
||||
use editor::actions::{
|
||||
AcceptEditPrediction, AcceptPartialEditPrediction, NextEditPrediction,
|
||||
PreviousEditPrediction, ShowEditPrediction, ToggleEditPrediction,
|
||||
};
|
||||
let edit_prediction_actions = [
|
||||
TypeId::of::<AcceptEditPrediction>(),
|
||||
TypeId::of::<AcceptPartialEditPrediction>(),
|
||||
TypeId::of::<ShowEditPrediction>(),
|
||||
TypeId::of::<NextEditPrediction>(),
|
||||
TypeId::of::<PreviousEditPrediction>(),
|
||||
TypeId::of::<ToggleEditPrediction>(),
|
||||
];
|
||||
filter.show_action_types(edit_prediction_actions.iter());
|
||||
|
||||
filter
|
||||
.show_action_types([TypeId::of::<zed_actions::OpenZedPredictOnboarding>()].iter());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn init_language_model_settings(cx: &mut App) {
|
||||
|
||||
@@ -16,7 +16,7 @@ use agent::{
|
||||
};
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::{DisableAiSettings, telemetry::Telemetry};
|
||||
use client::telemetry::Telemetry;
|
||||
use collections::{HashMap, HashSet, VecDeque, hash_map};
|
||||
use editor::SelectionEffects;
|
||||
use editor::{
|
||||
@@ -57,17 +57,6 @@ pub fn init(
|
||||
cx: &mut App,
|
||||
) {
|
||||
cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
|
||||
|
||||
cx.observe_global::<SettingsStore>(|cx| {
|
||||
if DisableAiSettings::get_global(cx).disable_ai {
|
||||
// Hide any active inline assist UI when AI is disabled
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
assistant.cancel_all_active_completions(cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.observe_new(|_workspace: &mut Workspace, window, cx| {
|
||||
let Some(window) = window else {
|
||||
return;
|
||||
@@ -152,26 +141,6 @@ impl InlineAssistant {
|
||||
.detach();
|
||||
}
|
||||
|
||||
/// Hides all active inline assists when AI is disabled
|
||||
pub fn cancel_all_active_completions(&mut self, cx: &mut App) {
|
||||
// Cancel all active completions in editors
|
||||
for (editor_handle, _) in self.assists_by_editor.iter() {
|
||||
if let Some(editor) = editor_handle.upgrade() {
|
||||
let windows = cx.windows();
|
||||
if !windows.is_empty() {
|
||||
let window = windows[0];
|
||||
let _ = window.update(cx, |_, window, cx| {
|
||||
editor.update(cx, |editor, cx| {
|
||||
if editor.has_active_inline_completion() {
|
||||
editor.cancel(&Default::default(), window, cx);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_workspace_event(
|
||||
&mut self,
|
||||
workspace: Entity<Workspace>,
|
||||
@@ -207,7 +176,7 @@ impl InlineAssistant {
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let is_assistant2_enabled = !DisableAiSettings::get_global(cx).disable_ai;
|
||||
let is_assistant2_enabled = true;
|
||||
|
||||
if let Some(editor) = item.act_as::<Editor>(cx) {
|
||||
editor.update(cx, |editor, cx| {
|
||||
@@ -230,13 +199,6 @@ impl InlineAssistant {
|
||||
cx,
|
||||
);
|
||||
|
||||
if DisableAiSettings::get_global(cx).disable_ai {
|
||||
// Cancel any active completions
|
||||
if editor.has_active_inline_completion() {
|
||||
editor.cancel(&Default::default(), window, cx);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the Assistant1 code action provider, as it still might be registered.
|
||||
editor.remove_code_action_provider("assistant".into(), window, cx);
|
||||
} else {
|
||||
@@ -257,7 +219,7 @@ impl InlineAssistant {
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
let settings = AgentSettings::get_global(cx);
|
||||
if !settings.enabled || DisableAiSettings::get_global(cx).disable_ai {
|
||||
if !settings.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ use crate::ui::{
|
||||
MaxModeTooltip,
|
||||
preview::{AgentPreview, UsageCallout},
|
||||
};
|
||||
use agent::history_store::HistoryStore;
|
||||
use agent::{
|
||||
context::{AgentContextKey, ContextLoadResult, load_context},
|
||||
context_store::ContextStoreEvent,
|
||||
@@ -30,9 +29,8 @@ use fs::Fs;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt as _, future};
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, Entity, EventEmitter, Focusable, IntoElement, KeyContext,
|
||||
Subscription, Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point,
|
||||
pulsating_between,
|
||||
Animation, AnimationExt, App, Entity, EventEmitter, Focusable, KeyContext, Subscription, Task,
|
||||
TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
|
||||
};
|
||||
use language::{Buffer, Language, Point};
|
||||
use language_model::{
|
||||
@@ -82,7 +80,6 @@ pub struct MessageEditor {
|
||||
user_store: Entity<UserStore>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
history_store: Option<WeakEntity<HistoryStore>>,
|
||||
context_strip: Entity<ContextStrip>,
|
||||
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
|
||||
model_selector: Entity<AgentModelSelector>,
|
||||
@@ -164,7 +161,6 @@ impl MessageEditor {
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
text_thread_store: WeakEntity<TextThreadStore>,
|
||||
history_store: Option<WeakEntity<HistoryStore>>,
|
||||
thread: Entity<Thread>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -237,7 +233,6 @@ impl MessageEditor {
|
||||
workspace,
|
||||
context_store,
|
||||
prompt_store,
|
||||
history_store,
|
||||
context_strip,
|
||||
context_picker_menu_handle,
|
||||
load_context_task: None,
|
||||
@@ -630,7 +625,7 @@ impl MessageEditor {
|
||||
.unwrap_or(false);
|
||||
|
||||
IconButton::new("follow-agent", IconName::Crosshair)
|
||||
.disabled(!is_model_selected)
|
||||
.disabled(is_model_selected)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.toggle_state(following)
|
||||
@@ -915,10 +910,6 @@ impl MessageEditor {
|
||||
.on_click({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |_event, window, cx| {
|
||||
telemetry::event!(
|
||||
"Agent Message Sent",
|
||||
agent = "zed",
|
||||
);
|
||||
focus_handle.dispatch_action(
|
||||
&Chat, window, cx,
|
||||
);
|
||||
@@ -1666,36 +1657,32 @@ impl Render for MessageEditor {
|
||||
|
||||
let line_height = TextSize::Small.rems(cx).to_pixels(window.rem_size()) * 1.5;
|
||||
|
||||
let has_configured_providers = LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.filter(|provider| {
|
||||
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
|
||||
})
|
||||
.count()
|
||||
> 0;
|
||||
let in_pro_trial = matches!(
|
||||
self.user_store.read(cx).current_plan(),
|
||||
Some(proto::Plan::ZedProTrial)
|
||||
);
|
||||
|
||||
let is_signed_out = self
|
||||
.workspace
|
||||
.read_with(cx, |workspace, _| {
|
||||
workspace.client().status().borrow().is_signed_out()
|
||||
})
|
||||
.unwrap_or(true);
|
||||
let pro_user = matches!(
|
||||
self.user_store.read(cx).current_plan(),
|
||||
Some(proto::Plan::ZedPro)
|
||||
);
|
||||
|
||||
let has_history = self
|
||||
.history_store
|
||||
.as_ref()
|
||||
.and_then(|hs| hs.update(cx, |hs, cx| hs.entries(cx).len() > 0).ok())
|
||||
.unwrap_or(false)
|
||||
|| self
|
||||
.thread
|
||||
.read_with(cx, |thread, _| thread.messages().len() > 0);
|
||||
let configured_providers: Vec<(IconName, SharedString)> =
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.filter(|provider| {
|
||||
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
|
||||
})
|
||||
.map(|provider| (provider.icon(), provider.name().0.clone()))
|
||||
.collect();
|
||||
let has_existing_providers = configured_providers.len() > 0;
|
||||
|
||||
v_flex()
|
||||
.size_full()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.when(
|
||||
!has_history && is_signed_out && has_configured_providers,
|
||||
has_existing_providers && !in_pro_trial && !pro_user,
|
||||
|this| this.child(cx.new(ApiKeysWithProviders::new)),
|
||||
)
|
||||
.when(changed_buffers.len() > 0, |parent| {
|
||||
@@ -1787,7 +1774,6 @@ impl AgentPreview for MessageEditor {
|
||||
None,
|
||||
thread_store.downgrade(),
|
||||
text_thread_store.downgrade(),
|
||||
None,
|
||||
thread,
|
||||
window,
|
||||
cx,
|
||||
|
||||
@@ -5,6 +5,7 @@ mod end_trial_upsell;
|
||||
mod new_thread_button;
|
||||
mod onboarding_modal;
|
||||
pub mod preview;
|
||||
mod upsell;
|
||||
|
||||
pub use agent_notification::*;
|
||||
pub use burn_mode_tooltip::*;
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use ai_onboarding::{AgentPanelOnboardingCard, BulletItem};
|
||||
use client::zed_urls;
|
||||
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
|
||||
use ui::{Divider, List, Tooltip, prelude::*};
|
||||
use ui::{Divider, List, prelude::*};
|
||||
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct EndTrialUpsell {
|
||||
@@ -33,19 +33,14 @@ impl RenderOnce for EndTrialUpsell {
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("500 prompts with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited edit predictions with Zeta, our open-source model",
|
||||
)),
|
||||
.child(BulletItem::new("500 prompts per month with Claude models"))
|
||||
.child(BulletItem::new("Unlimited edit predictions")),
|
||||
)
|
||||
.child(
|
||||
Button::new("cta-button", "Upgrade to Zed Pro")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!("Upgrade To Pro Clicked", state = "end-of-trial");
|
||||
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
|
||||
}),
|
||||
.on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
|
||||
);
|
||||
|
||||
let free_section = v_flex()
|
||||
@@ -60,43 +55,37 @@ impl RenderOnce for EndTrialUpsell {
|
||||
.color(Color::Muted)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(
|
||||
Label::new("(Current Plan)")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Custom(cx.theme().colors().text_muted.opacity(0.6)))
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("50 prompts with the Claude models"))
|
||||
.child(BulletItem::new("2,000 accepted edit predictions")),
|
||||
.child(BulletItem::new(
|
||||
"50 prompts per month with the Claude models",
|
||||
))
|
||||
.child(BulletItem::new(
|
||||
"2000 accepted edit predictions using our open-source Zeta model",
|
||||
)),
|
||||
)
|
||||
.child(
|
||||
Button::new("dismiss-button", "Stay on Free")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Outlined)
|
||||
.on_click({
|
||||
let callback = self.dismiss_upsell.clone();
|
||||
move |_, window, cx| callback(window, cx)
|
||||
}),
|
||||
);
|
||||
|
||||
AgentPanelOnboardingCard::new()
|
||||
.child(Headline::new("Your Zed Pro Trial has expired"))
|
||||
.child(Headline::new("Your Zed Pro trial has expired."))
|
||||
.child(
|
||||
Label::new("You've been automatically reset to the Free plan.")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
.mb_2(),
|
||||
.mb_1(),
|
||||
)
|
||||
.child(pro_section)
|
||||
.child(free_section)
|
||||
.child(
|
||||
h_flex().absolute().top_4().right_4().child(
|
||||
IconButton::new("dismiss_onboarding", IconName::Close)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip(Tooltip::text("Dismiss"))
|
||||
.on_click({
|
||||
let callback = self.dismiss_upsell.clone();
|
||||
move |_, window, cx| {
|
||||
telemetry::event!("Banner Dismissed", source = "AI Onboarding");
|
||||
callback(window, cx)
|
||||
}
|
||||
}),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
163
crates/agent_ui/src/ui/upsell.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
use component::{Component, ComponentScope, single_example};
|
||||
use gpui::{
|
||||
AnyElement, App, ClickEvent, IntoElement, ParentElement, RenderOnce, SharedString, Styled,
|
||||
Window,
|
||||
};
|
||||
use theme::ActiveTheme;
|
||||
use ui::{
|
||||
Button, ButtonCommon, ButtonStyle, Checkbox, Clickable, Color, Label, LabelCommon,
|
||||
RegisterComponent, ToggleState, h_flex, v_flex,
|
||||
};
|
||||
|
||||
/// A component that displays an upsell message with a call-to-action button
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// let upsell = Upsell::new(
|
||||
/// "Upgrade to Zed Pro",
|
||||
/// "Get access to advanced AI features and more",
|
||||
/// "Upgrade Now",
|
||||
/// Box::new(|_, _window, cx| {
|
||||
/// cx.open_url("https://zed.dev/pricing");
|
||||
/// }),
|
||||
/// Box::new(|_, _window, cx| {
|
||||
/// // Handle dismiss
|
||||
/// }),
|
||||
/// Box::new(|checked, window, cx| {
|
||||
/// // Handle don't show again
|
||||
/// }),
|
||||
/// );
|
||||
/// ```
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct Upsell {
|
||||
title: SharedString,
|
||||
message: SharedString,
|
||||
cta_text: SharedString,
|
||||
on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>,
|
||||
on_dismiss: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>,
|
||||
on_dont_show_again: Box<dyn Fn(bool, &mut Window, &mut App) + 'static>,
|
||||
}
|
||||
|
||||
impl Upsell {
|
||||
/// Create a new upsell component
|
||||
pub fn new(
|
||||
title: impl Into<SharedString>,
|
||||
message: impl Into<SharedString>,
|
||||
cta_text: impl Into<SharedString>,
|
||||
on_click: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>,
|
||||
on_dismiss: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>,
|
||||
on_dont_show_again: Box<dyn Fn(bool, &mut Window, &mut App) + 'static>,
|
||||
) -> Self {
|
||||
Self {
|
||||
title: title.into(),
|
||||
message: message.into(),
|
||||
cta_text: cta_text.into(),
|
||||
on_click,
|
||||
on_dismiss,
|
||||
on_dont_show_again,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for Upsell {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.p_4()
|
||||
.gap_3()
|
||||
.bg(cx.theme().colors().surface_background)
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Label::new(self.title)
|
||||
.size(ui::LabelSize::Large)
|
||||
.weight(gpui::FontWeight::BOLD),
|
||||
)
|
||||
.child(Label::new(self.message).color(Color::Muted)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.items_center()
|
||||
.child(
|
||||
h_flex()
|
||||
.items_center()
|
||||
.gap_1()
|
||||
.child(
|
||||
Checkbox::new("dont-show-again", ToggleState::Unselected).on_click(
|
||||
move |_, window, cx| {
|
||||
(self.on_dont_show_again)(true, window, cx);
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Label::new("Don't show again")
|
||||
.color(Color::Muted)
|
||||
.size(ui::LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("dismiss-button", "No Thanks")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.on_click(self.on_dismiss),
|
||||
)
|
||||
.child(
|
||||
Button::new("cta-button", self.cta_text)
|
||||
.style(ButtonStyle::Filled)
|
||||
.on_click(self.on_click),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for Upsell {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
|
||||
fn name() -> &'static str {
|
||||
"Upsell"
|
||||
}
|
||||
|
||||
fn description() -> Option<&'static str> {
|
||||
Some("A promotional component that displays a message with a call-to-action.")
|
||||
}
|
||||
|
||||
fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
|
||||
let examples = vec![
|
||||
single_example(
|
||||
"Default",
|
||||
Upsell::new(
|
||||
"Upgrade to Zed Pro",
|
||||
"Get unlimited access to AI features and more with Zed Pro. Unlock advanced AI capabilities and other premium features.",
|
||||
"Upgrade Now",
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
).render(window, cx).into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Short Message",
|
||||
Upsell::new(
|
||||
"Try Zed Pro for free",
|
||||
"Start your 7-day trial today.",
|
||||
"Start Trial",
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
Box::new(|_, _, _| {}),
|
||||
).render(window, cx).into_any_element(),
|
||||
),
|
||||
];
|
||||
|
||||
Some(v_flex().gap_4().children(examples).into_any_element())
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,6 @@ language_model.workspace = true
|
||||
proto.workspace = true
|
||||
serde.workspace = true
|
||||
smallvec.workspace = true
|
||||
telemetry.workspace = true
|
||||
ui.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_actions.workspace = true
|
||||
|
||||
@@ -38,6 +38,10 @@ impl ApiKeysWithProviders {
|
||||
.map(|provider| (provider.icon(), provider.name().0.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn has_providers(&self) -> bool {
|
||||
!self.configured_providers.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ApiKeysWithProviders {
|
||||
@@ -49,10 +53,11 @@ impl Render for ApiKeysWithProviders {
|
||||
.map(|(icon, name)| {
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
|
||||
.child(Icon::new(icon).size(IconSize::Small).color(Color::Muted))
|
||||
.child(Label::new(name))
|
||||
});
|
||||
div()
|
||||
|
||||
h_flex()
|
||||
.mx_2p5()
|
||||
.p_1()
|
||||
.pb_0()
|
||||
@@ -80,24 +85,8 @@ impl Render for ApiKeysWithProviders {
|
||||
.border_x_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.child(
|
||||
h_flex()
|
||||
.min_w_0()
|
||||
.gap_2()
|
||||
.child(
|
||||
Icon::new(IconName::Info)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted)
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.w_full()
|
||||
.child(
|
||||
Label::new("Start now using API keys from your environment for the following providers:")
|
||||
.color(Color::Muted)
|
||||
)
|
||||
)
|
||||
)
|
||||
.child(Icon::new(IconName::Info).size(IconSize::XSmall).color(Color::Muted))
|
||||
.child(Label::new("Or start now using API keys from your environment for the following providers:").color(Color::Muted))
|
||||
.children(configured_providers_list)
|
||||
)
|
||||
}
|
||||
@@ -129,7 +118,7 @@ impl RenderOnce for ApiKeysWithoutProviders {
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(List::new().child(BulletItem::new(
|
||||
"Add your own keys to use AI without signing in.",
|
||||
"You can also use AI in Zed by bringing your own API keys",
|
||||
)))
|
||||
.child(
|
||||
Button::new("configure-providers", "Configure Providers")
|
||||
|
||||
@@ -61,11 +61,6 @@ impl Render for AgentPanelOnboarding {
|
||||
Some(proto::Plan::ZedProTrial)
|
||||
);
|
||||
|
||||
let is_pro_user = matches!(
|
||||
self.user_store.read(cx).current_plan(),
|
||||
Some(proto::Plan::ZedPro)
|
||||
);
|
||||
|
||||
AgentPanelOnboardingCard::new()
|
||||
.child(
|
||||
ZedAiOnboarding::new(
|
||||
@@ -80,7 +75,7 @@ impl Render for AgentPanelOnboarding {
|
||||
}),
|
||||
)
|
||||
.map(|this| {
|
||||
if enrolled_in_trial || is_pro_user || self.configured_providers.len() >= 1 {
|
||||
if enrolled_in_trial || self.configured_providers.len() >= 1 {
|
||||
this
|
||||
} else {
|
||||
this.child(ApiKeysWithoutProviders::new())
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
mod agent_api_keys_onboarding;
|
||||
mod agent_panel_onboarding_card;
|
||||
mod agent_panel_onboarding_content;
|
||||
mod ai_upsell_card;
|
||||
mod edit_prediction_onboarding_content;
|
||||
mod young_account_banner;
|
||||
|
||||
pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProviders};
|
||||
pub use agent_panel_onboarding_card::AgentPanelOnboardingCard;
|
||||
pub use agent_panel_onboarding_content::AgentPanelOnboarding;
|
||||
pub use ai_upsell_card::AiUpsellCard;
|
||||
pub use edit_prediction_onboarding_content::EditPredictionOnboarding;
|
||||
pub use young_account_banner::YoungAccountBanner;
|
||||
|
||||
@@ -18,7 +16,6 @@ use client::{Client, UserStore, zed_urls};
|
||||
use gpui::{AnyElement, Entity, IntoElement, ParentElement, SharedString};
|
||||
use ui::{Divider, List, ListItem, RegisterComponent, TintColor, Tooltip, prelude::*};
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct BulletItem {
|
||||
label: SharedString,
|
||||
}
|
||||
@@ -31,32 +28,22 @@ impl BulletItem {
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for BulletItem {
|
||||
fn render(self, window: &mut Window, _cx: &mut App) -> impl IntoElement {
|
||||
let line_height = 0.85 * window.line_height();
|
||||
impl IntoElement for BulletItem {
|
||||
type Element = AnyElement;
|
||||
|
||||
fn into_element(self) -> Self::Element {
|
||||
ListItem::new("list-item")
|
||||
.selectable(false)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.min_w_0()
|
||||
.gap_1()
|
||||
.items_start()
|
||||
.child(
|
||||
h_flex().h(line_height).justify_center().child(
|
||||
Icon::new(IconName::Dash)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Hidden),
|
||||
),
|
||||
)
|
||||
.child(div().w_full().min_w_0().child(Label::new(self.label))),
|
||||
.start_slot(
|
||||
Icon::new(IconName::Dash)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Hidden),
|
||||
)
|
||||
.child(div().w_full().child(Label::new(self.label)))
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum SignInStatus {
|
||||
SignedIn,
|
||||
SigningIn,
|
||||
@@ -154,18 +141,22 @@ impl ZedAiOnboarding {
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("50 prompts per month with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"2,000 accepted edit predictions with Zeta, our open-source model",
|
||||
"50 prompts per month with the Claude models",
|
||||
))
|
||||
.child(BulletItem::new(
|
||||
"2000 accepted edit predictions using our open-source Zeta model",
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
fn pro_trial_definition(&self) -> impl IntoElement {
|
||||
List::new()
|
||||
.child(BulletItem::new("150 prompts with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited accepted edit predictions with Zeta, our open-source model",
|
||||
"150 prompts per month with the Claude models",
|
||||
))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited accepted edit predictions using our open-source Zeta model",
|
||||
))
|
||||
}
|
||||
|
||||
@@ -187,16 +178,15 @@ impl ZedAiOnboarding {
|
||||
List::new()
|
||||
.child(BulletItem::new("500 prompts per month with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited accepted edit predictions with Zeta, our open-source model",
|
||||
"Unlimited accepted edit predictions using our open-source Zeta model",
|
||||
))
|
||||
.child(BulletItem::new("$20 USD per month")),
|
||||
.child(BulletItem::new("USD $20 per month")),
|
||||
)
|
||||
.child(
|
||||
Button::new("pro", "Get Started")
|
||||
Button::new("pro", "Start with Pro")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!("Upgrade To Pro Clicked", state = "young-account");
|
||||
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
|
||||
}),
|
||||
)
|
||||
@@ -216,15 +206,14 @@ impl ZedAiOnboarding {
|
||||
List::new()
|
||||
.child(self.pro_trial_definition())
|
||||
.child(BulletItem::new(
|
||||
"Try it out for 14 days for free, no credit card required",
|
||||
"Try it out for 14 days with no charge and no credit card required",
|
||||
)),
|
||||
)
|
||||
.child(
|
||||
Button::new("pro", "Start Free Trial")
|
||||
Button::new("pro", "Start Pro Trial")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!("Start Trial Clicked", state = "post-sign-in");
|
||||
cx.open_url(&zed_urls::start_trial_url(cx))
|
||||
}),
|
||||
)
|
||||
@@ -236,33 +225,28 @@ impl ZedAiOnboarding {
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.w_full()
|
||||
.child(Headline::new("Accept Terms of Service"))
|
||||
.child(Headline::new("Before starting…"))
|
||||
.child(
|
||||
Label::new("We don’t sell your data, track you across the web, or compromise your privacy.")
|
||||
Label::new("Make sure you have read and accepted Zed AI's terms of service.")
|
||||
.color(Color::Muted)
|
||||
.mb_2(),
|
||||
)
|
||||
.child(
|
||||
Button::new("terms_of_service", "Review Terms of Service")
|
||||
Button::new("terms_of_service", "View and Read the Terms of Service")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Outlined)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!("Review Terms of Service Clicked");
|
||||
cx.open_url(&zed_urls::terms_of_service(cx))
|
||||
}),
|
||||
.on_click(move |_, _window, cx| cx.open_url(&zed_urls::terms_of_service(cx))),
|
||||
)
|
||||
.child(
|
||||
Button::new("accept_terms", "Accept")
|
||||
Button::new("accept_terms", "I've read it and accept it")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.on_click({
|
||||
let callback = self.accept_terms_of_service.clone();
|
||||
move |_, window, cx| {
|
||||
telemetry::event!("Terms of Service Accepted");
|
||||
(callback)(window, cx)}
|
||||
move |_, window, cx| (callback)(window, cx)
|
||||
}),
|
||||
)
|
||||
.into_any_element()
|
||||
@@ -275,22 +259,19 @@ impl ZedAiOnboarding {
|
||||
.gap_1()
|
||||
.child(Headline::new("Welcome to Zed AI"))
|
||||
.child(
|
||||
Label::new("Sign in to try Zed Pro for 14 days, no credit card required.")
|
||||
Label::new("Sign in to start using AI in Zed with a free trial of the Pro plan, which includes:")
|
||||
.color(Color::Muted)
|
||||
.mb_2(),
|
||||
)
|
||||
.child(self.pro_trial_definition())
|
||||
.child(
|
||||
Button::new("sign_in", "Try Zed Pro for Free")
|
||||
Button::new("sign_in", "Sign in to Start Trial")
|
||||
.disabled(signing_in)
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click({
|
||||
let callback = self.sign_in.clone();
|
||||
move |_, window, cx| {
|
||||
telemetry::event!("Start Trial Clicked", state = "pre-sign-in");
|
||||
callback(window, cx)
|
||||
}
|
||||
move |_, window, cx| callback(window, cx)
|
||||
}),
|
||||
)
|
||||
.into_any_element()
|
||||
@@ -303,6 +284,11 @@ impl ZedAiOnboarding {
|
||||
.relative()
|
||||
.gap_1()
|
||||
.child(Headline::new("Welcome to Zed AI"))
|
||||
.child(
|
||||
Label::new("Choose how you want to start.")
|
||||
.color(Color::Muted)
|
||||
.mb_2(),
|
||||
)
|
||||
.map(|this| {
|
||||
if self.account_too_young {
|
||||
this.child(young_account_banner)
|
||||
@@ -317,13 +303,7 @@ impl ZedAiOnboarding {
|
||||
IconButton::new("dismiss_onboarding", IconName::Close)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip(Tooltip::text("Dismiss"))
|
||||
.on_click(move |_, window, cx| {
|
||||
telemetry::event!(
|
||||
"Banner Dismissed",
|
||||
source = "AI Onboarding",
|
||||
);
|
||||
callback(window, cx)
|
||||
}),
|
||||
.on_click(move |_, window, cx| callback(window, cx)),
|
||||
),
|
||||
)
|
||||
},
|
||||
@@ -338,7 +318,7 @@ impl ZedAiOnboarding {
|
||||
v_flex()
|
||||
.relative()
|
||||
.gap_1()
|
||||
.child(Headline::new("Welcome to the Zed Pro Trial"))
|
||||
.child(Headline::new("Welcome to the Zed Pro free trial"))
|
||||
.child(
|
||||
Label::new("Here's what you get for the next 14 days:")
|
||||
.color(Color::Muted)
|
||||
@@ -360,13 +340,7 @@ impl ZedAiOnboarding {
|
||||
IconButton::new("dismiss_onboarding", IconName::Close)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip(Tooltip::text("Dismiss"))
|
||||
.on_click(move |_, window, cx| {
|
||||
telemetry::event!(
|
||||
"Banner Dismissed",
|
||||
source = "AI Onboarding",
|
||||
);
|
||||
callback(window, cx)
|
||||
}),
|
||||
.on_click(move |_, window, cx| callback(window, cx)),
|
||||
),
|
||||
)
|
||||
},
|
||||
@@ -386,9 +360,7 @@ impl ZedAiOnboarding {
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("500 prompts with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited edit predictions with Zeta, our open-source model",
|
||||
)),
|
||||
.child(BulletItem::new("Unlimited edit predictions")),
|
||||
)
|
||||
.child(
|
||||
Button::new("pro", "Continue with Zed Pro")
|
||||
@@ -396,10 +368,7 @@ impl ZedAiOnboarding {
|
||||
.style(ButtonStyle::Outlined)
|
||||
.on_click({
|
||||
let callback = self.continue_with_zed_ai.clone();
|
||||
move |_, window, cx| {
|
||||
telemetry::event!("Banner Dismissed", source = "AI Onboarding");
|
||||
callback(window, cx)
|
||||
}
|
||||
move |_, window, cx| callback(window, cx)
|
||||
}),
|
||||
)
|
||||
.into_any_element()
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use client::{Client, zed_urls};
|
||||
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
|
||||
use ui::{Divider, List, Vector, VectorName, prelude::*};
|
||||
|
||||
use crate::{BulletItem, SignInStatus};
|
||||
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct AiUpsellCard {
|
||||
pub sign_in_status: SignInStatus,
|
||||
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
}
|
||||
|
||||
impl AiUpsellCard {
|
||||
pub fn new(client: Arc<Client>) -> Self {
|
||||
let status = *client.status().borrow();
|
||||
|
||||
Self {
|
||||
sign_in_status: status.into(),
|
||||
sign_in: Arc::new(move |_window, cx| {
|
||||
cx.spawn({
|
||||
let client = client.clone();
|
||||
async move |cx| {
|
||||
client.authenticate_and_connect(true, cx).await;
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for AiUpsellCard {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let pro_section = v_flex()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Pro")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Accent)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("500 prompts with Claude models"))
|
||||
.child(BulletItem::new(
|
||||
"Unlimited edit predictions with Zeta, our open-source model",
|
||||
)),
|
||||
);
|
||||
|
||||
let free_section = v_flex()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Free")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("50 prompts with the Claude models"))
|
||||
.child(BulletItem::new("2,000 accepted edit predictions")),
|
||||
);
|
||||
|
||||
let grid_bg = h_flex().absolute().inset_0().w_full().h(px(240.)).child(
|
||||
Vector::new(VectorName::Grid, rems_from_px(500.), rems_from_px(240.))
|
||||
.color(Color::Custom(cx.theme().colors().border.opacity(0.05))),
|
||||
);
|
||||
|
||||
let gradient_bg = div()
|
||||
.absolute()
|
||||
.inset_0()
|
||||
.size_full()
|
||||
.bg(gpui::linear_gradient(
|
||||
180.,
|
||||
gpui::linear_color_stop(
|
||||
cx.theme().colors().elevated_surface_background.opacity(0.8),
|
||||
0.,
|
||||
),
|
||||
gpui::linear_color_stop(
|
||||
cx.theme().colors().elevated_surface_background.opacity(0.),
|
||||
0.8,
|
||||
),
|
||||
));
|
||||
|
||||
const DESCRIPTION: &str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI.";
|
||||
|
||||
let footer_buttons = match self.sign_in_status {
|
||||
SignInStatus::SignedIn => v_flex()
|
||||
.items_center()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new("sign_in", "Start 14-day Free Pro Trial")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!("Start Trial Clicked", state = "post-sign-in");
|
||||
cx.open_url(&zed_urls::start_trial_url(cx))
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
Label::new("No credit card required")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any_element(),
|
||||
_ => Button::new("sign_in", "Sign In")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click({
|
||||
let callback = self.sign_in.clone();
|
||||
move |_, window, cx| {
|
||||
telemetry::event!("Start Trial Clicked", state = "pre-sign-in");
|
||||
callback(window, cx)
|
||||
}
|
||||
})
|
||||
.into_any_element(),
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.relative()
|
||||
.p_6()
|
||||
.pt_4()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.rounded_lg()
|
||||
.overflow_hidden()
|
||||
.child(grid_bg)
|
||||
.child(gradient_bg)
|
||||
.child(Headline::new("Try Zed AI"))
|
||||
.child(Label::new(DESCRIPTION).color(Color::Muted).mb_2())
|
||||
.child(
|
||||
h_flex()
|
||||
.mt_1p5()
|
||||
.mb_2p5()
|
||||
.items_start()
|
||||
.gap_12()
|
||||
.child(free_section)
|
||||
.child(pro_section),
|
||||
)
|
||||
.child(footer_buttons)
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for AiUpsellCard {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
|
||||
fn name() -> &'static str {
|
||||
"AI Upsell Card"
|
||||
}
|
||||
|
||||
fn sort_name() -> &'static str {
|
||||
"AI Upsell Card"
|
||||
}
|
||||
|
||||
fn description() -> Option<&'static str> {
|
||||
Some("A card presenting the Zed AI product during user's first-open onboarding flow.")
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
Some(
|
||||
v_flex()
|
||||
.p_4()
|
||||
.gap_4()
|
||||
.children(vec![example_group(vec![
|
||||
single_example(
|
||||
"Signed Out State",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedOut,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Signed In State",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
])])
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ pub struct YoungAccountBanner;
|
||||
|
||||
impl RenderOnce for YoungAccountBanner {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
const YOUNG_ACCOUNT_DISCLAIMER: &str = "To prevent abuse of our service, we cannot offer plans to GitHub accounts created fewer than 30 days ago. To request an exception, reach out to billing-support@zed.dev.";
|
||||
const YOUNG_ACCOUNT_DISCLAIMER: &str = "To prevent abuse of our service, we cannot offer plans to GitHub accounts created fewer than 30 days ago. To request an exception, reach out to billing@zed.dev.";
|
||||
|
||||
let label = div()
|
||||
.w_full()
|
||||
|
||||
@@ -1323,7 +1323,7 @@ fn setup_context_editor_with_fake_model(
|
||||
) -> (Entity<AssistantContext>, Arc<FakeLanguageModel>) {
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor().clone()));
|
||||
|
||||
let fake_provider = Arc::new(FakeLanguageModelProvider::default());
|
||||
let fake_provider = Arc::new(FakeLanguageModelProvider);
|
||||
let fake_model = Arc::new(fake_provider.test_model());
|
||||
|
||||
cx.update(|cx| {
|
||||
|
||||
@@ -767,11 +767,6 @@ impl ContextStore {
|
||||
fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let fs = self.fs.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
pub static ZED_STATELESS: LazyLock<bool> =
|
||||
LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
|
||||
if *ZED_STATELESS {
|
||||
return Ok(());
|
||||
}
|
||||
fs.create_dir(contexts_dir()).await?;
|
||||
|
||||
let mut paths = fs.read_dir(contexts_dir()).await?;
|
||||
|
||||
@@ -51,13 +51,23 @@ impl ActionLog {
|
||||
Some(self.tracked_buffers.get(buffer)?.snapshot.clone())
|
||||
}
|
||||
|
||||
pub fn has_unnotified_user_edits(&self) -> bool {
|
||||
self.tracked_buffers
|
||||
.values()
|
||||
.any(|tracked| tracked.has_unnotified_user_edits)
|
||||
}
|
||||
|
||||
/// Return a unified diff patch with user edits made since last read or notification
|
||||
pub fn unnotified_user_edits(&self, cx: &Context<Self>) -> Option<String> {
|
||||
let diffs = self
|
||||
if !self.has_unnotified_user_edits() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let unified_diff = self
|
||||
.tracked_buffers
|
||||
.values()
|
||||
.filter_map(|tracked| {
|
||||
if !tracked.may_have_unnotified_user_edits {
|
||||
if !tracked.has_unnotified_user_edits {
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -85,13 +95,9 @@ impl ActionLog {
|
||||
|
||||
Some(result)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n");
|
||||
|
||||
if diffs.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let unified_diff = diffs.join("\n\n");
|
||||
Some(unified_diff)
|
||||
}
|
||||
|
||||
@@ -100,7 +106,7 @@ impl ActionLog {
|
||||
pub fn flush_unnotified_user_edits(&mut self, cx: &Context<Self>) -> Option<String> {
|
||||
let patch = self.unnotified_user_edits(cx);
|
||||
self.tracked_buffers.values_mut().for_each(|tracked| {
|
||||
tracked.may_have_unnotified_user_edits = false;
|
||||
tracked.has_unnotified_user_edits = false;
|
||||
tracked.last_seen_base = tracked.diff_base.clone();
|
||||
});
|
||||
patch
|
||||
@@ -179,7 +185,7 @@ impl ActionLog {
|
||||
version: buffer.read(cx).version(),
|
||||
diff,
|
||||
diff_update: diff_update_tx,
|
||||
may_have_unnotified_user_edits: false,
|
||||
has_unnotified_user_edits: false,
|
||||
_open_lsp_handle: open_lsp_handle,
|
||||
_maintain_diff: cx.spawn({
|
||||
let buffer = buffer.clone();
|
||||
@@ -331,34 +337,27 @@ impl ActionLog {
|
||||
let new_snapshot = buffer_snapshot.clone();
|
||||
let unreviewed_edits = tracked_buffer.unreviewed_edits.clone();
|
||||
let edits = diff_snapshots(&old_snapshot, &new_snapshot);
|
||||
let mut has_user_changes = false;
|
||||
if let ChangeAuthor::User = author
|
||||
&& !edits.is_empty()
|
||||
{
|
||||
tracked_buffer.has_unnotified_user_edits = true;
|
||||
}
|
||||
async move {
|
||||
if let ChangeAuthor::User = author {
|
||||
has_user_changes = apply_non_conflicting_edits(
|
||||
apply_non_conflicting_edits(
|
||||
&unreviewed_edits,
|
||||
edits,
|
||||
&mut base_text,
|
||||
new_snapshot.as_rope(),
|
||||
);
|
||||
}
|
||||
|
||||
(Arc::new(base_text.to_string()), base_text, has_user_changes)
|
||||
(Arc::new(base_text.to_string()), base_text)
|
||||
}
|
||||
});
|
||||
|
||||
anyhow::Ok(rebase)
|
||||
})??;
|
||||
let (new_base_text, new_diff_base, has_user_changes) = rebase.await;
|
||||
|
||||
this.update(cx, |this, _| {
|
||||
let tracked_buffer = this
|
||||
.tracked_buffers
|
||||
.get_mut(buffer)
|
||||
.context("buffer not tracked")
|
||||
.unwrap();
|
||||
tracked_buffer.may_have_unnotified_user_edits |= has_user_changes;
|
||||
})?;
|
||||
|
||||
let (new_base_text, new_diff_base) = rebase.await;
|
||||
Self::update_diff(
|
||||
this,
|
||||
buffer,
|
||||
@@ -830,12 +829,11 @@ fn apply_non_conflicting_edits(
|
||||
edits: Vec<Edit<u32>>,
|
||||
old_text: &mut Rope,
|
||||
new_text: &Rope,
|
||||
) -> bool {
|
||||
) {
|
||||
let mut old_edits = patch.edits().iter().cloned().peekable();
|
||||
let mut new_edits = edits.into_iter().peekable();
|
||||
let mut applied_delta = 0i32;
|
||||
let mut rebased_delta = 0i32;
|
||||
let mut has_made_changes = false;
|
||||
|
||||
while let Some(mut new_edit) = new_edits.next() {
|
||||
let mut conflict = false;
|
||||
@@ -885,10 +883,8 @@ fn apply_non_conflicting_edits(
|
||||
&new_text.chunks_in_range(new_bytes).collect::<String>(),
|
||||
);
|
||||
applied_delta += new_edit.new_len() as i32 - new_edit.old_len() as i32;
|
||||
has_made_changes = true;
|
||||
}
|
||||
}
|
||||
has_made_changes
|
||||
}
|
||||
|
||||
fn diff_snapshots(
|
||||
@@ -962,7 +958,7 @@ struct TrackedBuffer {
|
||||
diff: Entity<BufferDiff>,
|
||||
snapshot: text::BufferSnapshot,
|
||||
diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>,
|
||||
may_have_unnotified_user_edits: bool,
|
||||
has_unnotified_user_edits: bool,
|
||||
_open_lsp_handle: OpenLspBufferHandle,
|
||||
_maintain_diff: Task<()>,
|
||||
_subscription: Subscription,
|
||||
|
||||
@@ -216,12 +216,7 @@ pub trait Tool: 'static + Send + Sync {
|
||||
|
||||
/// Returns true if the tool needs the users's confirmation
|
||||
/// before having permission to run.
|
||||
fn needs_confirmation(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> bool;
|
||||
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
|
||||
|
||||
/// Returns true if the tool may perform edits.
|
||||
fn may_perform_edits(&self) -> bool;
|
||||
|
||||
@@ -375,12 +375,7 @@ mod tests {
|
||||
false
|
||||
}
|
||||
|
||||
fn needs_confirmation(
|
||||
&self,
|
||||
_input: &serde_json::Value,
|
||||
_project: &Entity<Project>,
|
||||
_cx: &App,
|
||||
) -> bool {
|
||||
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
derive_more.workspace = true
|
||||
|
||||
@@ -20,13 +20,14 @@ mod thinking_tool;
|
||||
mod ui;
|
||||
mod web_search_tool;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use assistant_tool::ToolRegistry;
|
||||
use copy_path_tool::CopyPathTool;
|
||||
use gpui::{App, Entity};
|
||||
use http_client::HttpClientWithUrl;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use move_path_tool::MovePathTool;
|
||||
use std::sync::Arc;
|
||||
use web_search_tool::WebSearchTool;
|
||||
|
||||
pub(crate) use templates::*;
|
||||
|
||||
@@ -44,7 +44,7 @@ impl Tool for CopyPathTool {
|
||||
"copy_path".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ impl Tool for CreateDirectoryTool {
|
||||
include_str!("./create_directory_tool/description.md").into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ impl Tool for DeletePathTool {
|
||||
"delete_path".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ impl Tool for DiagnosticsTool {
|
||||
"diagnostics".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ use language::{
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
|
||||
use paths;
|
||||
use project::{
|
||||
Project, ProjectPath,
|
||||
lsp_store::{FormatTrigger, LspFormatTarget},
|
||||
@@ -127,47 +126,8 @@ impl Tool for EditFileTool {
|
||||
"edit_file".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(
|
||||
&self,
|
||||
input: &serde_json::Value,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
|
||||
return false;
|
||||
}
|
||||
|
||||
let Ok(input) = serde_json::from_value::<EditFileToolInput>(input.clone()) else {
|
||||
// If it's not valid JSON, it's going to error and confirming won't do anything.
|
||||
return false;
|
||||
};
|
||||
|
||||
// If any path component matches the local settings folder, then this could affect
|
||||
// the editor in ways beyond the project source, so prompt.
|
||||
let local_settings_folder = paths::local_settings_folder_relative_path();
|
||||
let path = Path::new(&input.path);
|
||||
if path
|
||||
.components()
|
||||
.any(|component| component.as_os_str() == local_settings_folder.as_os_str())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// It's also possible that the global config dir is configured to be inside the project,
|
||||
// so check for that edge case too.
|
||||
if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
|
||||
if canonical_path.starts_with(paths::config_dir()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if path is inside the global config directory
|
||||
// First check if it's already inside project - if not, try to canonicalize
|
||||
let project_path = project.read(cx).find_project_path(&input.path, cx);
|
||||
|
||||
// If the path is inside the project, and it's not one of the above edge cases,
|
||||
// then no confirmation is necessary. Otherwise, confirmation is necessary.
|
||||
project_path.is_none()
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
@@ -188,25 +148,7 @@ impl Tool for EditFileTool {
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let path = Path::new(&input.path);
|
||||
let mut description = input.display_description.clone();
|
||||
|
||||
// Add context about why confirmation may be needed
|
||||
let local_settings_folder = paths::local_settings_folder_relative_path();
|
||||
if path
|
||||
.components()
|
||||
.any(|c| c.as_os_str() == local_settings_folder.as_os_str())
|
||||
{
|
||||
description.push_str(" (local settings)");
|
||||
} else if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
|
||||
if canonical_path.starts_with(paths::config_dir()) {
|
||||
description.push_str(" (global settings)");
|
||||
}
|
||||
}
|
||||
|
||||
description
|
||||
}
|
||||
Ok(input) => input.display_description,
|
||||
Err(_) => "Editing file".to_string(),
|
||||
}
|
||||
}
|
||||
@@ -336,9 +278,6 @@ impl Tool for EditFileTool {
|
||||
.unwrap_or(false);
|
||||
|
||||
if format_on_save_enabled {
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_edited(buffer.clone(), cx);
|
||||
})?;
|
||||
let format_task = project.update(cx, |project, cx| {
|
||||
project.format(
|
||||
HashSet::from_iter([buffer.clone()]),
|
||||
@@ -1233,20 +1172,19 @@ async fn build_buffer_diff(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ::fs::Fs;
|
||||
use client::TelemetrySettings;
|
||||
use fs::{FakeFs, Fs};
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::fs;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
@@ -1336,7 +1274,7 @@ mod tests {
|
||||
) -> anyhow::Result<ProjectPath> {
|
||||
init_test(cx);
|
||||
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
@@ -1443,21 +1381,6 @@ mod tests {
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
TelemetrySettings::register(cx);
|
||||
agent_settings::AgentSettings::register(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn init_test_with_config(cx: &mut TestAppContext, data_dir: &Path) {
|
||||
cx.update(|cx| {
|
||||
// Set custom data directory (config will be under data_dir/config)
|
||||
paths::set_custom_data_dir(data_dir.to_str().unwrap());
|
||||
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
TelemetrySettings::register(cx);
|
||||
agent_settings::AgentSettings::register(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
@@ -1466,7 +1389,7 @@ mod tests {
|
||||
async fn test_format_on_save(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/root", json!({"src": {}})).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
@@ -1665,7 +1588,7 @@ mod tests {
|
||||
async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/root", json!({"src": {}})).await;
|
||||
|
||||
// Create a simple file with trailing whitespace
|
||||
@@ -1797,641 +1720,4 @@ mod tests {
|
||||
"Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_needs_confirmation(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let tool = Arc::new(EditFileTool);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
|
||||
// Test 1: Path with .zed component should require confirmation
|
||||
let input_with_zed = json!({
|
||||
"display_description": "Edit settings",
|
||||
"path": ".zed/settings.json",
|
||||
"mode": "edit"
|
||||
});
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
tool.needs_confirmation(&input_with_zed, &project, cx),
|
||||
"Path with .zed component should require confirmation"
|
||||
);
|
||||
});
|
||||
|
||||
// Test 2: Absolute path should require confirmation
|
||||
let input_absolute = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": "/etc/hosts",
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
tool.needs_confirmation(&input_absolute, &project, cx),
|
||||
"Absolute path should require confirmation"
|
||||
);
|
||||
});
|
||||
|
||||
// Test 3: Relative path without .zed should not require confirmation
|
||||
let input_relative = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": "root/src/main.rs",
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
!tool.needs_confirmation(&input_relative, &project, cx),
|
||||
"Relative path without .zed should not require confirmation"
|
||||
);
|
||||
});
|
||||
|
||||
// Test 4: Path with .zed in the middle should require confirmation
|
||||
let input_zed_middle = json!({
|
||||
"display_description": "Edit settings",
|
||||
"path": "root/.zed/tasks.json",
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
tool.needs_confirmation(&input_zed_middle, &project, cx),
|
||||
"Path with .zed in any component should require confirmation"
|
||||
);
|
||||
});
|
||||
|
||||
// Test 5: When always_allow_tool_actions is enabled, no confirmation needed
|
||||
cx.update(|cx| {
|
||||
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
||||
settings.always_allow_tool_actions = true;
|
||||
agent_settings::AgentSettings::override_global(settings, cx);
|
||||
|
||||
assert!(
|
||||
!tool.needs_confirmation(&input_with_zed, &project, cx),
|
||||
"When always_allow_tool_actions is true, no confirmation should be needed"
|
||||
);
|
||||
assert!(
|
||||
!tool.needs_confirmation(&input_absolute, &project, cx),
|
||||
"When always_allow_tool_actions is true, no confirmation should be needed for absolute paths"
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_ui_text_shows_correct_context(cx: &mut TestAppContext) {
|
||||
// Set up a custom config directory for testing
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
init_test_with_config(cx, temp_dir.path());
|
||||
|
||||
let tool = Arc::new(EditFileTool);
|
||||
|
||||
// Test ui_text shows context for various paths
|
||||
let test_cases = vec![
|
||||
(
|
||||
json!({
|
||||
"display_description": "Update config",
|
||||
"path": ".zed/settings.json",
|
||||
"mode": "edit"
|
||||
}),
|
||||
"Update config (local settings)",
|
||||
".zed path should show local settings context",
|
||||
),
|
||||
(
|
||||
json!({
|
||||
"display_description": "Fix bug",
|
||||
"path": "src/.zed/local.json",
|
||||
"mode": "edit"
|
||||
}),
|
||||
"Fix bug (local settings)",
|
||||
"Nested .zed path should show local settings context",
|
||||
),
|
||||
(
|
||||
json!({
|
||||
"display_description": "Update readme",
|
||||
"path": "README.md",
|
||||
"mode": "edit"
|
||||
}),
|
||||
"Update readme",
|
||||
"Normal path should not show additional context",
|
||||
),
|
||||
(
|
||||
json!({
|
||||
"display_description": "Edit config",
|
||||
"path": "config.zed",
|
||||
"mode": "edit"
|
||||
}),
|
||||
"Edit config",
|
||||
".zed as extension should not show context",
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected_text, description) in test_cases {
|
||||
cx.update(|_cx| {
|
||||
let ui_text = tool.ui_text(&input);
|
||||
assert_eq!(ui_text, expected_text, "Failed for case: {}", description);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_needs_confirmation_outside_project(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let tool = Arc::new(EditFileTool);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
|
||||
// Create a project in /project directory
|
||||
fs.insert_tree("/project", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
|
||||
// Test file outside project requires confirmation
|
||||
let input_outside = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": "/outside/file.txt",
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
tool.needs_confirmation(&input_outside, &project, cx),
|
||||
"File outside project should require confirmation"
|
||||
);
|
||||
});
|
||||
|
||||
// Test file inside project doesn't require confirmation
|
||||
let input_inside = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": "project/file.txt",
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
!tool.needs_confirmation(&input_inside, &project, cx),
|
||||
"File inside project should not require confirmation"
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_needs_confirmation_config_paths(cx: &mut TestAppContext) {
|
||||
// Set up a custom data directory for testing
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
init_test_with_config(cx, temp_dir.path());
|
||||
|
||||
let tool = Arc::new(EditFileTool);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/home/user/myproject", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/home/user/myproject").as_ref()], cx).await;
|
||||
|
||||
// Get the actual local settings folder name
|
||||
let local_settings_folder = paths::local_settings_folder_relative_path();
|
||||
|
||||
// Test various config path patterns
|
||||
let test_cases = vec![
|
||||
(
|
||||
format!("{}/settings.json", local_settings_folder.display()),
|
||||
true,
|
||||
"Top-level local settings file".to_string(),
|
||||
),
|
||||
(
|
||||
format!(
|
||||
"myproject/{}/settings.json",
|
||||
local_settings_folder.display()
|
||||
),
|
||||
true,
|
||||
"Local settings in project path".to_string(),
|
||||
),
|
||||
(
|
||||
format!("src/{}/config.toml", local_settings_folder.display()),
|
||||
true,
|
||||
"Local settings in subdirectory".to_string(),
|
||||
),
|
||||
(
|
||||
".zed.backup/file.txt".to_string(),
|
||||
true,
|
||||
".zed.backup is outside project".to_string(),
|
||||
),
|
||||
(
|
||||
"my.zed/file.txt".to_string(),
|
||||
true,
|
||||
"my.zed is outside project".to_string(),
|
||||
),
|
||||
(
|
||||
"myproject/src/file.zed".to_string(),
|
||||
false,
|
||||
".zed as file extension".to_string(),
|
||||
),
|
||||
(
|
||||
"myproject/normal/path/file.rs".to_string(),
|
||||
false,
|
||||
"Normal file without config paths".to_string(),
|
||||
),
|
||||
];
|
||||
|
||||
for (path, should_confirm, description) in test_cases {
|
||||
let input = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": path,
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
tool.needs_confirmation(&input, &project, cx),
|
||||
should_confirm,
|
||||
"Failed for case: {} - path: {}",
|
||||
description,
|
||||
path
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_needs_confirmation_global_config(cx: &mut TestAppContext) {
|
||||
// Set up a custom data directory for testing
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
init_test_with_config(cx, temp_dir.path());
|
||||
|
||||
let tool = Arc::new(EditFileTool);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
|
||||
// Create test files in the global config directory
|
||||
let global_config_dir = paths::config_dir();
|
||||
fs::create_dir_all(&global_config_dir).unwrap();
|
||||
let global_settings_path = global_config_dir.join("settings.json");
|
||||
fs::write(&global_settings_path, "{}").unwrap();
|
||||
|
||||
fs.insert_tree("/project", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
|
||||
// Test global config paths
|
||||
let test_cases = vec![
|
||||
(
|
||||
global_settings_path.to_str().unwrap().to_string(),
|
||||
true,
|
||||
"Global settings file should require confirmation",
|
||||
),
|
||||
(
|
||||
global_config_dir
|
||||
.join("keymap.json")
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
true,
|
||||
"Global keymap file should require confirmation",
|
||||
),
|
||||
(
|
||||
"project/normal_file.rs".to_string(),
|
||||
false,
|
||||
"Normal project file should not require confirmation",
|
||||
),
|
||||
];
|
||||
|
||||
for (path, should_confirm, description) in test_cases {
|
||||
let input = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": path,
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
tool.needs_confirmation(&input, &project, cx),
|
||||
should_confirm,
|
||||
"Failed for case: {}",
|
||||
description
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let tool = Arc::new(EditFileTool);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
|
||||
// Create multiple worktree directories
|
||||
fs.insert_tree(
|
||||
"/workspace/frontend",
|
||||
json!({
|
||||
"src": {
|
||||
"main.js": "console.log('frontend');"
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
fs.insert_tree(
|
||||
"/workspace/backend",
|
||||
json!({
|
||||
"src": {
|
||||
"main.rs": "fn main() {}"
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
fs.insert_tree(
|
||||
"/workspace/shared",
|
||||
json!({
|
||||
".zed": {
|
||||
"settings.json": "{}"
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Create project with multiple worktrees
|
||||
let project = Project::test(
|
||||
fs.clone(),
|
||||
[
|
||||
path!("/workspace/frontend").as_ref(),
|
||||
path!("/workspace/backend").as_ref(),
|
||||
path!("/workspace/shared").as_ref(),
|
||||
],
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Test files in different worktrees
|
||||
let test_cases = vec![
|
||||
("frontend/src/main.js", false, "File in first worktree"),
|
||||
("backend/src/main.rs", false, "File in second worktree"),
|
||||
(
|
||||
"shared/.zed/settings.json",
|
||||
true,
|
||||
".zed file in third worktree",
|
||||
),
|
||||
("/etc/hosts", true, "Absolute path outside all worktrees"),
|
||||
(
|
||||
"../outside/file.txt",
|
||||
true,
|
||||
"Relative path outside worktrees",
|
||||
),
|
||||
];
|
||||
|
||||
for (path, should_confirm, description) in test_cases {
|
||||
let input = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": path,
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
tool.needs_confirmation(&input, &project, cx),
|
||||
should_confirm,
|
||||
"Failed for case: {} - path: {}",
|
||||
description,
|
||||
path
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let tool = Arc::new(EditFileTool);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/project",
|
||||
json!({
|
||||
".zed": {
|
||||
"settings.json": "{}"
|
||||
},
|
||||
"src": {
|
||||
".zed": {
|
||||
"local.json": "{}"
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
|
||||
// Test edge cases
|
||||
let test_cases = vec![
|
||||
// Empty path - find_project_path returns Some for empty paths
|
||||
("", false, "Empty path is treated as project root"),
|
||||
// Root directory
|
||||
("/", true, "Root directory should be outside project"),
|
||||
// Parent directory references - find_project_path resolves these
|
||||
(
|
||||
"project/../other",
|
||||
false,
|
||||
"Path with .. is resolved by find_project_path",
|
||||
),
|
||||
(
|
||||
"project/./src/file.rs",
|
||||
false,
|
||||
"Path with . should work normally",
|
||||
),
|
||||
// Windows-style paths (if on Windows)
|
||||
#[cfg(target_os = "windows")]
|
||||
("C:\\Windows\\System32\\hosts", true, "Windows system path"),
|
||||
#[cfg(target_os = "windows")]
|
||||
("project\\src\\main.rs", false, "Windows-style project path"),
|
||||
];
|
||||
|
||||
for (path, should_confirm, description) in test_cases {
|
||||
let input = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": path,
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
tool.needs_confirmation(&input, &project, cx),
|
||||
should_confirm,
|
||||
"Failed for case: {} - path: {}",
|
||||
description,
|
||||
path
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_ui_text_with_all_path_types(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let tool = Arc::new(EditFileTool);
|
||||
|
||||
// Test UI text for various scenarios
|
||||
let test_cases = vec![
|
||||
(
|
||||
json!({
|
||||
"display_description": "Update config",
|
||||
"path": ".zed/settings.json",
|
||||
"mode": "edit"
|
||||
}),
|
||||
"Update config (local settings)",
|
||||
".zed path should show local settings context",
|
||||
),
|
||||
(
|
||||
json!({
|
||||
"display_description": "Fix bug",
|
||||
"path": "src/.zed/local.json",
|
||||
"mode": "edit"
|
||||
}),
|
||||
"Fix bug (local settings)",
|
||||
"Nested .zed path should show local settings context",
|
||||
),
|
||||
(
|
||||
json!({
|
||||
"display_description": "Update readme",
|
||||
"path": "README.md",
|
||||
"mode": "edit"
|
||||
}),
|
||||
"Update readme",
|
||||
"Normal path should not show additional context",
|
||||
),
|
||||
(
|
||||
json!({
|
||||
"display_description": "Edit config",
|
||||
"path": "config.zed",
|
||||
"mode": "edit"
|
||||
}),
|
||||
"Edit config",
|
||||
".zed as extension should not show context",
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected_text, description) in test_cases {
|
||||
cx.update(|_cx| {
|
||||
let ui_text = tool.ui_text(&input);
|
||||
assert_eq!(ui_text, expected_text, "Failed for case: {}", description);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let tool = Arc::new(EditFileTool);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/project",
|
||||
json!({
|
||||
"existing.txt": "content",
|
||||
".zed": {
|
||||
"settings.json": "{}"
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
|
||||
// Test different EditFileMode values
|
||||
let modes = vec![
|
||||
EditFileMode::Edit,
|
||||
EditFileMode::Create,
|
||||
EditFileMode::Overwrite,
|
||||
];
|
||||
|
||||
for mode in modes {
|
||||
// Test .zed path with different modes
|
||||
let input_zed = json!({
|
||||
"display_description": "Edit settings",
|
||||
"path": "project/.zed/settings.json",
|
||||
"mode": mode
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
tool.needs_confirmation(&input_zed, &project, cx),
|
||||
".zed path should require confirmation regardless of mode: {:?}",
|
||||
mode
|
||||
);
|
||||
});
|
||||
|
||||
// Test outside path with different modes
|
||||
let input_outside = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": "/outside/file.txt",
|
||||
"mode": mode
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
tool.needs_confirmation(&input_outside, &project, cx),
|
||||
"Outside path should require confirmation regardless of mode: {:?}",
|
||||
mode
|
||||
);
|
||||
});
|
||||
|
||||
// Test normal path with different modes
|
||||
let input_normal = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": "project/normal.txt",
|
||||
"mode": mode
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
!tool.needs_confirmation(&input_normal, &project, cx),
|
||||
"Normal path should not require confirmation regardless of mode: {:?}",
|
||||
mode
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_always_allow_tool_actions_bypasses_all_checks(cx: &mut TestAppContext) {
|
||||
// Set up with custom directories for deterministic testing
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
init_test_with_config(cx, temp_dir.path());
|
||||
|
||||
let tool = Arc::new(EditFileTool);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/project", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
|
||||
// Enable always_allow_tool_actions
|
||||
cx.update(|cx| {
|
||||
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
||||
settings.always_allow_tool_actions = true;
|
||||
agent_settings::AgentSettings::override_global(settings, cx);
|
||||
});
|
||||
|
||||
// Test that all paths that normally require confirmation are bypassed
|
||||
let global_settings_path = paths::config_dir().join("settings.json");
|
||||
fs::create_dir_all(paths::config_dir()).unwrap();
|
||||
fs::write(&global_settings_path, "{}").unwrap();
|
||||
|
||||
let test_cases = vec![
|
||||
".zed/settings.json",
|
||||
"project/.zed/config.toml",
|
||||
global_settings_path.to_str().unwrap(),
|
||||
"/etc/hosts",
|
||||
"/absolute/path/file.txt",
|
||||
"../outside/project.txt",
|
||||
];
|
||||
|
||||
for path in test_cases {
|
||||
let input = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": path,
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
!tool.needs_confirmation(&input, &project, cx),
|
||||
"Path {} should not require confirmation when always_allow_tool_actions is true",
|
||||
path
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
// Disable always_allow_tool_actions and verify confirmation is required again
|
||||
cx.update(|cx| {
|
||||
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
||||
settings.always_allow_tool_actions = false;
|
||||
agent_settings::AgentSettings::override_global(settings, cx);
|
||||
});
|
||||
|
||||
// Verify .zed path requires confirmation again
|
||||
let input = json!({
|
||||
"display_description": "Edit file",
|
||||
"path": ".zed/settings.json",
|
||||
"mode": "edit"
|
||||
});
|
||||
cx.update(|cx| {
|
||||
assert!(
|
||||
tool.needs_confirmation(&input, &project, cx),
|
||||
".zed path should require confirmation when always_allow_tool_actions is false"
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ impl Tool for FetchTool {
|
||||
"fetch".to_string()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ impl Tool for FindPathTool {
|
||||
"find_path".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ impl Tool for GrepTool {
|
||||
"grep".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ impl Tool for ListDirectoryTool {
|
||||
"list_directory".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ impl Tool for MovePathTool {
|
||||
"move_path".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ impl Tool for NowTool {
|
||||
"now".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ impl Tool for OpenTool {
|
||||
"open".to_string()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
true
|
||||
}
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
|
||||
@@ -19,7 +19,7 @@ impl Tool for ProjectNotificationsTool {
|
||||
"project_notifications".to_string()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
@@ -200,7 +200,7 @@ mod tests {
|
||||
|
||||
// Run the tool before any changes
|
||||
let tool = Arc::new(ProjectNotificationsTool);
|
||||
let provider = Arc::new(FakeLanguageModelProvider::default());
|
||||
let provider = Arc::new(FakeLanguageModelProvider);
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(provider.test_model());
|
||||
let request = Arc::new(LanguageModelRequest::default());
|
||||
let tool_input = json!({});
|
||||
|
||||
@@ -54,7 +54,7 @@ impl Tool for ReadFileTool {
|
||||
"read_file".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ impl Tool for TerminalTool {
|
||||
Self::NAME.to_string()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ impl Tool for ThinkingTool {
|
||||
"thinking".to_string()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ impl Tool for WebSearchTool {
|
||||
"web_search".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
@@ -17,5 +17,7 @@ default = []
|
||||
[dependencies]
|
||||
aws-smithy-runtime-api.workspace = true
|
||||
aws-smithy-types.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||
workspace-hack.workspace = true
|
||||
|
||||
@@ -11,11 +11,14 @@ use aws_smithy_runtime_api::client::result::ConnectorError;
|
||||
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
|
||||
use aws_smithy_runtime_api::http::{Headers, StatusCode};
|
||||
use aws_smithy_types::body::SdkBody;
|
||||
use http_client::AsyncBody;
|
||||
use futures::AsyncReadExt;
|
||||
use http_client::{AsyncBody, Inner};
|
||||
use http_client::{HttpClient, Request};
|
||||
use tokio::runtime::Handle;
|
||||
|
||||
struct AwsHttpConnector {
|
||||
client: Arc<dyn HttpClient>,
|
||||
handle: Handle,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AwsHttpConnector {
|
||||
@@ -39,17 +42,18 @@ impl AwsConnector for AwsHttpConnector {
|
||||
.client
|
||||
.send(Request::from_parts(parts, convert_to_async_body(body)));
|
||||
|
||||
let handle = self.handle.clone();
|
||||
|
||||
HttpConnectorFuture::new(async move {
|
||||
let response = match response.await {
|
||||
Ok(response) => response,
|
||||
Err(err) => return Err(ConnectorError::other(err.into(), None)),
|
||||
};
|
||||
let (parts, body) = response.into_parts();
|
||||
let body = convert_to_sdk_body(body, handle).await;
|
||||
|
||||
let mut response = HttpResponse::new(
|
||||
StatusCode::try_from(parts.status.as_u16()).unwrap(),
|
||||
convert_to_sdk_body(body),
|
||||
);
|
||||
let mut response =
|
||||
HttpResponse::new(StatusCode::try_from(parts.status.as_u16()).unwrap(), body);
|
||||
|
||||
let headers = match Headers::try_from(parts.headers) {
|
||||
Ok(headers) => headers,
|
||||
@@ -66,6 +70,7 @@ impl AwsConnector for AwsHttpConnector {
|
||||
#[derive(Clone)]
|
||||
pub struct AwsHttpClient {
|
||||
client: Arc<dyn HttpClient>,
|
||||
handler: Handle,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AwsHttpClient {
|
||||
@@ -75,8 +80,11 @@ impl std::fmt::Debug for AwsHttpClient {
|
||||
}
|
||||
|
||||
impl AwsHttpClient {
|
||||
pub fn new(client: Arc<dyn HttpClient>) -> Self {
|
||||
Self { client }
|
||||
pub fn new(client: Arc<dyn HttpClient>, handle: Handle) -> Self {
|
||||
Self {
|
||||
client,
|
||||
handler: handle,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,12 +96,25 @@ impl AwsClient for AwsHttpClient {
|
||||
) -> SharedHttpConnector {
|
||||
SharedHttpConnector::new(AwsHttpConnector {
|
||||
client: self.client.clone(),
|
||||
handle: self.handler.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_to_sdk_body(body: AsyncBody) -> SdkBody {
|
||||
SdkBody::from_body_1_x(body)
|
||||
pub async fn convert_to_sdk_body(body: AsyncBody, handle: Handle) -> SdkBody {
|
||||
match body.0 {
|
||||
Inner::Empty => SdkBody::empty(),
|
||||
Inner::Bytes(bytes) => SdkBody::from(bytes.into_inner()),
|
||||
Inner::AsyncReader(mut reader) => {
|
||||
let buffer = handle.spawn(async move {
|
||||
let mut buffer = Vec::new();
|
||||
let _ = reader.read_to_end(&mut buffer).await;
|
||||
buffer
|
||||
});
|
||||
|
||||
SdkBody::from(buffer.await.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_to_async_body(body: SdkBody) -> AsyncBody {
|
||||
|
||||
@@ -343,7 +343,8 @@ impl BufferDiffInner {
|
||||
..
|
||||
} in hunks.iter().cloned()
|
||||
{
|
||||
let preceding_pending_hunks = old_pending_hunks.slice(&buffer_range.start, Bias::Left);
|
||||
let preceding_pending_hunks =
|
||||
old_pending_hunks.slice(&buffer_range.start, Bias::Left, buffer);
|
||||
pending_hunks.append(preceding_pending_hunks, buffer);
|
||||
|
||||
// Skip all overlapping or adjacent old pending hunks
|
||||
@@ -354,7 +355,7 @@ impl BufferDiffInner {
|
||||
.cmp(&buffer_range.end, buffer)
|
||||
.is_le()
|
||||
}) {
|
||||
old_pending_hunks.next();
|
||||
old_pending_hunks.next(buffer);
|
||||
}
|
||||
|
||||
if (stage && secondary_status == DiffHunkSecondaryStatus::NoSecondaryHunk)
|
||||
@@ -378,10 +379,10 @@ impl BufferDiffInner {
|
||||
);
|
||||
}
|
||||
// append the remainder
|
||||
pending_hunks.append(old_pending_hunks.suffix(), buffer);
|
||||
pending_hunks.append(old_pending_hunks.suffix(buffer), buffer);
|
||||
|
||||
let mut unstaged_hunk_cursor = unstaged_diff.hunks.cursor::<DiffHunkSummary>(buffer);
|
||||
unstaged_hunk_cursor.next();
|
||||
unstaged_hunk_cursor.next(buffer);
|
||||
|
||||
// then, iterate over all pending hunks (both new ones and the existing ones) and compute the edits
|
||||
let mut prev_unstaged_hunk_buffer_end = 0;
|
||||
@@ -396,7 +397,8 @@ impl BufferDiffInner {
|
||||
}) = pending_hunks_iter.next()
|
||||
{
|
||||
// Advance unstaged_hunk_cursor to skip unstaged hunks before current hunk
|
||||
let skipped_unstaged = unstaged_hunk_cursor.slice(&buffer_range.start, Bias::Left);
|
||||
let skipped_unstaged =
|
||||
unstaged_hunk_cursor.slice(&buffer_range.start, Bias::Left, buffer);
|
||||
|
||||
if let Some(unstaged_hunk) = skipped_unstaged.last() {
|
||||
prev_unstaged_hunk_base_text_end = unstaged_hunk.diff_base_byte_range.end;
|
||||
@@ -423,7 +425,7 @@ impl BufferDiffInner {
|
||||
buffer_offset_range.end =
|
||||
buffer_offset_range.end.max(unstaged_hunk_offset_range.end);
|
||||
|
||||
unstaged_hunk_cursor.next();
|
||||
unstaged_hunk_cursor.next(buffer);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -512,7 +514,7 @@ impl BufferDiffInner {
|
||||
});
|
||||
|
||||
let anchor_iter = iter::from_fn(move || {
|
||||
cursor.next();
|
||||
cursor.next(buffer);
|
||||
cursor.item()
|
||||
})
|
||||
.flat_map(move |hunk| {
|
||||
@@ -529,12 +531,12 @@ impl BufferDiffInner {
|
||||
});
|
||||
|
||||
let mut pending_hunks_cursor = self.pending_hunks.cursor::<DiffHunkSummary>(buffer);
|
||||
pending_hunks_cursor.next();
|
||||
pending_hunks_cursor.next(buffer);
|
||||
|
||||
let mut secondary_cursor = None;
|
||||
if let Some(secondary) = secondary.as_ref() {
|
||||
let mut cursor = secondary.hunks.cursor::<DiffHunkSummary>(buffer);
|
||||
cursor.next();
|
||||
cursor.next(buffer);
|
||||
secondary_cursor = Some(cursor);
|
||||
}
|
||||
|
||||
@@ -562,7 +564,7 @@ impl BufferDiffInner {
|
||||
.cmp(&pending_hunks_cursor.start().buffer_range.start, buffer)
|
||||
.is_gt()
|
||||
{
|
||||
pending_hunks_cursor.seek_forward(&start_anchor, Bias::Left);
|
||||
pending_hunks_cursor.seek_forward(&start_anchor, Bias::Left, buffer);
|
||||
}
|
||||
|
||||
if let Some(pending_hunk) = pending_hunks_cursor.item() {
|
||||
@@ -588,7 +590,7 @@ impl BufferDiffInner {
|
||||
.cmp(&secondary_cursor.start().buffer_range.start, buffer)
|
||||
.is_gt()
|
||||
{
|
||||
secondary_cursor.seek_forward(&start_anchor, Bias::Left);
|
||||
secondary_cursor.seek_forward(&start_anchor, Bias::Left, buffer);
|
||||
}
|
||||
|
||||
if let Some(secondary_hunk) = secondary_cursor.item() {
|
||||
@@ -633,7 +635,7 @@ impl BufferDiffInner {
|
||||
});
|
||||
|
||||
iter::from_fn(move || {
|
||||
cursor.prev();
|
||||
cursor.prev(buffer);
|
||||
|
||||
let hunk = cursor.item()?;
|
||||
let range = hunk.buffer_range.to_point(buffer);
|
||||
@@ -651,8 +653,8 @@ impl BufferDiffInner {
|
||||
fn compare(&self, old: &Self, new_snapshot: &text::BufferSnapshot) -> Option<Range<Anchor>> {
|
||||
let mut new_cursor = self.hunks.cursor::<()>(new_snapshot);
|
||||
let mut old_cursor = old.hunks.cursor::<()>(new_snapshot);
|
||||
old_cursor.next();
|
||||
new_cursor.next();
|
||||
old_cursor.next(new_snapshot);
|
||||
new_cursor.next(new_snapshot);
|
||||
let mut start = None;
|
||||
let mut end = None;
|
||||
|
||||
@@ -667,7 +669,7 @@ impl BufferDiffInner {
|
||||
Ordering::Less => {
|
||||
start.get_or_insert(new_hunk.buffer_range.start);
|
||||
end.replace(new_hunk.buffer_range.end);
|
||||
new_cursor.next();
|
||||
new_cursor.next(new_snapshot);
|
||||
}
|
||||
Ordering::Equal => {
|
||||
if new_hunk != old_hunk {
|
||||
@@ -684,25 +686,25 @@ impl BufferDiffInner {
|
||||
}
|
||||
}
|
||||
|
||||
new_cursor.next();
|
||||
old_cursor.next();
|
||||
new_cursor.next(new_snapshot);
|
||||
old_cursor.next(new_snapshot);
|
||||
}
|
||||
Ordering::Greater => {
|
||||
start.get_or_insert(old_hunk.buffer_range.start);
|
||||
end.replace(old_hunk.buffer_range.end);
|
||||
old_cursor.next();
|
||||
old_cursor.next(new_snapshot);
|
||||
}
|
||||
}
|
||||
}
|
||||
(Some(new_hunk), None) => {
|
||||
start.get_or_insert(new_hunk.buffer_range.start);
|
||||
end.replace(new_hunk.buffer_range.end);
|
||||
new_cursor.next();
|
||||
new_cursor.next(new_snapshot);
|
||||
}
|
||||
(None, Some(old_hunk)) => {
|
||||
start.get_or_insert(old_hunk.buffer_range.start);
|
||||
end.replace(old_hunk.buffer_range.end);
|
||||
old_cursor.next();
|
||||
old_cursor.next(new_snapshot);
|
||||
}
|
||||
(None, None) => break,
|
||||
}
|
||||
|
||||
@@ -333,7 +333,7 @@ impl ChannelChat {
|
||||
if first_id <= message_id {
|
||||
let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>(&());
|
||||
let message_id = ChannelMessageId::Saved(message_id);
|
||||
cursor.seek(&message_id, Bias::Left);
|
||||
cursor.seek(&message_id, Bias::Left, &());
|
||||
return ControlFlow::Break(
|
||||
if cursor
|
||||
.item()
|
||||
@@ -499,7 +499,7 @@ impl ChannelChat {
|
||||
|
||||
pub fn message(&self, ix: usize) -> &ChannelMessage {
|
||||
let mut cursor = self.messages.cursor::<Count>(&());
|
||||
cursor.seek(&Count(ix), Bias::Right);
|
||||
cursor.seek(&Count(ix), Bias::Right, &());
|
||||
cursor.item().unwrap()
|
||||
}
|
||||
|
||||
@@ -516,13 +516,13 @@ impl ChannelChat {
|
||||
|
||||
pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
|
||||
let mut cursor = self.messages.cursor::<Count>(&());
|
||||
cursor.seek(&Count(range.start), Bias::Right);
|
||||
cursor.seek(&Count(range.start), Bias::Right, &());
|
||||
cursor.take(range.len())
|
||||
}
|
||||
|
||||
pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
|
||||
let mut cursor = self.messages.cursor::<ChannelMessageId>(&());
|
||||
cursor.seek(&ChannelMessageId::Pending(0), Bias::Left);
|
||||
cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
|
||||
cursor
|
||||
}
|
||||
|
||||
@@ -588,9 +588,9 @@ impl ChannelChat {
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>(&());
|
||||
let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left);
|
||||
let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
|
||||
let start_ix = old_cursor.start().1.0;
|
||||
let removed_messages = old_cursor.slice(&last_message.id, Bias::Right);
|
||||
let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
|
||||
let removed_count = removed_messages.summary().count;
|
||||
let new_count = messages.summary().count;
|
||||
let end_ix = start_ix + removed_count;
|
||||
@@ -599,10 +599,10 @@ impl ChannelChat {
|
||||
|
||||
let mut ranges = Vec::<Range<usize>>::new();
|
||||
if new_messages.last().unwrap().is_pending() {
|
||||
new_messages.append(old_cursor.suffix(), &());
|
||||
new_messages.append(old_cursor.suffix(&()), &());
|
||||
} else {
|
||||
new_messages.append(
|
||||
old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left),
|
||||
old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
|
||||
&(),
|
||||
);
|
||||
|
||||
@@ -617,7 +617,7 @@ impl ChannelChat {
|
||||
} else {
|
||||
new_messages.push(message.clone(), &());
|
||||
}
|
||||
old_cursor.next();
|
||||
old_cursor.next(&());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -641,12 +641,12 @@ impl ChannelChat {
|
||||
|
||||
fn message_removed(&mut self, id: u64, cx: &mut Context<Self>) {
|
||||
let mut cursor = self.messages.cursor::<ChannelMessageId>(&());
|
||||
let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left);
|
||||
let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
|
||||
if let Some(item) = cursor.item() {
|
||||
if item.id == ChannelMessageId::Saved(id) {
|
||||
let deleted_message_ix = messages.summary().count;
|
||||
cursor.next();
|
||||
messages.append(cursor.suffix(), &());
|
||||
cursor.next(&());
|
||||
messages.append(cursor.suffix(&()), &());
|
||||
drop(cursor);
|
||||
self.messages = messages;
|
||||
|
||||
@@ -680,7 +680,7 @@ impl ChannelChat {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let mut cursor = self.messages.cursor::<ChannelMessageId>(&());
|
||||
let mut messages = cursor.slice(&id, Bias::Left);
|
||||
let mut messages = cursor.slice(&id, Bias::Left, &());
|
||||
let ix = messages.summary().count;
|
||||
|
||||
if let Some(mut message_to_update) = cursor.item().cloned() {
|
||||
@@ -688,10 +688,10 @@ impl ChannelChat {
|
||||
message_to_update.mentions = mentions;
|
||||
message_to_update.edited_at = edited_at;
|
||||
messages.push(message_to_update, &());
|
||||
cursor.next();
|
||||
cursor.next(&());
|
||||
}
|
||||
|
||||
messages.append(cursor.suffix(), &());
|
||||
messages.append(cursor.suffix(&()), &());
|
||||
drop(cursor);
|
||||
self.messages = messages;
|
||||
|
||||
|
||||
@@ -151,7 +151,6 @@ impl Settings for ProxySettings {
|
||||
|
||||
pub fn init_settings(cx: &mut App) {
|
||||
TelemetrySettings::register(cx);
|
||||
DisableAiSettings::register(cx);
|
||||
ClientSettings::register(cx);
|
||||
ProxySettings::register(cx);
|
||||
}
|
||||
@@ -549,33 +548,6 @@ impl settings::Settings for TelemetrySettings {
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether to disable all AI features in Zed.
|
||||
///
|
||||
/// Default: false
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct DisableAiSettings {
|
||||
pub disable_ai: bool,
|
||||
}
|
||||
|
||||
impl settings::Settings for DisableAiSettings {
|
||||
const KEY: Option<&'static str> = Some("disable_ai");
|
||||
|
||||
type FileContent = Option<bool>;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
Ok(Self {
|
||||
disable_ai: sources
|
||||
.user
|
||||
.or(sources.server)
|
||||
.copied()
|
||||
.flatten()
|
||||
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn new(
|
||||
clock: Arc<dyn SystemClock>,
|
||||
|
||||
@@ -358,13 +358,13 @@ impl Telemetry {
|
||||
worktree_id: WorktreeId,
|
||||
updated_entries_set: &UpdatedEntriesSet,
|
||||
) {
|
||||
let Some(project_types) = self.detect_project_types(worktree_id, updated_entries_set)
|
||||
let Some(project_type_names) = self.detect_project_types(worktree_id, updated_entries_set)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
for project_type in project_types {
|
||||
telemetry::event!("Project Opened", project_type = project_type);
|
||||
for project_type_name in project_type_names {
|
||||
telemetry::event!("Project Opened", project_type = project_type_name);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -765,14 +765,12 @@ impl UserStore {
|
||||
|
||||
pub fn current_plan(&self) -> Option<proto::Plan> {
|
||||
#[cfg(debug_assertions)]
|
||||
if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() {
|
||||
if let Ok(plan) = std::env::var("ZED_SIMULATE_ZED_PRO_PLAN").as_ref() {
|
||||
return match plan.as_str() {
|
||||
"free" => Some(proto::Plan::Free),
|
||||
"trial" => Some(proto::Plan::ZedProTrial),
|
||||
"pro" => Some(proto::Plan::ZedPro),
|
||||
_ => {
|
||||
panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'");
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -11,9 +11,7 @@ use crate::{
|
||||
db::{User, UserId},
|
||||
rpc,
|
||||
};
|
||||
use ::rpc::proto;
|
||||
use anyhow::Context as _;
|
||||
use axum::extract;
|
||||
use axum::{
|
||||
Extension, Json, Router,
|
||||
body::Body,
|
||||
@@ -25,7 +23,6 @@ use axum::{
|
||||
routing::{get, post},
|
||||
};
|
||||
use axum_extra::response::ErasedJson;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use tower::ServiceBuilder;
|
||||
@@ -104,8 +101,8 @@ pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
|
||||
.route("/users/look_up", get(look_up_user))
|
||||
.route("/users/:id/access_tokens", post(create_access_token))
|
||||
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
|
||||
.route("/users/:id/update_plan", post(update_plan))
|
||||
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
|
||||
.merge(billing::router())
|
||||
.merge(contributors::router())
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
@@ -350,78 +347,3 @@ async fn refresh_llm_tokens(
|
||||
|
||||
Ok(Json(RefreshLlmTokensResponse {}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct UpdatePlanBody {
|
||||
pub plan: zed_llm_client::Plan,
|
||||
pub subscription_period: SubscriptionPeriod,
|
||||
pub usage: zed_llm_client::CurrentUsage,
|
||||
pub trial_started_at: Option<DateTime<Utc>>,
|
||||
pub is_usage_based_billing_enabled: bool,
|
||||
pub is_account_too_young: bool,
|
||||
pub has_overdue_invoices: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||
struct SubscriptionPeriod {
|
||||
pub started_at: DateTime<Utc>,
|
||||
pub ended_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct UpdatePlanResponse {}
|
||||
|
||||
async fn update_plan(
|
||||
Path(user_id): Path<UserId>,
|
||||
Extension(rpc_server): Extension<Arc<rpc::Server>>,
|
||||
extract::Json(body): extract::Json<UpdatePlanBody>,
|
||||
) -> Result<Json<UpdatePlanResponse>> {
|
||||
let plan = match body.plan {
|
||||
zed_llm_client::Plan::ZedFree => proto::Plan::Free,
|
||||
zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
|
||||
zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
let update_user_plan = proto::UpdateUserPlan {
|
||||
plan: plan.into(),
|
||||
trial_started_at: body
|
||||
.trial_started_at
|
||||
.map(|trial_started_at| trial_started_at.timestamp() as u64),
|
||||
is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
|
||||
usage: Some(proto::SubscriptionUsage {
|
||||
model_requests_usage_amount: body.usage.model_requests.used,
|
||||
model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
|
||||
edit_predictions_usage_amount: body.usage.edit_predictions.used,
|
||||
edit_predictions_usage_limit: Some(usage_limit_to_proto(
|
||||
body.usage.edit_predictions.limit,
|
||||
)),
|
||||
}),
|
||||
subscription_period: Some(proto::SubscriptionPeriod {
|
||||
started_at: body.subscription_period.started_at.timestamp() as u64,
|
||||
ended_at: body.subscription_period.ended_at.timestamp() as u64,
|
||||
}),
|
||||
account_too_young: Some(body.is_account_too_young),
|
||||
has_overdue_invoices: Some(body.has_overdue_invoices),
|
||||
};
|
||||
|
||||
rpc_server
|
||||
.update_plan_for_user(user_id, update_user_plan)
|
||||
.await?;
|
||||
|
||||
Ok(Json(UpdatePlanResponse {}))
|
||||
}
|
||||
|
||||
fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit {
|
||||
proto::UsageLimit {
|
||||
variant: Some(match limit {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,23 @@
|
||||
use anyhow::{Context as _, bail};
|
||||
use axum::{Extension, Json, Router, extract, routing::post};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{HashMap, HashSet};
|
||||
use reqwest::StatusCode;
|
||||
use sea_orm::ActiveValue;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{str::FromStr, sync::Arc, time::Duration};
|
||||
use stripe::{
|
||||
BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
|
||||
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
|
||||
CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
|
||||
PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
};
|
||||
use util::{ResultExt, maybe};
|
||||
use zed_llm_client::LanguageModelProvider;
|
||||
|
||||
use crate::AppState;
|
||||
use crate::db::billing_subscription::{
|
||||
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
|
||||
};
|
||||
@@ -15,18 +25,450 @@ use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
|
||||
use crate::rpc::{ResultExt as _, Server};
|
||||
use crate::stripe_client::{
|
||||
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
|
||||
StripeSubscriptionId,
|
||||
StripeSubscriptionId, UpdateCustomerParams,
|
||||
};
|
||||
use crate::{AppState, Error, Result};
|
||||
use crate::{db::UserId, llm::db::LlmDatabase};
|
||||
use crate::{
|
||||
db::{
|
||||
CreateBillingCustomerParams, CreateBillingSubscriptionParams,
|
||||
BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
|
||||
CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
|
||||
UpdateBillingSubscriptionParams, billing_customer,
|
||||
},
|
||||
stripe_billing::StripeBilling,
|
||||
};
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route("/billing/subscriptions", post(create_billing_subscription))
|
||||
.route(
|
||||
"/billing/subscriptions/manage",
|
||||
post(manage_billing_subscription),
|
||||
)
|
||||
.route(
|
||||
"/billing/subscriptions/sync",
|
||||
post(sync_billing_subscription),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ProductCode {
|
||||
ZedPro,
|
||||
ZedProTrial,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CreateBillingSubscriptionBody {
|
||||
github_user_id: i32,
|
||||
product: ProductCode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct CreateBillingSubscriptionResponse {
|
||||
checkout_session_url: String,
|
||||
}
|
||||
|
||||
/// Initiates a Stripe Checkout session for creating a billing subscription.
|
||||
async fn create_billing_subscription(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
|
||||
) -> Result<Json<CreateBillingSubscriptionResponse>> {
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(body.github_user_id)
|
||||
.await?
|
||||
.context("user not found")?;
|
||||
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::error!("failed to retrieve Stripe billing object");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
|
||||
if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
|
||||
let is_checkout_allowed = body.product == ProductCode::ZedProTrial
|
||||
&& existing_subscription.kind == Some(SubscriptionKind::ZedFree);
|
||||
|
||||
if !is_checkout_allowed {
|
||||
return Err(Error::http(
|
||||
StatusCode::CONFLICT,
|
||||
"user already has an active subscription".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
|
||||
if let Some(existing_billing_customer) = &existing_billing_customer {
|
||||
if existing_billing_customer.has_overdue_invoices {
|
||||
return Err(Error::http(
|
||||
StatusCode::PAYMENT_REQUIRED,
|
||||
"user has overdue invoices".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
|
||||
let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into());
|
||||
if let Some(email) = user.email_address.as_deref() {
|
||||
stripe_billing
|
||||
.client()
|
||||
.update_customer(&customer_id, UpdateCustomerParams { email: Some(email) })
|
||||
.await
|
||||
// Update of email address is best-effort - continue checkout even if it fails
|
||||
.context("error updating stripe customer email address")
|
||||
.log_err();
|
||||
}
|
||||
customer_id
|
||||
} else {
|
||||
stripe_billing
|
||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||
.await?
|
||||
};
|
||||
|
||||
let success_url = format!(
|
||||
"{}/account?checkout_complete=1",
|
||||
app.config.zed_dot_dev_url()
|
||||
);
|
||||
|
||||
let checkout_session_url = match body.product {
|
||||
ProductCode::ZedPro => {
|
||||
stripe_billing
|
||||
.checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
|
||||
.await?
|
||||
}
|
||||
ProductCode::ZedProTrial => {
|
||||
if let Some(existing_billing_customer) = &existing_billing_customer {
|
||||
if existing_billing_customer.trial_started_at.is_some() {
|
||||
return Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"user already used free trial".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let feature_flags = app.db.get_user_flags(user.id).await?;
|
||||
|
||||
stripe_billing
|
||||
.checkout_with_zed_pro_trial(
|
||||
&customer_id,
|
||||
&user.github_login,
|
||||
feature_flags,
|
||||
&success_url,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(CreateBillingSubscriptionResponse {
|
||||
checkout_session_url,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ManageSubscriptionIntent {
|
||||
/// The user intends to manage their subscription.
|
||||
///
|
||||
/// This will open the Stripe billing portal without putting the user in a specific flow.
|
||||
ManageSubscription,
|
||||
/// The user intends to update their payment method.
|
||||
UpdatePaymentMethod,
|
||||
/// The user intends to upgrade to Zed Pro.
|
||||
UpgradeToPro,
|
||||
/// The user intends to cancel their subscription.
|
||||
Cancel,
|
||||
/// The user intends to stop the cancellation of their subscription.
|
||||
StopCancellation,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ManageBillingSubscriptionBody {
|
||||
github_user_id: i32,
|
||||
intent: ManageSubscriptionIntent,
|
||||
/// The ID of the subscription to manage.
|
||||
subscription_id: BillingSubscriptionId,
|
||||
redirect_to: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ManageBillingSubscriptionResponse {
|
||||
billing_portal_session_url: Option<String>,
|
||||
}
|
||||
|
||||
/// Initiates a Stripe customer portal session for managing a billing subscription.
|
||||
async fn manage_billing_subscription(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
|
||||
) -> Result<Json<ManageBillingSubscriptionResponse>> {
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(body.github_user_id)
|
||||
.await?
|
||||
.context("user not found")?;
|
||||
|
||||
let Some(stripe_client) = app.real_stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::error!("failed to retrieve Stripe billing object");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
|
||||
let customer = app
|
||||
.db
|
||||
.get_billing_customer_by_user_id(user.id)
|
||||
.await?
|
||||
.context("billing customer not found")?;
|
||||
let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
|
||||
.context("failed to parse customer ID")?;
|
||||
|
||||
let subscription = app
|
||||
.db
|
||||
.get_billing_subscription_by_id(body.subscription_id)
|
||||
.await?
|
||||
.context("subscription not found")?;
|
||||
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
|
||||
.context("failed to parse subscription ID")?;
|
||||
|
||||
if body.intent == ManageSubscriptionIntent::StopCancellation {
|
||||
let updated_stripe_subscription = Subscription::update(
|
||||
&stripe_client,
|
||||
&subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
cancel_at_period_end: Some(false),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
app.db
|
||||
.update_billing_subscription(
|
||||
subscription.id,
|
||||
&UpdateBillingSubscriptionParams {
|
||||
stripe_cancel_at: ActiveValue::set(
|
||||
updated_stripe_subscription
|
||||
.cancel_at
|
||||
.and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
|
||||
.map(|time| time.naive_utc()),
|
||||
),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Ok(Json(ManageBillingSubscriptionResponse {
|
||||
billing_portal_session_url: None,
|
||||
}));
|
||||
}
|
||||
|
||||
let flow = match body.intent {
|
||||
ManageSubscriptionIntent::ManageSubscription => None,
|
||||
ManageSubscriptionIntent::UpgradeToPro => {
|
||||
let zed_pro_price_id: stripe::PriceId =
|
||||
stripe_billing.zed_pro_price_id().await?.try_into()?;
|
||||
let zed_free_price_id: stripe::PriceId =
|
||||
stripe_billing.zed_free_price_id().await?.try_into()?;
|
||||
|
||||
let stripe_subscription =
|
||||
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
|
||||
|
||||
let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
|
||||
&& stripe_subscription.items.data.iter().any(|item| {
|
||||
item.price
|
||||
.as_ref()
|
||||
.map_or(false, |price| price.id == zed_pro_price_id)
|
||||
});
|
||||
if is_on_zed_pro_trial {
|
||||
let payment_methods = PaymentMethod::list(
|
||||
&stripe_client,
|
||||
&stripe::ListPaymentMethods {
|
||||
customer: Some(stripe_subscription.customer.id()),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let has_payment_method = !payment_methods.data.is_empty();
|
||||
if !has_payment_method {
|
||||
return Err(Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"missing payment method".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
|
||||
Subscription::update(
|
||||
&stripe_client,
|
||||
&stripe_subscription.id,
|
||||
stripe::UpdateSubscription {
|
||||
trial_end: Some(stripe::Scheduled::now()),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Ok(Json(ManageBillingSubscriptionResponse {
|
||||
billing_portal_session_url: None,
|
||||
}));
|
||||
}
|
||||
|
||||
let subscription_item_to_update = stripe_subscription
|
||||
.items
|
||||
.data
|
||||
.iter()
|
||||
.find_map(|item| {
|
||||
let price = item.price.as_ref()?;
|
||||
|
||||
if price.id == zed_free_price_id {
|
||||
Some(item.id.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.context("No subscription item to update")?;
|
||||
|
||||
Some(CreateBillingPortalSessionFlowData {
|
||||
type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
|
||||
subscription_update_confirm: Some(
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
|
||||
subscription: subscription.stripe_subscription_id,
|
||||
items: vec![
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
|
||||
id: subscription_item_to_update.to_string(),
|
||||
price: Some(zed_pro_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
},
|
||||
],
|
||||
discounts: None,
|
||||
},
|
||||
),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
|
||||
type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
|
||||
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
|
||||
type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
|
||||
redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
|
||||
return_url: format!(
|
||||
"{}{path}",
|
||||
app.config.zed_dot_dev_url(),
|
||||
path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
|
||||
),
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
ManageSubscriptionIntent::Cancel => {
|
||||
if subscription.kind == Some(SubscriptionKind::ZedFree) {
|
||||
return Err(Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"free subscription cannot be canceled".into(),
|
||||
));
|
||||
}
|
||||
|
||||
Some(CreateBillingPortalSessionFlowData {
|
||||
type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
|
||||
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
|
||||
type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
|
||||
redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
|
||||
return_url: format!("{}/account", app.config.zed_dot_dev_url()),
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
subscription_cancel: Some(
|
||||
stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
|
||||
subscription: subscription.stripe_subscription_id,
|
||||
retention: None,
|
||||
},
|
||||
),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
ManageSubscriptionIntent::StopCancellation => unreachable!(),
|
||||
};
|
||||
|
||||
let mut params = CreateBillingPortalSession::new(customer_id);
|
||||
params.flow_data = flow;
|
||||
let return_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||
params.return_url = Some(&return_url);
|
||||
|
||||
let session = BillingPortalSession::create(&stripe_client, params).await?;
|
||||
|
||||
Ok(Json(ManageBillingSubscriptionResponse {
|
||||
billing_portal_session_url: Some(session.url),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SyncBillingSubscriptionBody {
|
||||
github_user_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct SyncBillingSubscriptionResponse {
|
||||
stripe_customer_id: String,
|
||||
}
|
||||
|
||||
async fn sync_billing_subscription(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
|
||||
) -> Result<Json<SyncBillingSubscriptionResponse>> {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(body.github_user_id)
|
||||
.await?
|
||||
.context("user not found")?;
|
||||
|
||||
let billing_customer = app
|
||||
.db
|
||||
.get_billing_customer_by_user_id(user.id)
|
||||
.await?
|
||||
.context("billing customer not found")?;
|
||||
let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||
|
||||
let subscriptions = stripe_client
|
||||
.list_subscriptions_for_customer(&stripe_customer_id)
|
||||
.await?;
|
||||
|
||||
for subscription in subscriptions {
|
||||
let subscription_id = subscription.id.clone();
|
||||
|
||||
sync_subscription(&app, &stripe_client, subscription)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to sync subscription {subscription_id} for user {}",
|
||||
user.id,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(Json(SyncBillingSubscriptionResponse {
|
||||
stripe_customer_id: billing_customer.stripe_customer_id.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// The amount of time we wait in between each poll of Stripe events.
|
||||
///
|
||||
/// This value should strike a balance between:
|
||||
@@ -460,7 +902,7 @@ async fn handle_customer_subscription_event(
|
||||
|
||||
// When the user's subscription changes, push down any changes to their plan.
|
||||
rpc_server
|
||||
.update_plan_for_user_legacy(billing_customer.user_id)
|
||||
.update_plan_for_user(billing_customer.user_id)
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
|
||||
@@ -433,8 +433,6 @@ impl Server {
|
||||
.add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::Stage>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::Unstage>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::Stash>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::StashPop>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::Commit>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::GitInit>)
|
||||
.add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
|
||||
@@ -831,7 +829,7 @@ impl Server {
|
||||
// This arrangement ensures we will attempt to process earlier messages first, but fall
|
||||
// back to processing messages arrived later in the spirit of making progress.
|
||||
let mut foreground_message_handlers = FuturesUnordered::new();
|
||||
let concurrent_handlers = Arc::new(Semaphore::new(512));
|
||||
let concurrent_handlers = Arc::new(Semaphore::new(256));
|
||||
loop {
|
||||
let next_message = async {
|
||||
let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
|
||||
@@ -1004,26 +1002,7 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_plan_for_user(
|
||||
self: &Arc<Self>,
|
||||
user_id: UserId,
|
||||
update_user_plan: proto::UpdateUserPlan,
|
||||
) -> Result<()> {
|
||||
let pool = self.connection_pool.lock();
|
||||
for connection_id in pool.user_connection_ids(user_id) {
|
||||
self.peer
|
||||
.send(connection_id, update_user_plan.clone())
|
||||
.trace_err();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
|
||||
/// message on the Collab server.
|
||||
///
|
||||
/// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
|
||||
pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
|
||||
pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
|
||||
let user = self
|
||||
.app_state
|
||||
.db
|
||||
@@ -1039,7 +1018,14 @@ impl Server {
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.update_plan_for_user(user_id, update_user_plan).await
|
||||
let pool = self.connection_pool.lock();
|
||||
for connection_id in pool.user_connection_ids(user_id) {
|
||||
self.peer
|
||||
.send(connection_id, update_user_plan.clone())
|
||||
.trace_err();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use chrono::Utc;
|
||||
use collections::HashMap;
|
||||
use stripe::SubscriptionStatus;
|
||||
@@ -9,13 +9,18 @@ use uuid::Uuid;
|
||||
|
||||
use crate::Result;
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||
use crate::stripe_client::{
|
||||
RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateMeterEventParams,
|
||||
RealStripeClient, StripeAutomaticTax, StripeBillingAddressCollection,
|
||||
StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
|
||||
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
|
||||
StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
|
||||
StripeCustomerId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
|
||||
StripeCustomerId, StripeCustomerUpdate, StripeCustomerUpdateAddress, StripeCustomerUpdateName,
|
||||
StripeMeter, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
|
||||
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
|
||||
UpdateSubscriptionParams,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection,
|
||||
UpdateSubscriptionItems, UpdateSubscriptionParams,
|
||||
};
|
||||
|
||||
pub struct StripeBilling {
|
||||
@@ -25,6 +30,8 @@ pub struct StripeBilling {
|
||||
|
||||
#[derive(Default)]
|
||||
struct StripeBillingState {
|
||||
meters_by_event_name: HashMap<String, StripeMeter>,
|
||||
price_ids_by_meter_id: HashMap<String, StripePriceId>,
|
||||
prices_by_lookup_key: HashMap<String, StripePrice>,
|
||||
}
|
||||
|
||||
@@ -53,11 +60,24 @@ impl StripeBilling {
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
let prices = self.client.list_prices().await?;
|
||||
let (meters, prices) =
|
||||
futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
|
||||
|
||||
for meter in meters {
|
||||
state
|
||||
.meters_by_event_name
|
||||
.insert(meter.event_name.clone(), meter);
|
||||
}
|
||||
|
||||
for price in prices {
|
||||
if let Some(lookup_key) = price.lookup_key.clone() {
|
||||
state.prices_by_lookup_key.insert(lookup_key, price);
|
||||
state.prices_by_lookup_key.insert(lookup_key, price.clone());
|
||||
}
|
||||
|
||||
if let Some(recurring) = price.recurring {
|
||||
if let Some(meter) = recurring.meter {
|
||||
state.price_ids_by_meter_id.insert(meter, price.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,6 +229,95 @@ impl StripeBilling {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn checkout_with_zed_pro(
|
||||
&self,
|
||||
customer_id: &StripeCustomerId,
|
||||
github_login: &str,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let zed_pro_price_id = self.zed_pro_price_id().await?;
|
||||
|
||||
let mut params = StripeCreateCheckoutSessionParams::default();
|
||||
params.mode = Some(StripeCheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||
price: Some(zed_pro_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
}]);
|
||||
params.success_url = Some(success_url);
|
||||
params.billing_address_collection = Some(StripeBillingAddressCollection::Required);
|
||||
params.customer_update = Some(StripeCustomerUpdate {
|
||||
address: Some(StripeCustomerUpdateAddress::Auto),
|
||||
name: Some(StripeCustomerUpdateName::Auto),
|
||||
shipping: None,
|
||||
});
|
||||
params.tax_id_collection = Some(StripeTaxIdCollection { enabled: true });
|
||||
|
||||
let session = self.client.create_checkout_session(params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
pub async fn checkout_with_zed_pro_trial(
|
||||
&self,
|
||||
customer_id: &StripeCustomerId,
|
||||
github_login: &str,
|
||||
feature_flags: Vec<String>,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let zed_pro_price_id = self.zed_pro_price_id().await?;
|
||||
|
||||
let eligible_for_extended_trial = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
|
||||
|
||||
let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
|
||||
|
||||
let mut subscription_metadata = std::collections::HashMap::new();
|
||||
if eligible_for_extended_trial {
|
||||
subscription_metadata.insert(
|
||||
"promo_feature_flag".to_string(),
|
||||
AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let mut params = StripeCreateCheckoutSessionParams::default();
|
||||
params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
|
||||
trial_period_days: Some(trial_period_days),
|
||||
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||
missing_payment_method:
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||
},
|
||||
}),
|
||||
metadata: if !subscription_metadata.is_empty() {
|
||||
Some(subscription_metadata)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
});
|
||||
params.mode = Some(StripeCheckoutSessionMode::Subscription);
|
||||
params.payment_method_collection =
|
||||
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||
price: Some(zed_pro_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
}]);
|
||||
params.success_url = Some(success_url);
|
||||
params.billing_address_collection = Some(StripeBillingAddressCollection::Required);
|
||||
params.customer_update = Some(StripeCustomerUpdate {
|
||||
address: Some(StripeCustomerUpdateAddress::Auto),
|
||||
name: Some(StripeCustomerUpdateName::Auto),
|
||||
shipping: None,
|
||||
});
|
||||
params.tax_id_collection = Some(StripeTaxIdCollection { enabled: true });
|
||||
|
||||
let session = self.client.create_checkout_session(params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_zed_free(
|
||||
&self,
|
||||
customer_id: StripeCustomerId,
|
||||
|
||||
@@ -3,11 +3,17 @@ use std::sync::Arc;
|
||||
use chrono::{Duration, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
use crate::stripe_client::{
|
||||
FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
|
||||
StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
|
||||
StripeSubscriptionItemId, UpdateSubscriptionItems,
|
||||
FakeStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode,
|
||||
StripeCheckoutSessionPaymentMethodCollection, StripeCreateCheckoutSessionLineItems,
|
||||
StripeCreateCheckoutSessionSubscriptionData, StripeCustomerId, StripeCustomerUpdate,
|
||||
StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeMeter, StripeMeterId, StripePrice,
|
||||
StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
|
||||
StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
|
||||
StripeSubscriptionTrialSettingsEndBehavior,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
|
||||
};
|
||||
|
||||
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
||||
@@ -358,3 +364,240 @@ async fn test_bill_model_request_usage() {
|
||||
);
|
||||
assert_eq!(create_meter_event_calls[0].value, 73);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_checkout_with_zed_pro() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
let customer_id = StripeCustomerId("cus_test".into());
|
||||
let github_login = "zeduser1";
|
||||
let success_url = "https://example.com/success";
|
||||
|
||||
// It returns an error when the Zed Pro price doesn't exist.
|
||||
{
|
||||
let result = stripe_billing
|
||||
.checkout_with_zed_pro(&customer_id, github_login, success_url)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(
|
||||
result.err().unwrap().to_string(),
|
||||
r#"no price ID found for "zed-pro""#
|
||||
);
|
||||
}
|
||||
|
||||
// Successful checkout.
|
||||
{
|
||||
let price = StripePrice {
|
||||
id: StripePriceId("price_1".into()),
|
||||
unit_amount: Some(2000),
|
||||
lookup_key: Some("zed-pro".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price.id.clone(), price.clone());
|
||||
|
||||
stripe_billing.initialize().await.unwrap();
|
||||
|
||||
let checkout_url = stripe_billing
|
||||
.checkout_with_zed_pro(&customer_id, github_login, success_url)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
|
||||
|
||||
let create_checkout_session_calls = stripe_client
|
||||
.create_checkout_session_calls
|
||||
.lock()
|
||||
.drain(..)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(create_checkout_session_calls.len(), 1);
|
||||
let call = create_checkout_session_calls.into_iter().next().unwrap();
|
||||
assert_eq!(call.customer, Some(customer_id));
|
||||
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
|
||||
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
|
||||
assert_eq!(
|
||||
call.line_items,
|
||||
Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||
price: Some(price.id.to_string()),
|
||||
quantity: Some(1)
|
||||
}])
|
||||
);
|
||||
assert_eq!(call.payment_method_collection, None);
|
||||
assert_eq!(call.subscription_data, None);
|
||||
assert_eq!(call.success_url.as_deref(), Some(success_url));
|
||||
assert_eq!(
|
||||
call.billing_address_collection,
|
||||
Some(StripeBillingAddressCollection::Required)
|
||||
);
|
||||
assert_eq!(
|
||||
call.customer_update,
|
||||
Some(StripeCustomerUpdate {
|
||||
address: Some(StripeCustomerUpdateAddress::Auto),
|
||||
name: Some(StripeCustomerUpdateName::Auto),
|
||||
shipping: None,
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_checkout_with_zed_pro_trial() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
let customer_id = StripeCustomerId("cus_test".into());
|
||||
let github_login = "zeduser1";
|
||||
let success_url = "https://example.com/success";
|
||||
|
||||
// It returns an error when the Zed Pro price doesn't exist.
|
||||
{
|
||||
let result = stripe_billing
|
||||
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(
|
||||
result.err().unwrap().to_string(),
|
||||
r#"no price ID found for "zed-pro""#
|
||||
);
|
||||
}
|
||||
|
||||
let price = StripePrice {
|
||||
id: StripePriceId("price_1".into()),
|
||||
unit_amount: Some(2000),
|
||||
lookup_key: Some("zed-pro".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price.id.clone(), price.clone());
|
||||
|
||||
stripe_billing.initialize().await.unwrap();
|
||||
|
||||
// Successful checkout.
|
||||
{
|
||||
let checkout_url = stripe_billing
|
||||
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
|
||||
|
||||
let create_checkout_session_calls = stripe_client
|
||||
.create_checkout_session_calls
|
||||
.lock()
|
||||
.drain(..)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(create_checkout_session_calls.len(), 1);
|
||||
let call = create_checkout_session_calls.into_iter().next().unwrap();
|
||||
assert_eq!(call.customer.as_ref(), Some(&customer_id));
|
||||
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
|
||||
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
|
||||
assert_eq!(
|
||||
call.line_items,
|
||||
Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||
price: Some(price.id.to_string()),
|
||||
quantity: Some(1)
|
||||
}])
|
||||
);
|
||||
assert_eq!(
|
||||
call.payment_method_collection,
|
||||
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
|
||||
);
|
||||
assert_eq!(
|
||||
call.subscription_data,
|
||||
Some(StripeCreateCheckoutSessionSubscriptionData {
|
||||
trial_period_days: Some(14),
|
||||
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||
missing_payment_method:
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||
},
|
||||
}),
|
||||
metadata: None,
|
||||
})
|
||||
);
|
||||
assert_eq!(call.success_url.as_deref(), Some(success_url));
|
||||
assert_eq!(
|
||||
call.billing_address_collection,
|
||||
Some(StripeBillingAddressCollection::Required)
|
||||
);
|
||||
assert_eq!(
|
||||
call.customer_update,
|
||||
Some(StripeCustomerUpdate {
|
||||
address: Some(StripeCustomerUpdateAddress::Auto),
|
||||
name: Some(StripeCustomerUpdateName::Auto),
|
||||
shipping: None,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Successful checkout with extended trial.
|
||||
{
|
||||
let checkout_url = stripe_billing
|
||||
.checkout_with_zed_pro_trial(
|
||||
&customer_id,
|
||||
github_login,
|
||||
vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
|
||||
success_url,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
|
||||
|
||||
let create_checkout_session_calls = stripe_client
|
||||
.create_checkout_session_calls
|
||||
.lock()
|
||||
.drain(..)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(create_checkout_session_calls.len(), 1);
|
||||
let call = create_checkout_session_calls.into_iter().next().unwrap();
|
||||
assert_eq!(call.customer, Some(customer_id));
|
||||
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
|
||||
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
|
||||
assert_eq!(
|
||||
call.line_items,
|
||||
Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||
price: Some(price.id.to_string()),
|
||||
quantity: Some(1)
|
||||
}])
|
||||
);
|
||||
assert_eq!(
|
||||
call.payment_method_collection,
|
||||
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
|
||||
);
|
||||
assert_eq!(
|
||||
call.subscription_data,
|
||||
Some(StripeCreateCheckoutSessionSubscriptionData {
|
||||
trial_period_days: Some(60),
|
||||
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||
missing_payment_method:
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||
},
|
||||
}),
|
||||
metadata: Some(std::collections::HashMap::from_iter([(
|
||||
"promo_feature_flag".into(),
|
||||
AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
|
||||
)])),
|
||||
})
|
||||
);
|
||||
assert_eq!(call.success_url.as_deref(), Some(success_url));
|
||||
assert_eq!(
|
||||
call.billing_address_collection,
|
||||
Some(StripeBillingAddressCollection::Required)
|
||||
);
|
||||
assert_eq!(
|
||||
call.customer_update,
|
||||
Some(StripeCustomerUpdate {
|
||||
address: Some(StripeCustomerUpdateAddress::Auto),
|
||||
name: Some(StripeCustomerUpdateName::Auto),
|
||||
shipping: None,
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use futures::{FutureExt, StreamExt, channel::oneshot, future, select};
|
||||
use futures::{FutureExt, StreamExt, channel::oneshot, select};
|
||||
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
|
||||
use parking_lot::Mutex;
|
||||
use postage::barrier;
|
||||
@@ -10,19 +10,15 @@ use smol::channel;
|
||||
use std::{
|
||||
fmt,
|
||||
path::PathBuf,
|
||||
pin::pin,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicI32, Ordering::SeqCst},
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
use util::TryFutureExt;
|
||||
|
||||
use crate::{
|
||||
transport::{StdioTransport, Transport},
|
||||
types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled},
|
||||
};
|
||||
use crate::transport::{StdioTransport, Transport};
|
||||
|
||||
const JSON_RPC_VERSION: &str = "2.0";
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
@@ -36,7 +32,6 @@ pub const INTERNAL_ERROR: i32 = -32603;
|
||||
|
||||
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
|
||||
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
|
||||
type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
@@ -83,15 +78,6 @@ pub struct Request<'a, T> {
|
||||
pub params: T,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AnyRequest<'a> {
|
||||
pub jsonrpc: &'a str,
|
||||
pub id: RequestId,
|
||||
pub method: &'a str,
|
||||
#[serde(skip_serializing_if = "is_null_value")]
|
||||
pub params: Option<&'a RawValue>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AnyResponse<'a> {
|
||||
jsonrpc: &'a str,
|
||||
@@ -190,23 +176,15 @@ impl Client {
|
||||
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
|
||||
let response_handlers =
|
||||
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
|
||||
let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default()));
|
||||
|
||||
let receive_input_task = cx.spawn({
|
||||
let notification_handlers = notification_handlers.clone();
|
||||
let response_handlers = response_handlers.clone();
|
||||
let request_handlers = request_handlers.clone();
|
||||
let transport = transport.clone();
|
||||
async move |cx| {
|
||||
Self::handle_input(
|
||||
transport,
|
||||
notification_handlers,
|
||||
request_handlers,
|
||||
response_handlers,
|
||||
cx,
|
||||
)
|
||||
.log_err()
|
||||
.await
|
||||
Self::handle_input(transport, notification_handlers, response_handlers, cx)
|
||||
.log_err()
|
||||
.await
|
||||
}
|
||||
});
|
||||
let receive_err_task = cx.spawn({
|
||||
@@ -252,24 +230,13 @@ impl Client {
|
||||
async fn handle_input(
|
||||
transport: Arc<dyn Transport>,
|
||||
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
|
||||
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
|
||||
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut receiver = transport.receive();
|
||||
|
||||
while let Some(message) = receiver.next().await {
|
||||
log::trace!("recv: {}", &message);
|
||||
if let Ok(request) = serde_json::from_str::<AnyRequest>(&message) {
|
||||
let mut request_handlers = request_handlers.lock();
|
||||
if let Some(handler) = request_handlers.get_mut(request.method) {
|
||||
handler(
|
||||
request.id,
|
||||
request.params.unwrap_or(RawValue::NULL),
|
||||
cx.clone(),
|
||||
);
|
||||
}
|
||||
} else if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
|
||||
if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
|
||||
if let Some(handlers) = response_handlers.lock().as_mut() {
|
||||
if let Some(handler) = handlers.remove(&response.id) {
|
||||
handler(Ok(message.to_string()));
|
||||
@@ -280,8 +247,6 @@ impl Client {
|
||||
if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
|
||||
handler(notification.params.unwrap_or(Value::Null), cx.clone());
|
||||
}
|
||||
} else {
|
||||
log::error!("Unhandled JSON from context_server: {}", message);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -329,17 +294,6 @@ impl Client {
|
||||
&self,
|
||||
method: &str,
|
||||
params: impl Serialize,
|
||||
) -> Result<T> {
|
||||
self.request_with(method, params, None, Some(REQUEST_TIMEOUT))
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn request_with<T: DeserializeOwned>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: impl Serialize,
|
||||
cancel_rx: Option<oneshot::Receiver<()>>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<T> {
|
||||
let id = self.next_id.fetch_add(1, SeqCst);
|
||||
let request = serde_json::to_string(&Request {
|
||||
@@ -375,23 +329,7 @@ impl Client {
|
||||
handle_response?;
|
||||
send?;
|
||||
|
||||
let mut timeout_fut = pin!(
|
||||
match timeout {
|
||||
Some(timeout) => future::Either::Left(executor.timer(timeout)),
|
||||
None => future::Either::Right(future::pending()),
|
||||
}
|
||||
.fuse()
|
||||
);
|
||||
let mut cancel_fut = pin!(
|
||||
match cancel_rx {
|
||||
Some(rx) => future::Either::Left(async {
|
||||
rx.await.log_err();
|
||||
}),
|
||||
None => future::Either::Right(future::pending()),
|
||||
}
|
||||
.fuse()
|
||||
);
|
||||
|
||||
let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
|
||||
select! {
|
||||
response = rx.fuse() => {
|
||||
let elapsed = started.elapsed();
|
||||
@@ -410,18 +348,8 @@ impl Client {
|
||||
Err(_) => anyhow::bail!("cancelled")
|
||||
}
|
||||
}
|
||||
_ = cancel_fut => {
|
||||
self.notify(
|
||||
Cancelled::METHOD,
|
||||
ClientNotification::Cancelled(CancelledParams {
|
||||
request_id: RequestId::Int(id),
|
||||
reason: None
|
||||
})
|
||||
).log_err();
|
||||
anyhow::bail!(RequestCanceled)
|
||||
}
|
||||
_ = timeout_fut => {
|
||||
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", timeout.unwrap());
|
||||
_ = timeout => {
|
||||
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
|
||||
anyhow::bail!("Context server request timeout");
|
||||
}
|
||||
}
|
||||
@@ -451,17 +379,6 @@ impl Client {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RequestCanceled;
|
||||
|
||||
impl std::error::Error for RequestCanceled {}
|
||||
|
||||
impl std::fmt::Display for RequestCanceled {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("Context server request was canceled")
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ContextServerId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
|
||||
@@ -9,8 +9,6 @@ use futures::{
|
||||
};
|
||||
use gpui::{App, AppContext, AsyncApp, Task};
|
||||
use net::async_net::{UnixListener, UnixStream};
|
||||
use schemars::JsonSchema;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::{json, value::RawValue};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{
|
||||
@@ -22,32 +20,16 @@ use util::ResultExt;
|
||||
|
||||
use crate::{
|
||||
client::{CspResult, RequestId, Response},
|
||||
types::{
|
||||
CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
|
||||
ToolResponseContent,
|
||||
requests::{CallTool, ListTools},
|
||||
},
|
||||
types::Request,
|
||||
};
|
||||
|
||||
pub struct McpServer {
|
||||
socket_path: PathBuf,
|
||||
tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
|
||||
handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
|
||||
handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
|
||||
_server_task: Task<()>,
|
||||
}
|
||||
|
||||
struct RegisteredTool {
|
||||
tool: Tool,
|
||||
handler: ToolHandler,
|
||||
}
|
||||
|
||||
type ToolHandler = Box<
|
||||
dyn Fn(
|
||||
Option<serde_json::Value>,
|
||||
&mut AsyncApp,
|
||||
) -> Task<Result<ToolResponse<serde_json::Value>>>,
|
||||
>;
|
||||
type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
|
||||
type McpHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
|
||||
|
||||
impl McpServer {
|
||||
pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
|
||||
@@ -61,14 +43,12 @@ impl McpServer {
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (temp_dir, socket_path, listener) = task.await?;
|
||||
let tools = Rc::new(RefCell::new(HashMap::default()));
|
||||
let handlers = Rc::new(RefCell::new(HashMap::default()));
|
||||
let server_task = cx.spawn({
|
||||
let tools = tools.clone();
|
||||
let handlers = handlers.clone();
|
||||
async move |cx| {
|
||||
while let Ok((stream, _)) = listener.accept().await {
|
||||
Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
|
||||
Self::serve_connection(stream, handlers.clone(), cx);
|
||||
}
|
||||
drop(temp_dir)
|
||||
}
|
||||
@@ -76,56 +56,11 @@ impl McpServer {
|
||||
Ok(Self {
|
||||
socket_path,
|
||||
_server_task: server_task,
|
||||
tools,
|
||||
handlers: handlers,
|
||||
handlers: handlers.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
|
||||
let output_schema = schemars::schema_for!(T::Output);
|
||||
let unit_schema = schemars::schema_for!(());
|
||||
|
||||
let registered_tool = RegisteredTool {
|
||||
tool: Tool {
|
||||
name: T::NAME.into(),
|
||||
description: Some(tool.description().into()),
|
||||
input_schema: schemars::schema_for!(T::Input).into(),
|
||||
output_schema: if output_schema == unit_schema {
|
||||
None
|
||||
} else {
|
||||
Some(output_schema.into())
|
||||
},
|
||||
annotations: Some(tool.annotations()),
|
||||
},
|
||||
handler: Box::new({
|
||||
let tool = tool.clone();
|
||||
move |input_value, cx| {
|
||||
let input = match input_value {
|
||||
Some(input) => serde_json::from_value(input),
|
||||
None => serde_json::from_value(serde_json::Value::Null),
|
||||
};
|
||||
|
||||
let tool = tool.clone();
|
||||
match input {
|
||||
Ok(input) => cx.spawn(async move |cx| {
|
||||
let output = tool.run(input, cx).await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: output.content,
|
||||
structured_content: serde_json::to_value(output.structured_content)
|
||||
.unwrap_or_default(),
|
||||
})
|
||||
}),
|
||||
Err(err) => Task::ready(Err(err.into())),
|
||||
}
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
self.tools.borrow_mut().insert(T::NAME, registered_tool);
|
||||
}
|
||||
|
||||
pub fn handle_request<R: Request>(
|
||||
&mut self,
|
||||
f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
|
||||
@@ -185,8 +120,7 @@ impl McpServer {
|
||||
|
||||
fn serve_connection(
|
||||
stream: UnixStream,
|
||||
tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
|
||||
handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
|
||||
handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let (read, write) = smol::io::split(stream);
|
||||
@@ -201,13 +135,7 @@ impl McpServer {
|
||||
let Some(request_id) = request.id.clone() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if request.method == CallTool::METHOD {
|
||||
Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
|
||||
.await;
|
||||
} else if request.method == ListTools::METHOD {
|
||||
Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
|
||||
} else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
|
||||
if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
|
||||
let outgoing_tx = outgoing_tx.clone();
|
||||
|
||||
if let Some(task) = cx
|
||||
@@ -221,126 +149,25 @@ impl McpServer {
|
||||
.detach();
|
||||
}
|
||||
} else {
|
||||
Self::send_err(
|
||||
request_id,
|
||||
format!("unhandled method {}", request.method),
|
||||
&outgoing_tx,
|
||||
);
|
||||
outgoing_tx
|
||||
.unbounded_send(
|
||||
serde_json::to_string(&Response::<()> {
|
||||
jsonrpc: "2.0",
|
||||
id: request.id.unwrap(),
|
||||
value: CspResult::Error(Some(crate::client::Error {
|
||||
message: format!("unhandled method {}", request.method),
|
||||
code: -32601,
|
||||
})),
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn handle_list_tools(
|
||||
request_id: RequestId,
|
||||
tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
|
||||
outgoing_tx: &UnboundedSender<String>,
|
||||
) {
|
||||
let response = ListToolsResponse {
|
||||
tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
|
||||
next_cursor: None,
|
||||
meta: None,
|
||||
};
|
||||
|
||||
outgoing_tx
|
||||
.unbounded_send(
|
||||
serde_json::to_string(&Response {
|
||||
jsonrpc: "2.0",
|
||||
id: request_id,
|
||||
value: CspResult::Ok(Some(response)),
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
|
||||
async fn handle_call_tool(
|
||||
request_id: RequestId,
|
||||
params: Option<Box<RawValue>>,
|
||||
tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
|
||||
outgoing_tx: &UnboundedSender<String>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
|
||||
Some(params) => serde_json::from_str(params.get()),
|
||||
None => serde_json::from_value(serde_json::Value::Null),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(params) => {
|
||||
if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) {
|
||||
let outgoing_tx = outgoing_tx.clone();
|
||||
|
||||
let task = (tool.handler)(params.arguments, cx);
|
||||
cx.spawn(async move |_| {
|
||||
let response = match task.await {
|
||||
Ok(result) => CallToolResponse {
|
||||
content: result.content,
|
||||
is_error: Some(false),
|
||||
meta: None,
|
||||
structured_content: if result.structured_content.is_null() {
|
||||
None
|
||||
} else {
|
||||
Some(result.structured_content)
|
||||
},
|
||||
},
|
||||
Err(err) => CallToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: err.to_string(),
|
||||
}],
|
||||
is_error: Some(true),
|
||||
meta: None,
|
||||
structured_content: None,
|
||||
},
|
||||
};
|
||||
|
||||
outgoing_tx
|
||||
.unbounded_send(
|
||||
serde_json::to_string(&Response {
|
||||
jsonrpc: "2.0",
|
||||
id: request_id,
|
||||
value: CspResult::Ok(Some(response)),
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
} else {
|
||||
Self::send_err(
|
||||
request_id,
|
||||
format!("Tool not found: {}", params.name),
|
||||
&outgoing_tx,
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
Self::send_err(request_id, err.to_string(), &outgoing_tx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_err(
|
||||
request_id: RequestId,
|
||||
message: impl Into<String>,
|
||||
outgoing_tx: &UnboundedSender<String>,
|
||||
) {
|
||||
outgoing_tx
|
||||
.unbounded_send(
|
||||
serde_json::to_string(&Response::<()> {
|
||||
jsonrpc: "2.0",
|
||||
id: request_id,
|
||||
value: CspResult::Error(Some(crate::client::Error {
|
||||
message: message.into(),
|
||||
code: -32601,
|
||||
})),
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
|
||||
async fn handle_io(
|
||||
mut outgoing_rx: UnboundedReceiver<String>,
|
||||
incoming_tx: UnboundedSender<RawRequest>,
|
||||
@@ -389,37 +216,7 @@ impl McpServer {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait McpServerTool {
|
||||
type Input: DeserializeOwned + JsonSchema;
|
||||
type Output: Serialize + JsonSchema;
|
||||
|
||||
const NAME: &'static str;
|
||||
|
||||
fn description(&self) -> &'static str;
|
||||
|
||||
fn annotations(&self) -> ToolAnnotations {
|
||||
ToolAnnotations {
|
||||
title: None,
|
||||
read_only_hint: None,
|
||||
destructive_hint: None,
|
||||
idempotent_hint: None,
|
||||
open_world_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
|
||||
}
|
||||
|
||||
pub struct ToolResponse<T> {
|
||||
pub content: Vec<ToolResponseContent>,
|
||||
pub structured_content: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct RawRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<RequestId>,
|
||||
|
||||
@@ -5,12 +5,7 @@
|
||||
//! read/write messages and the types from types.rs for serialization/deserialization
|
||||
//! of messages.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::AsyncApp;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::client::Client;
|
||||
use crate::types::{self, Notification, Request};
|
||||
@@ -100,25 +95,7 @@ impl InitializedContextServerProtocol {
|
||||
self.inner.request(T::METHOD, params).await
|
||||
}
|
||||
|
||||
pub async fn request_with<T: Request>(
|
||||
&self,
|
||||
params: T::Params,
|
||||
cancel_rx: Option<oneshot::Receiver<()>>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<T::Response> {
|
||||
self.inner
|
||||
.request_with(T::METHOD, params, cancel_rx, timeout)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
|
||||
self.inner.notify(T::METHOD, params)
|
||||
}
|
||||
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
{
|
||||
self.inner.on_notification(method, f);
|
||||
}
|
||||
}
|
||||
|
||||