Compare commits

..

14 Commits

Author SHA1 Message Date
Mikayla Maki
f9a7019049 Switch to using text style instead of refinement 2025-04-10 16:17:44 -06:00
Mikayla Maki
0e35f5bf20 Merge branch 'main' into fix-text-hover-style 2025-04-10 16:07:57 -06:00
Mikayla Maki
7ac19ff1ed remove redundant code 2025-04-10 16:06:37 -06:00
Jason Lee
7d19c452b4 update example 2025-03-15 19:16:19 +08:00
Jason Lee
2b76bd8e35 . 2025-03-15 18:33:09 +08:00
Jason Lee
070464e4c5 Merge remote-tracking branch 'origin/main' into fix-text-hover-style 2025-03-15 17:46:16 +08:00
Jason Lee
69be667c02 Merge remote-tracking branch 'origin/main' into fix-text-hover-style 2025-03-15 17:46:03 +08:00
Jason Lee
0155974e59 . 2025-02-13 10:55:16 +08:00
Jason Lee
575782ea88 . 2025-02-13 10:50:24 +08:00
Jason Lee
0d6b2052ea . 2025-02-13 10:45:19 +08:00
Jason Lee
1e16283820 . 2025-02-13 10:42:49 +08:00
Jason Lee
0a61f38231 . 2025-02-12 19:47:50 +08:00
Jason Lee
1e2b16a1a3 Fix text style on hover 2025-02-12 19:29:42 +08:00
Jason Lee
cb21b53ca9 Fix hello_world example 2025-02-12 16:25:21 +08:00
544 changed files with 15361 additions and 24019 deletions

View File

@@ -20,7 +20,6 @@ platforms = [
[traversal-excludes]
workspace-members = [
"remote_server",
"eval",
]
third-party = [
{ name = "reqwest", version = "0.11.27" },

View File

@@ -1,19 +0,0 @@
name: Other [Staff Only]
description: Zed Staff Only
body:
- type: textarea
attributes:
label: Summary
value: |
<!-- Please insert a one line summary of the issue below -->
SUMMARY_SENTENCE_HERE
### Description
IF YOU DO NOT WORK FOR ZED INDUSTRIES DO NOT CREATE ISSUES WITH THIS TEMPLATE.
THEY WILL BE AUTO-CLOSED AND MAY RESULT IN YOU BEING BANNED FROM THE ZED ISSUE TRACKER.
FEATURE REQUESTS / SUPPORT REQUESTS SHOULD BE OPENED AS DISCUSSIONS:
https://github.com/zed-industries/zed/discussions/new/choose
validations:
required: true

View File

@@ -325,7 +325,7 @@ jobs:
cache-provider: "buildjet"
- name: Install Clang & Mold
run: ./script/headless && ./script/install-mold 2.34.0
run: ./script/remote-server && ./script/install-mold 2.34.0
- name: Configure CI
run: |
@@ -594,7 +594,7 @@ jobs:
timeout-minutes: 60
name: Linux x86_x64 release bundle
runs-on:
- buildjet-16vcpu-ubuntu-2004 # ubuntu 20.04 for minimal glibc
- buildjet-16vcpu-ubuntu-2004
if: |
startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
@@ -622,23 +622,26 @@ jobs:
- name: Create Linux .tar.gz bundle
run: script/bundle-linux
- name: Upload Artifact to Workflow - zed (run-bundling)
- name: Upload Linux bundle to workflow run if main branch or specific label
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
with:
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.tar.gz
path: target/release/zed-*.tar.gz
- name: Upload Artifact to Workflow - zed-remote-server (run-bundling)
- name: Upload Linux remote server to workflow run if main branch or specific label
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
with:
name: zed-remote-server-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.gz
path: target/zed-remote-server-linux-x86_64.gz
- name: Upload Artifacts to release
- name: Upload app bundle to release
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) }}
with:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
@@ -677,26 +680,29 @@ jobs:
# This exports RELEASE_CHANNEL into env (GITHUB_ENV)
script/determine-release-channel
- name: Create and upload Linux .tar.gz bundles
- name: Create and upload Linux .tar.gz bundle
run: script/bundle-linux
- name: Upload Artifact to Workflow - zed (run-bundling)
- name: Upload Linux bundle to workflow run if main branch or specific label
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
with:
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.tar.gz
path: target/release/zed-*.tar.gz
- name: Upload Artifact to Workflow - zed-remote-server (run-bundling)
- name: Upload Linux remote server to workflow run if main branch or specific label
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
with:
name: zed-remote-server-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.gz
path: target/zed-remote-server-linux-aarch64.gz
- name: Upload Artifacts to release
- name: Upload app bundle to release
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) }}
with:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}

View File

@@ -117,10 +117,12 @@ jobs:
export ZED_KUBE_NAMESPACE=production
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=10
export ZED_API_LOAD_BALANCER_SIZE_UNIT=2
export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=2
elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then
export ZED_KUBE_NAMESPACE=staging
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=1
export ZED_API_LOAD_BALANCER_SIZE_UNIT=1
export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=1
else
echo "cowardly refusing to deploy from an unknown branch"
exit 1
@@ -145,3 +147,9 @@ jobs:
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
export ZED_SERVICE_NAME=llm
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_LLM_LOAD_BALANCER_SIZE_UNIT
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"

View File

@@ -1,90 +0,0 @@
name: Run Eval Daily
on:
schedule:
- cron: "0 2 * * *"
workflow_dispatch:
env:
CARGO_TERM_COLOR: always
CARGO_INCREMENTAL: 0
RUST_BACKTRACE: 1
jobs:
build_binary:
name: Build Eval Binary
if: github.repository_owner == 'zed-industries'
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
clean: false
- name: Install Dependencies
run: |
sudo apt-get update
sudo apt-get install -y libx11-dev libxcb-shape0-dev libxcb-xfixes0-dev libxcb1-dev libxcb-render0-dev libxcb-randr0-dev libxcb-xtest0-dev libxcb-keysyms1-dev libxext-dev libxrender-dev libxrandr-dev libxtst-dev libxfixes-dev pkg-config libasound2-dev libx11-xcb-dev libxkbcommon-dev libxkbcommon-x11-dev
./script/headless && ./script/install-mold 2.34.0
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable
- name: Build entire Zed application
run: cargo build --release
- name: Build eval binary separately to ensure it's available
run: cargo build --release -p eval
- name: Upload eval binary
uses: actions/upload-artifact@v4
with:
name: eval-binary
path: target/release/eval
run_eval:
name: Run Eval - ${{ matrix.exercise }}
needs: build_binary
if: github.repository_owner == 'zed-industries'
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
exercise:
- find_and_replace_diff_card
- metal_i64_support
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
clean: false
- name: Install Dependencies
run: |
sudo apt-get update
sudo apt-get install -y libx11-dev libxcb-shape0-dev libxcb-xfixes0-dev libxcb1-dev libxcb-render0-dev libxcb-randr0-dev libxcb-xtest0-dev libxcb-keysyms1-dev libxext-dev libxrender-dev libxrandr-dev libxtst-dev libxfixes-dev pkg-config libasound2-dev libx11-xcb-dev libxkbcommon-dev libxkbcommon-x11-dev
./script/headless && ./script/install-mold 2.34.0
- name: Create required directories
run: |
mkdir -p /home/runner/.config/zed
touch /home/runner/.config/zed/settings.json
mkdir -p ./crates/eval/repos
mkdir -p ./crates/eval/worktrees
mkdir -p ./crates/eval/runs
- name: Set up Anthropic API key
run: echo "ANTHROPIC_API_KEY=${{ secrets.ANTHROPIC_API_KEY }}" >> $GITHUB_ENV
- name: Download eval binary
uses: actions/download-artifact@v4
with:
name: eval-binary
path: ./target/release/
- name: Make eval binary executable
run: chmod +x ./target/release/eval
- name: Run eval for ${{ matrix.exercise }}
run: ./target/release/eval --cohort-id dailyrun${{ github.run_id }} ${{ matrix.exercise }}

View File

@@ -1,13 +1,13 @@
[
{
"label": "Debug Zed (CodeLLDB)",
"adapter": "CodeLLDB",
"label": "Debug Zed with LLDB",
"adapter": "LLDB",
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
"request": "launch",
"cwd": "$ZED_WORKTREE_ROOT"
},
{
"label": "Debug Zed (GDB)",
"label": "Debug Zed with GDB",
"adapter": "GDB",
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
"request": "launch",

291
Cargo.lock generated
View File

@@ -52,6 +52,7 @@ dependencies = [
name = "agent"
version = "0.1.0"
dependencies = [
"agent_rules",
"anyhow",
"assistant_context_editor",
"assistant_settings",
@@ -115,7 +116,6 @@ dependencies = [
"terminal_view",
"text",
"theme",
"thiserror 2.0.12",
"time",
"time_format",
"ui",
@@ -125,7 +125,57 @@ dependencies = [
"workspace",
"workspace-hack",
"zed_actions",
"zed_llm_client",
]
[[package]]
name = "agent_eval"
version = "0.1.0"
dependencies = [
"agent",
"anyhow",
"assistant_tool",
"assistant_tools",
"clap",
"client",
"collections",
"context_server",
"dap",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"language",
"language_model",
"language_models",
"node_runtime",
"project",
"prompt_store",
"release_channel",
"reqwest_client",
"serde",
"serde_json",
"serde_json_lenient",
"settings",
"smol",
"tempfile",
"util",
"walkdir",
"workspace-hack",
]
[[package]]
name = "agent_rules"
version = "0.1.0"
dependencies = [
"anyhow",
"fs",
"gpui",
"indoc",
"prompt_store",
"util",
"workspace-hack",
"worktree",
]
[[package]]
@@ -325,8 +375,9 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.27.1",
"strum",
"thiserror 2.0.12",
"util",
"workspace-hack",
]
@@ -338,9 +389,9 @@ checksum = "34cd60c5e3152cef0a592f1b296f1cc93715d89d2551d85315828c3a09575ff4"
[[package]]
name = "anyhow"
version = "1.0.98"
version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
[[package]]
name = "approx"
@@ -449,6 +500,7 @@ dependencies = [
"smol",
"tempfile",
"util",
"which 6.0.3",
"workspace-hack",
]
@@ -567,7 +619,7 @@ dependencies = [
"settings",
"smallvec",
"smol",
"strum 0.27.1",
"strum",
"telemetry_events",
"text",
"theme",
@@ -704,35 +756,26 @@ dependencies = [
"assistant_tool",
"chrono",
"collections",
"component",
"feature_flags",
"futures 0.3.31",
"gpui",
"html_to_markdown",
"http_client",
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"linkme",
"open",
"pretty_assertions",
"project",
"rand 0.8.5",
"regex",
"schemars",
"serde",
"serde_json",
"settings",
"tree-sitter-rust",
"ui",
"unindent",
"util",
"web_search",
"workspace",
"workspace-hack",
"worktree",
"zed_llm_client",
]
[[package]]
@@ -1190,18 +1233,6 @@ dependencies = [
"workspace-hack",
]
[[package]]
name = "auto_update_helper"
version = "0.1.0"
dependencies = [
"anyhow",
"log",
"simplelog",
"windows 0.61.1",
"winresource",
"workspace-hack",
]
[[package]]
name = "auto_update_ui"
version = "0.1.0"
@@ -1890,7 +1921,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.27.1",
"strum",
"thiserror 2.0.12",
"tokio",
"workspace-hack",
@@ -2951,6 +2982,7 @@ dependencies = [
name = "collab"
version = "0.44.0"
dependencies = [
"anthropic",
"anyhow",
"assistant",
"assistant_context_editor",
@@ -3037,7 +3069,7 @@ dependencies = [
"settings",
"sha2",
"sqlx",
"strum 0.27.1",
"strum",
"subtle",
"supermaven_api",
"telemetry_events",
@@ -3057,7 +3089,6 @@ dependencies = [
"workspace",
"workspace-hack",
"worktree",
"zed_llm_client",
]
[[package]]
@@ -3195,18 +3226,14 @@ dependencies = [
name = "component_preview"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"collections",
"component",
"db",
"gpui",
"languages",
"notifications",
"project",
"serde",
"ui",
"ui_input",
"workspace",
"workspace-hack",
]
@@ -3370,7 +3397,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
"strum 0.27.1",
"strum",
"task",
"theme",
"ui",
@@ -4011,6 +4038,7 @@ dependencies = [
"node_runtime",
"parking_lot",
"paths",
"regex",
"schemars",
"serde",
"serde_json",
@@ -4042,6 +4070,7 @@ dependencies = [
"gpui",
"language",
"paths",
"regex",
"serde",
"serde_json",
"task",
@@ -4185,7 +4214,6 @@ dependencies = [
"collections",
"command_palette_hooks",
"dap",
"db",
"editor",
"env_logger 0.11.8",
"feature_flags",
@@ -4204,7 +4232,6 @@ dependencies = [
"settings",
"sysinfo",
"task",
"tasks_ui",
"terminal_view",
"theme",
"ui",
@@ -4325,24 +4352,19 @@ dependencies = [
"anyhow",
"client",
"collections",
"component",
"ctor",
"editor",
"env_logger 0.11.8",
"gpui",
"indoc",
"language",
"linkme",
"log",
"lsp",
"markdown",
"pretty_assertions",
"project",
"rand 0.8.5",
"serde",
"serde_json",
"settings",
"text",
"theme",
"ui",
"unindent",
@@ -4487,7 +4509,7 @@ dependencies = [
"optfield",
"proc-macro2",
"quote",
"strum 0.26.3",
"strum",
"syn 2.0.100",
]
@@ -4889,41 +4911,57 @@ dependencies = [
"anyhow",
"assistant_tool",
"assistant_tools",
"async-watch",
"chrono",
"clap",
"client",
"collections",
"context_server",
"dap",
"dirs 5.0.1",
"env_logger 0.11.8",
"extension",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
"language",
"language_extension",
"language_model",
"language_models",
"languages",
"node_runtime",
"paths",
"project",
"prompt_store",
"release_channel",
"reqwest_client",
"serde",
"settings",
"smol",
"toml 0.8.20",
"workspace-hack",
]
[[package]]
name = "evals"
version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"client",
"clock",
"collections",
"dap",
"env_logger 0.11.8",
"feature_flags",
"fs",
"gpui",
"http_client",
"language",
"languages",
"node_runtime",
"open_ai",
"project",
"reqwest_client",
"semantic_index",
"serde",
"serde_json",
"settings",
"shellexpand 2.1.2",
"telemetry",
"toml 0.8.20",
"unindent",
"smol",
"util",
"uuid",
"workspace-hack",
]
[[package]]
@@ -4987,10 +5025,10 @@ dependencies = [
"async-tar",
"async-trait",
"collections",
"convert_case 0.8.0",
"fs",
"futures 0.3.31",
"gpui",
"heck 0.5.0",
"http_client",
"language",
"log",
@@ -5102,7 +5140,7 @@ dependencies = [
"serde",
"settings",
"smallvec",
"strum 0.27.1",
"strum",
"telemetry",
"theme",
"ui",
@@ -5953,7 +5991,7 @@ dependencies = [
"serde_derive",
"serde_json",
"settings",
"strum 0.27.1",
"strum",
"telemetry",
"theme",
"time",
@@ -6046,7 +6084,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.27.1",
"strum",
"workspace-hack",
]
@@ -6152,7 +6190,7 @@ dependencies = [
"slotmap",
"smallvec",
"smol",
"strum 0.27.1",
"strum",
"sum_tree",
"taffy",
"thiserror 2.0.12",
@@ -6800,7 +6838,7 @@ name = "icons"
version = "0.1.0"
dependencies = [
"serde",
"strum 0.27.1",
"strum",
"workspace-hack",
]
@@ -7068,16 +7106,16 @@ dependencies = [
"paths",
"pretty_assertions",
"serde",
"strum 0.27.1",
"strum",
"util",
"workspace-hack",
]
[[package]]
name = "indexmap"
version = "2.8.0"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058"
checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
dependencies = [
"equivalent",
"hashbrown 0.15.2",
@@ -7654,19 +7692,17 @@ dependencies = [
"serde",
"serde_json",
"smol",
"strum 0.27.1",
"strum",
"telemetry_events",
"thiserror 2.0.12",
"util",
"workspace-hack",
"zed_llm_client",
]
[[package]]
name = "language_model_selector"
version = "0.1.0"
dependencies = [
"collections",
"feature_flags",
"gpui",
"language_model",
@@ -7715,15 +7751,13 @@ dependencies = [
"serde_json",
"settings",
"smol",
"strum 0.27.1",
"strum",
"theme",
"thiserror 2.0.12",
"tiktoken-rs",
"tokio",
"ui",
"util",
"workspace-hack",
"zed_llm_client",
]
[[package]]
@@ -7939,9 +7973,9 @@ checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa"
[[package]]
name = "libmimalloc-sys"
version = "0.1.41"
version = "0.1.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b20daca3a4ac14dbdc753c5e90fc7b490a48a9131daed3c9a9ced7b2defd37b"
checksum = "ec9d6fac27761dabcd4ee73571cdb06b7022dc99089acbe5435691edffaac0f4"
dependencies = [
"cc",
"libc",
@@ -8612,9 +8646,9 @@ dependencies = [
[[package]]
name = "mimalloc"
version = "0.1.45"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03cb1f88093fe50061ca1195d336ffec131347c7b833db31f9ab62a2d1b7925f"
checksum = "995942f432bbb4822a7e9c3faa87a695185b0d09273ba85f097b54f4e458f2af"
dependencies = [
"libmimalloc-sys",
]
@@ -8688,7 +8722,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.27.1",
"strum",
"workspace-hack",
]
@@ -9535,7 +9569,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.27.1",
"strum",
"workspace-hack",
]
@@ -11729,7 +11763,6 @@ name = "remote_server"
version = "0.1.0"
dependencies = [
"anyhow",
"askpass",
"async-watch",
"backtrace",
"cargo_toml",
@@ -12115,7 +12148,7 @@ dependencies = [
"serde",
"serde_json",
"sha2",
"strum 0.27.1",
"strum",
"tracing",
"util",
"workspace-hack",
@@ -12643,7 +12676,7 @@ dependencies = [
"serde",
"serde_json",
"sqlx",
"strum 0.26.3",
"strum",
"thiserror 2.0.12",
"time",
"tracing",
@@ -13239,9 +13272,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.14.0"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd"
checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9"
dependencies = [
"serde",
]
@@ -13308,7 +13341,6 @@ dependencies = [
"fs",
"futures 0.3.31",
"gpui",
"indoc",
"parking_lot",
"paths",
"schemars",
@@ -13689,7 +13721,7 @@ dependencies = [
"settings",
"simplelog",
"story",
"strum 0.27.1",
"strum",
"theme",
"title_bar",
"ui",
@@ -13771,16 +13803,7 @@ version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
dependencies = [
"strum_macros 0.26.4",
]
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros 0.27.1",
"strum_macros",
]
[[package]]
@@ -13796,19 +13819,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.100",
]
[[package]]
name = "subtle"
version = "2.6.1"
@@ -14230,7 +14240,9 @@ version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"debugger_ui",
"editor",
"feature_flags",
"file_icons",
"fuzzy",
"gpui",
@@ -14424,7 +14436,7 @@ dependencies = [
"serde_json_lenient",
"serde_repr",
"settings",
"strum 0.27.1",
"strum",
"thiserror 2.0.12",
"util",
"uuid",
@@ -14458,7 +14470,7 @@ dependencies = [
"serde_json",
"serde_json_lenient",
"simplelog",
"strum 0.27.1",
"strum",
"theme",
"vscode_theme",
"workspace-hack",
@@ -15459,7 +15471,7 @@ dependencies = [
"settings",
"smallvec",
"story",
"strum 0.27.1",
"strum",
"theme",
"ui_macros",
"util",
@@ -16592,36 +16604,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web_search"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"gpui",
"serde",
"workspace-hack",
"zed_llm_client",
]
[[package]]
name = "web_search_providers"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"feature_flags",
"futures 0.3.31",
"gpui",
"http_client",
"language_model",
"serde",
"serde_json",
"web_search",
"workspace-hack",
"zed_llm_client",
]
[[package]]
name = "webpki-root-certs"
version = "0.26.8"
@@ -17660,7 +17642,7 @@ dependencies = [
"settings",
"smallvec",
"sqlez",
"strum 0.27.1",
"strum",
"task",
"telemetry",
"tempfile",
@@ -17668,7 +17650,6 @@ dependencies = [
"ui",
"util",
"uuid",
"windows 0.61.1",
"workspace-hack",
"zed_actions",
]
@@ -17805,7 +17786,7 @@ dependencies = [
"sqlx-macros-core",
"sqlx-postgres",
"sqlx-sqlite",
"strum 0.26.3",
"strum",
"subtle",
"syn 1.0.109",
"syn 2.0.100",
@@ -17831,8 +17812,6 @@ dependencies = [
"wasmtime-cranelift",
"wasmtime-environ",
"winapi",
"windows-core 0.61.0",
"windows-numerics",
"windows-sys 0.48.0",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
@@ -18177,13 +18156,12 @@ dependencies = [
[[package]]
name = "zed"
version = "0.184.0"
version = "0.182.0"
dependencies = [
"activity_indicator",
"agent",
"anyhow",
"ashpd",
"askpass",
"assets",
"assistant",
"assistant_context_editor",
@@ -18274,6 +18252,7 @@ dependencies = [
"settings",
"settings_ui",
"shellexpand 2.1.2",
"simplelog",
"smol",
"snippet_provider",
"snippets_ui",
@@ -18301,8 +18280,6 @@ dependencies = [
"uuid",
"vim",
"vim_mode_setting",
"web_search",
"web_search_providers",
"welcome",
"windows 0.61.1",
"winresource",
@@ -18310,7 +18287,6 @@ dependencies = [
"workspace-hack",
"zed_actions",
"zeta",
"zlog",
"zlog_settings",
]
@@ -18360,21 +18336,19 @@ dependencies = [
[[package]]
name = "zed_html"
version = "0.2.1"
version = "0.2.0"
dependencies = [
"zed_extension_api 0.1.0",
]
[[package]]
name = "zed_llm_client"
version = "0.6.0"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91b8b05f1028157205026e525869eb860fa89bec87ea60b445efc91d05df31f"
checksum = "1bf21350eced858d129840589158a8f6895c4fa4327ae56dd8c7d6a98495bed4"
dependencies = [
"anyhow",
"serde",
"serde_json",
"strum 0.27.1",
"uuid",
]
@@ -18625,10 +18599,7 @@ dependencies = [
name = "zlog"
version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"log",
"tempfile",
"workspace-hack",
]

View File

@@ -3,11 +3,13 @@ resolver = "2"
members = [
"crates/activity_indicator",
"crates/agent",
"crates/agent_rules",
"crates/anthropic",
"crates/askpass",
"crates/assets",
"crates/assistant",
"crates/assistant_context_editor",
"crates/agent_eval",
"crates/assistant_settings",
"crates/assistant_slash_command",
"crates/assistant_slash_commands",
@@ -15,7 +17,6 @@ members = [
"crates/assistant_tools",
"crates/audio",
"crates/auto_update",
"crates/auto_update_helper",
"crates/auto_update_ui",
"crates/aws_http_client",
"crates/bedrock",
@@ -47,6 +48,7 @@ members = [
"crates/docs_preprocessor",
"crates/editor",
"crates/eval",
"crates/evals",
"crates/extension",
"crates/extension_api",
"crates/extension_cli",
@@ -164,8 +166,6 @@ members = [
"crates/util_macros",
"crates/vim",
"crates/vim_mode_setting",
"crates/web_search",
"crates/web_search_providers",
"crates/welcome",
"crates/workspace",
"crates/worktree",
@@ -211,12 +211,14 @@ edition = "2024"
activity_indicator = { path = "crates/activity_indicator" }
agent = { path = "crates/agent" }
agent_rules = { path = "crates/agent_rules" }
ai = { path = "crates/ai" }
anthropic = { path = "crates/anthropic" }
askpass = { path = "crates/askpass" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant_context_editor = { path = "crates/assistant_context_editor" }
assistant_eval = { path = "crates/agent_eval" }
assistant_settings = { path = "crates/assistant_settings" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
@@ -224,7 +226,6 @@ assistant_tool = { path = "crates/assistant_tool" }
assistant_tools = { path = "crates/assistant_tools" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
auto_update_helper = { path = "crates/auto_update_helper" }
auto_update_ui = { path = "crates/auto_update_ui" }
aws_http_client = { path = "crates/aws_http_client" }
bedrock = { path = "crates/bedrock" }
@@ -371,8 +372,6 @@ util = { path = "crates/util" }
util_macros = { path = "crates/util_macros" }
vim = { path = "crates/vim" }
vim_mode_setting = { path = "crates/vim_mode_setting" }
web_search = { path = "crates/web_search" }
web_search_providers = { path = "crates/web_search_providers" }
welcome = { path = "crates/welcome" }
workspace = { path = "crates/workspace" }
worktree = { path = "crates/worktree" }
@@ -404,12 +403,8 @@ async-tungstenite = "0.29.1"
async-watch = "0.3.1"
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
aws-config = { version = "1.6.1", features = ["behavior-version-latest"] }
aws-credential-types = { version = "1.2.2", features = [
"hardcoded-credentials",
] }
aws-sdk-bedrockruntime = { version = "1.80.0", features = [
"behavior-version-latest",
] }
aws-credential-types = { version = "1.2.2", features = ["hardcoded-credentials"] }
aws-sdk-bedrockruntime = { version = "1.80.0", features = ["behavior-version-latest"] }
aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
base64 = "0.22"
@@ -448,7 +443,6 @@ futures-lite = "1.13"
git2 = { version = "0.20.1", default-features = false }
globset = "0.4"
handlebars = "4.3"
heck = "0.5"
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
hex = "0.4.3"
html5ever = "0.27.0"
@@ -539,7 +533,7 @@ smol = "2.0"
sqlformat = "0.2"
streaming-iterator = "0.1"
strsim = "0.11"
strum = { version = "0.27.0", features = ["derive"] }
strum = { version = "0.26.0", features = ["derive"] }
subtle = "2.5.0"
syn = { version = "1.0.72", features = ["full", "extra-traits"] }
sys-locale = "0.3.1"
@@ -604,7 +598,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "0.6.0"
zed_llm_client = "0.4"
zstd = "0.11"
metal = "0.29"
@@ -625,10 +619,12 @@ features = [
[workspace.dependencies.windows]
version = "0.61"
features = [
"Foundation_Collections",
"Foundation_Numerics",
"Storage_Search",
"Storage_Streams",
"System_Threading",
"UI_StartScreen",
"UI_ViewManagement",
"Wdk_System_SystemServices",
"Win32_Globalization",
@@ -655,7 +651,6 @@ features = [
"Win32_System_SystemInformation",
"Win32_System_SystemServices",
"Win32_System_Threading",
"Win32_System_Variant",
"Win32_System_WinRT",
"Win32_UI_Controls",
"Win32_UI_HiDpi",
@@ -663,7 +658,6 @@ features = [
"Win32_UI_Input_KeyboardAndMouse",
"Win32_UI_Shell",
"Win32_UI_Shell_Common",
"Win32_UI_Shell_PropertiesSystem",
"Win32_UI_WindowsAndMessaging",
]
@@ -695,6 +689,7 @@ breadcrumbs = { codegen-units = 1 }
collections = { codegen-units = 1 }
command_palette = { codegen-units = 1 }
command_palette_hooks = { codegen-units = 1 }
evals = { codegen-units = 1 }
extension_cli = { codegen-units = 1 }
feature_flags = { codegen-units = 1 }
file_icons = { codegen-units = 1 }
@@ -786,12 +781,4 @@ let_underscore_future = "allow"
too_many_arguments = "allow"
[workspace.metadata.cargo-machete]
ignored = [
"bindgen",
"cbindgen",
"prost_build",
"serde",
"component",
"linkme",
"workspace-hack",
]
ignored = ["bindgen", "cbindgen", "prost_build", "serde", "component", "linkme", "workspace-hack"]

View File

@@ -11,7 +11,7 @@ ENV CARGO_TERM_COLOR=always
COPY script/install-mold script/
RUN ./script/install-mold "2.34.0"
COPY script/headless script/
RUN ./script/headless
COPY script/remote-server script/
RUN ./script/remote-server
COPY . .

View File

@@ -0,0 +1,12 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="16" height="16" rx="2" fill="black" fill-opacity="0.2"/>
<g clip-path="url(#clip0_1916_18)">
<path d="M10.652 3.79999H8.816L12.164 12.2H14L10.652 3.79999Z" fill="#1F1F1E"/>
<path d="M5.348 3.79999L2 12.2H3.872L4.55672 10.436H8.05927L8.744 12.2H10.616L7.268 3.79999H5.348ZM5.16224 8.87599L6.308 5.92399L7.45374 8.87599H5.16224Z" fill="#1F1F1E"/>
</g>
<defs>
<clipPath id="clip0_1916_18">
<rect width="12" height="8.4" fill="white" transform="translate(2 3.79999)"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 601 B

View File

@@ -1,5 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M20 14H4C3.44772 14 3 14.4477 3 15V20C3 20.5523 3.44772 21 4 21H20C20.5523 21 21 20.5523 21 20V15C21 14.4477 20.5523 14 20 14Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M11 3H4C3.44772 3 3 3.44772 3 4V9C3 9.55228 3.44772 10 4 10H11C11.5523 10 12 9.55228 12 9V4C12 3.44772 11.5523 3 11 3Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M20 3H17C16.4477 3 16 3.44772 16 4V9C16 9.55228 16.4477 10 17 10H20C20.5523 10 21 9.55228 21 9V4C21 3.44772 20.5523 3 20 3Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 746 B

View File

@@ -134,9 +134,7 @@
"shift-f10": "editor::OpenContextMenu",
"ctrl-shift-e": "editor::ToggleEditPrediction",
"f9": "editor::ToggleBreakpoint",
"shift-f9": "editor::EditLogBreakpoint",
"ctrl-shift-backspace": "editor::GoToPreviousChange",
"ctrl-shift-alt-backspace": "editor::GoToNextChange"
"shift-f9": "editor::EditLogBreakpoint"
}
},
{
@@ -152,9 +150,7 @@
"context": "AgentDiff",
"bindings": {
"ctrl-y": "agent::Keep",
"ctrl-n": "agent::Reject",
"ctrl-shift-y": "agent::KeepAll",
"ctrl-shift-n": "agent::RejectAll"
"ctrl-n": "agent::Reject"
}
},
{
@@ -356,11 +352,11 @@
"alt-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
"ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch
"ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip
"ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }],
"ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match
"ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }],
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }],
"ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }],
"ctrl-k ctrl-i": "editor::Hover",
"ctrl-/": ["editor::ToggleComments", { "advance_downwards": false }],
"ctrl-u": "editor::UndoSelection",
@@ -629,14 +625,12 @@
"context": "AgentPanel",
"bindings": {
"ctrl-n": "agent::NewThread",
"ctrl-alt-n": "agent::NewTextThread",
"ctrl-alt-n": "agent::NewPromptEditor",
"ctrl-shift-h": "agent::OpenHistory",
"ctrl-alt-c": "agent::OpenConfiguration",
"ctrl-alt-p": "assistant::OpenPromptLibrary",
"ctrl-i": "agent::ToggleProfileSelector",
"ctrl-alt-/": "assistant::ToggleModelSelector",
"ctrl-shift-a": "agent::ToggleContextPicker",
"shift-escape": "agent::ExpandMessageEditor",
"ctrl-e": "agent::ChatMode",
"ctrl-alt-e": "agent::RemoveAllContext"
}
@@ -651,7 +645,7 @@
{
"context": "AgentPanel && prompt_editor",
"bindings": {
"cmd-n": "agent::NewTextThread",
"cmd-n": "agent::NewPromptEditor",
"cmd-alt-t": "agent::NewThread"
}
},
@@ -721,7 +715,7 @@
"alt-shift-copy": "workspace::CopyRelativePath",
"alt-ctrl-shift-c": "workspace::CopyRelativePath",
"alt-ctrl-r": "outline_panel::RevealInFileManager",
"space": "outline_panel::OpenSelectedEntry",
"space": "outline_panel::Open",
"shift-down": "menu::SelectNext",
"shift-up": "menu::SelectPrevious",
"alt-enter": "editor::OpenExcerpts",
@@ -785,7 +779,6 @@
"shift-tab": "git_panel::FocusEditor",
"escape": "git_panel::ToggleFocus",
"ctrl-enter": "git::Commit",
"ctrl-shift-enter": "git::Amend",
"alt-enter": "menu::SecondaryConfirm",
"delete": ["git::RestoreFile", { "skip_prompt": false }],
"backspace": ["git::RestoreFile", { "skip_prompt": false }],
@@ -794,20 +787,12 @@
"ctrl-delete": ["git::RestoreFile", { "skip_prompt": false }]
}
},
{
"context": "GitPanel && CommitEditor",
"use_key_equivalents": true,
"bindings": {
"escape": "git::Cancel"
}
},
{
"context": "GitCommit > Editor",
"bindings": {
"escape": "menu::Cancel",
"enter": "editor::Newline",
"ctrl-enter": "git::Commit",
"ctrl-shift-enter": "git::Amend",
"alt-l": "git::GenerateCommitMessage"
}
},
@@ -829,7 +814,6 @@
"context": "GitDiff > Editor",
"bindings": {
"ctrl-enter": "git::Commit",
"ctrl-shift-enter": "git::Amend",
"ctrl-space": "git::StageAll",
"ctrl-shift-space": "git::UnstageAll"
}
@@ -848,7 +832,6 @@
"shift-tab": "git_panel::FocusChanges",
"enter": "editor::Newline",
"ctrl-enter": "git::Commit",
"ctrl-shift-enter": "git::Amend",
"alt-up": "git_panel::FocusChanges",
"alt-l": "git::GenerateCommitMessage"
}

View File

@@ -242,9 +242,7 @@
"use_key_equivalents": true,
"bindings": {
"cmd-y": "agent::Keep",
"cmd-n": "agent::Reject",
"cmd-shift-y": "agent::KeepAll",
"cmd-shift-n": "agent::RejectAll"
"cmd-n": "agent::Reject"
}
},
{
@@ -283,14 +281,12 @@
"use_key_equivalents": true,
"bindings": {
"cmd-n": "agent::NewThread",
"cmd-alt-n": "agent::NewTextThread",
"cmd-alt-n": "agent::NewPromptEditor",
"cmd-shift-h": "agent::OpenHistory",
"cmd-alt-c": "agent::OpenConfiguration",
"cmd-alt-p": "assistant::OpenPromptLibrary",
"cmd-i": "agent::ToggleProfileSelector",
"cmd-alt-/": "assistant::ToggleModelSelector",
"cmd-shift-a": "agent::ToggleContextPicker",
"shift-escape": "agent::ExpandMessageEditor",
"cmd-e": "agent::ChatMode",
"cmd-alt-e": "agent::RemoveAllContext"
}
@@ -306,7 +302,7 @@
"context": "AgentPanel && prompt_editor",
"use_key_equivalents": true,
"bindings": {
"cmd-n": "agent::NewTextThread",
"cmd-n": "agent::NewPromptEditor",
"cmd-alt-t": "agent::NewThread"
}
},
@@ -492,15 +488,12 @@
"alt-shift-down": "editor::DuplicateLineDown",
"ctrl-shift-right": "editor::SelectLargerSyntaxNode", // Expand Selection
"ctrl-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection
"cmd-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
"cmd-d": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match
"cmd-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
"cmd-f2": "editor::SelectAllMatches", // Select all occurrences of current word
"cmd-k cmd-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip
// macOS binds `ctrl-cmd-d` to Show Dictionary which breaks these two binds
// To use `ctrl-cmd-d` or `ctrl-k ctrl-cmd-d` in Zed you must execute this command and then restart:
// defaults write com.apple.symbolichotkeys AppleSymbolicHotKeys -dict-add 70 '<dict><key>enabled</key><false/></dict>'
"ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch
"cmd-k ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch
"ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": false }],
"cmd-k cmd-d": ["editor::SelectNext", { "replace_newest": true }],
"cmd-k ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": true }],
"cmd-k cmd-i": "editor::Hover",
"cmd-/": ["editor::ToggleComments", { "advance_downwards": false }],
"cmd-u": "editor::UndoSelection",
@@ -542,9 +535,7 @@
"cmd-\\": "pane::SplitRight",
"cmd-k v": "markdown::OpenPreviewToTheSide",
"cmd-shift-v": "markdown::OpenPreview",
"ctrl-cmd-c": "editor::DisplayCursorNames",
"cmd-shift-backspace": "editor::GoToPreviousChange",
"cmd-shift-alt-backspace": "editor::GoToNextChange"
"ctrl-cmd-c": "editor::DisplayCursorNames"
}
},
{
@@ -789,7 +780,7 @@
"cmd-alt-c": "workspace::CopyPath",
"alt-cmd-shift-c": "workspace::CopyRelativePath",
"alt-cmd-r": "outline_panel::RevealInFileManager",
"space": "outline_panel::OpenSelectedEntry",
"space": "outline_panel::Open",
"shift-down": "menu::SelectNext",
"shift-up": "menu::SelectPrevious",
"alt-enter": "editor::OpenExcerpts",
@@ -858,26 +849,17 @@
"shift-tab": "git_panel::FocusEditor",
"escape": "git_panel::ToggleFocus",
"cmd-enter": "git::Commit",
"cmd-shift-enter": "git::Amend",
"backspace": ["git::RestoreFile", { "skip_prompt": false }],
"delete": ["git::RestoreFile", { "skip_prompt": false }],
"cmd-backspace": ["git::RestoreFile", { "skip_prompt": true }],
"cmd-delete": ["git::RestoreFile", { "skip_prompt": true }]
}
},
{
"context": "GitPanel && CommitEditor",
"use_key_equivalents": true,
"bindings": {
"escape": "git::Cancel"
}
},
{
"context": "GitDiff > Editor",
"use_key_equivalents": true,
"bindings": {
"cmd-enter": "git::Commit",
"cmd-shift-enter": "git::Amend",
"cmd-ctrl-y": "git::StageAll",
"cmd-ctrl-shift-y": "git::UnstageAll"
}
@@ -888,7 +870,6 @@
"bindings": {
"enter": "editor::Newline",
"cmd-enter": "git::Commit",
"cmd-shift-enter": "git::Amend",
"tab": "git_panel::FocusChanges",
"shift-tab": "git_panel::FocusChanges",
"alt-up": "git_panel::FocusChanges",
@@ -918,7 +899,6 @@
"enter": "editor::Newline",
"escape": "menu::Cancel",
"cmd-enter": "git::Commit",
"cmd-shift-enter": "git::Amend",
"alt-tab": "git::GenerateCommitMessage"
}
},

View File

@@ -37,8 +37,6 @@
"ctrl-shift-a": "editor::SelectLargerSyntaxNode",
"ctrl-shift-d": "editor::DuplicateSelection",
"alt-f3": "editor::SelectAllMatches", // find_all_under
// "ctrl-f3": "", // find_under (cancels any selections)
// "cmd-alt-shift-g": "" // find_under_prev (cancels any selections)
"f9": "editor::SortLinesCaseSensitive",
"ctrl-f9": "editor::SortLinesCaseInsensitive",
"f12": "editor::GoToDefinition",
@@ -51,9 +49,7 @@
"ctrl-k ctrl-l": "editor::ConvertToLowerCase",
"shift-alt-m": "markdown::OpenPreviewToTheSide",
"ctrl-backspace": "editor::DeleteToPreviousWordStart",
"ctrl-delete": "editor::DeleteToNextWordEnd",
"f3": "editor::FindNextMatch",
"shift-f3": "editor::FindPreviousMatch"
"ctrl-delete": "editor::DeleteToNextWordEnd"
}
},
{

View File

@@ -38,8 +38,6 @@
"cmd-shift-a": "editor::SelectLargerSyntaxNode",
"cmd-shift-d": "editor::DuplicateSelection",
"ctrl-cmd-g": "editor::SelectAllMatches", // find_all_under
// "cmd-alt-g": "", // find_under (cancels any selections)
// "cmd-alt-shift-g": "" // find_under_prev (cancels any selections)
"f5": "editor::SortLinesCaseSensitive",
"ctrl-f5": "editor::SortLinesCaseInsensitive",
"shift-f12": "editor::FindAllReferences",
@@ -53,9 +51,7 @@
"cmd-shift-j": "editor::JoinLines",
"shift-alt-m": "markdown::OpenPreviewToTheSide",
"ctrl-backspace": "editor::DeleteToPreviousWordStart",
"ctrl-delete": "editor::DeleteToNextWordEnd",
"cmd-g": "editor::FindNextMatch",
"cmd-shift-g": "editor::FindPreviousMatch"
"ctrl-delete": "editor::DeleteToNextWordEnd"
}
},
{

View File

@@ -190,8 +190,7 @@
"ctrl-w g shift-d": "editor::GoToTypeDefinitionSplit",
"ctrl-w space": "editor::OpenExcerptsSplit",
"ctrl-w g space": "editor::OpenExcerptsSplit",
"ctrl-6": "pane::AlternateFile",
"ctrl-^": "pane::AlternateFile"
"ctrl-6": "pane::AlternateFile"
}
},
{
@@ -204,7 +203,6 @@
"c": "vim::PushChange",
"shift-c": "vim::ChangeToEndOfLine",
"d": "vim::PushDelete",
"delete": "vim::DeleteRight",
"shift-d": "vim::DeleteToEndOfLine",
"shift-j": "vim::JoinLines",
"g shift-j": "vim::JoinLinesNoWhitespace",
@@ -540,7 +538,6 @@
"bindings": {
"d": "vim::CurrentLine",
"s": "vim::PushDeleteSurrounds",
"v": "vim::PushForcedMotion", // "d v"
"o": "editor::ToggleSelectedDiffHunks", // "d o"
"shift-o": "git::ToggleStaged",
"p": "git::Restore", // "d p"
@@ -589,7 +586,6 @@
"context": "vim_operator == y",
"bindings": {
"y": "vim::CurrentLine",
"v": "vim::PushForcedMotion",
"s": ["vim::PushAddSurrounds", {}]
}
},
@@ -787,7 +783,8 @@
"{": "project_panel::SelectPrevDirectory",
"shift-g": "menu::SelectLast",
"g g": "menu::SelectFirst",
"-": "project_panel::SelectParent"
"-": "project_panel::SelectParent",
"ctrl-6": "pane::AlternateFile"
}
},
{

View File

@@ -1,74 +1,170 @@
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
You are an AI assistant integrated into a code editor. You have the programming ability of an expert programmer who takes pride in writing high-quality code and is driven to the point of obsession about solving problems effectively. Your goal is to do one of the following two things:
## Communication
1. Help users answer questions and perform tasks related to their codebase.
2. Answer general-purpose questions unrelated to their particular codebase.
1. Be conversational but professional.
2. Refer to the USER in the second person and yourself in the first person.
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
4. NEVER lie or make things up.
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
It will be up to you to decide which of these you are doing based on what the user has told you. When unclear, ask clarifying questions to understand the user's intent before proceeding.
## Searching and Reading
You should only perform actions that modify the user's system if explicitly requested by the user:
- If the user asks a question about how to accomplish a task, provide guidance or information, and use read-only tools (e.g., search) to assist. You may suggest potential actions, but do not directly modify the user's system without explicit instruction.
- If the user clearly requests that you perform an action, carry out the action directly without explaining why you are doing so.
If you are unsure about the answer to the user's request or how to satiate their request, you should gather more information.
This can be done with additional tool calls, asking clarifying questions, etc.
When answering questions, it's okay to give incomplete examples containing comments about what would go there in a real version. When being asked to directly perform tasks on the code base, you must ALWAYS make fully working code. You may never "simplify" the code by omitting or deleting functionality you know the user has requested, and you must NEVER write comments like "in a full version, this would..." - instead, you must actually implement the real version. Don't be lazy!
For example, if you've performed a semantic search, and the results may not fully answer the user's request, or merit gathering more information, feel free to call more tools. Similarly, if you've performed an edit that may partially
satiate the user's query, but you're not confident, gather more information or use more tools before ending your turn.
Note that project files are automatically backed up. The user can always get them back later if anything goes wrong, so there's
no need to create backup files (e.g. `.bak` files) because these files will just take up unnecessary space on the user's disk.
Bias towards not asking the user for help if you can find the answer yourself.
When attempting to resolve issues around failing tests, never simply remove the failing tests. Unless the user explicitly asks you to remove tests, ALWAYS attempt to fix the code causing the tests to fail.
## Tool Use
Ignore "TODO"-type comments unless they're relevant to the user's explicit request or the user specifically asks you to address them. It is, however, okay to include them in codebase summaries.
1. Make sure to adhere to the tools schema.
2. Provide every required argument.
3. DO NOT use tools to access items that are already available in the context section.
4. Use only the tools that are currently available.
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
<style>
Editing code:
- Make sure to take previous edits into account.
- The edits you perform might lead to errors or warnings. At the end of your changes, check whether you introduced any problems, and fix them before providing a summary of the changes you made.
- You may only attempt to fix these up to 3 times. If you have tried 3 times to fix them, and there are still problems remaining, you must not continue trying to fix them, and must instead tell the user that there are problems remaining - and ask if the user would like you to attempt to solve them further.
- Do not fix errors unrelated to your changes unless the user explicitly asks you to do so.
- Prefer to move files over recreating them. The move can be followed by minor edits if required.
- If you seem to be stuck, never go back and "simplify the implementation" by deleting the parts of the implementation you're stuck on and replacing them with comments. If you ever feel the urge to do this, instead immediately stop whatever you're doing (even if the code is in a broken state), report that you are stuck, explain what you're stuck on, and ask the user how to proceed.
## Fixing Diagnostics
Tool use:
- Make sure to adhere to the tools schema.
- Provide every required argument.
- DO NOT use tools to access items that are already available in the context section.
- Use only the tools that are currently available.
- DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
Responding:
- Be concise and direct in your responses.
- Never apologize or thank the user.
- Don't comment that you have just realized or understood something.
- When you are going to make a tool call, tersely explain your reasoning for choosing to use that tool, with no flourishes or commentary beyond that information.
For example, rather than saying "You're absolutely right! Thank you for providing that context. Now I understand that we're missing a dependency, and I need to add it:" say "I'll add that missing dependency:" instead.
- Also, don't restate what a tool call is about to do (or just did).
For example, don't say "Now I'm going to check diagnostics to see if there are any warnings or errors," followed by running a tool which checks diagnostics and reports warnings or errors; instead, just request the tool call without saying anything.
- All tool results are provided to you automatically, so DO NOT thank the user when this happens.
## Debugging
Whenever you mention a code block, you MUST use ONLY the following format:
When debugging, only make code changes if you are certain that you can solve the problem.
Otherwise, follow debugging best practices:
1. Address the root cause instead of the symptoms.
2. Add descriptive logging statements and error messages to track variable and code state.
3. Add test functions and statements to isolate the problem.
```language path/to/Something.blah#L123-456
(code goes here)
```
## Calling External APIs
The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah
is a path in the project. (If there is no valid path in the project, then you can use
/dev/null/path.extension for its path.) This is the ONLY valid way to format code blocks, because the Markdown parser
does not understand the more common ```language syntax, or bare ``` blocks. It only
understands this path-based syntax, and if the path is missing, then it will error and you will have to do it over again.
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
Just to be really clear about this, if you ever find yourself writing three backticks followed by a language name, STOP!
You have made a mistake. You can only ever put paths after triple backticks!
## System Information
<example>
Based on all the information I've gathered, here's a summary of how this system works:
1. The README file is loaded into the system.
2. The system finds the first two headers, including everything in between. In this case, that would be:
Operating System: {{os}}
Default Shell: {{shell}}
```path/to/README.md#L8-12
# First Header
This is the info under the first header.
## Sub-header
```
3. Then the system finds the last header in the README:
```path/to/README.md#L27-29
## Last Header
This is the last header in the README.
```
4. Finally, it passes this information on to the next process.
</example>
<example>
In Markdown, hash marks signify headings. For example:
```/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</example>
Here are examples of ways you must never render code blocks:
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it does not include the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it has the language instead of the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
# Level 1 heading
## Level 2 heading
### Level 3 heading
</bad_example_do_not_do_this>
This example is unacceptable because it uses indentation to mark the code block
instead of backticks with a path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
</style>
The user has opened a project that contains the following root directories/files. Whenever you specify a path in the project, it must be a relative path which begins with one of these root directories/files:
{{#each worktrees}}
- `{{root_name}}` (absolute path: `{{abs_path}}`)
{{/each}}
{{#if has_rules}}
## User's Custom Instructions
The following additional instructions are provided by the user, and should be followed to the best of your ability without interfering with the tool use guidelines.
There are rules that apply to these root directories:
{{#each worktrees}}
{{#if rules_file}}
`{{root_name}}/{{rules_file.path_in_worktree}}`:
``````
{{{rules_file.text}}}
``````
{{/if}}
{{/each}}
{{/if}}
<user_environment>
Operating System: {{os}} ({{arch}})
Shell: {{shell}}
</user_environment>

View File

@@ -80,8 +80,6 @@
// Values are clamped to the [0.0, 1.0] range.
"inactive_opacity": 1.0
},
// Layout mode of the bottom dock. Defaults to "contained"
"bottom_dock_layout": "contained",
// The direction that you want to split panes horizontally. Defaults to "up"
"pane_split_direction_horizontal": "up",
// The direction that you want to split panes horizontally. Defaults to "left"
@@ -214,14 +212,7 @@
// The default number of lines to expand excerpts in the multibuffer by.
"expand_excerpt_lines": 3,
// Globs to match against file paths to determine if a file is private.
"private_files": [
"**/.env*",
"**/*.pem",
"**/*.key",
"**/*.cert",
"**/*.crt",
"**/secrets.yml"
],
"private_files": ["**/.env*", "**/*.pem", "**/*.key", "**/*.cert", "**/*.crt", "**/secrets.yml"],
// Whether to use additional LSP queries to format (and amend) the code after
// every "trigger" symbol input, defined by LSP server capabilities.
"use_on_type_format": true,
@@ -594,6 +585,7 @@
//
// Default: main
"fallback_branch_name": "main",
"scrollbar": {
// When to show the scrollbar in the git panel.
//
@@ -632,14 +624,14 @@
// The provider to use.
"provider": "zed.dev",
// The model to use.
"model": "claude-3-7-sonnet-latest"
"model": "claude-3-5-sonnet-latest"
},
// The model to use when applying edits from the assistant.
"editor_model": {
// The provider to use.
"provider": "zed.dev",
// The model to use.
"model": "claude-3-7-sonnet-latest"
"model": "claude-3-5-sonnet-latest"
},
// When enabled, the agent can run potentially destructive actions without asking for your confirmation.
"always_allow_tool_actions": false,
@@ -650,7 +642,6 @@
// We don't know which of the context server tools are safe for the "Ask" profile, so we don't enable them by default.
// "enable_all_context_servers": true,
"tools": {
"contents": true,
"diagnostics": true,
"fetch": true,
"list_directory": false,
@@ -658,35 +649,32 @@
"path_search": true,
"read_file": true,
"regex_search": true,
"thinking": true,
"web_search": true
"thinking": true
}
},
"write": {
"name": "Write",
"enable_all_context_servers": true,
"tools": {
"batch_tool": false,
"code_actions": false,
"code_symbols": false,
"contents": false,
"terminal": true,
"batch_tool": true,
"code_actions": true,
"code_symbols": true,
"copy_path": false,
"create_file": true,
"delete_path": false,
"diagnostics": true,
"edit_file": true,
"find_replace_file": true,
"fetch": true,
"list_directory": true,
"list_directory": false,
"move_path": false,
"now": false,
"now": true,
"path_search": true,
"read_file": true,
"regex_search": true,
"rename": false,
"symbol_info": false,
"terminal": true,
"thinking": true,
"web_search": true
"rename": true,
"symbol_info": true,
"thinking": true
}
}
},
@@ -721,9 +709,7 @@
// The list of language servers to use (or disable) for all languages.
//
// This is typically customized on a per-language basis.
"language_servers": [
"..."
],
"language_servers": ["..."],
// When to automatically save edited buffers. This setting can
// take four values.
//
@@ -919,9 +905,7 @@
// for files that are not tracked by git, but are still important to your project. Note that globs
// that are overly broad can slow down Zed's file scanning. `file_scan_exclusions` takes
// precedence over these inclusions.
"file_scan_inclusions": [
".env*"
],
"file_scan_inclusions": [".env*"],
// Git gutter behavior configuration.
"git": {
// Control whether the git gutter is shown. May take 2 values:
@@ -973,15 +957,7 @@
// Any addition to this list will be merged with the default list.
// Globs are matched relative to the worktree root,
// except when starting with a slash (/) or equivalent in Windows.
"disabled_globs": [
"**/.env*",
"**/*.pem",
"**/*.key",
"**/*.cert",
"**/*.crt",
"**/.dev.vars",
"**/secrets.yml"
],
"disabled_globs": ["**/.env*", "**/*.pem", "**/*.key", "**/*.cert", "**/*.crt", "**/.dev.vars", "**/secrets.yml"],
// When to show edit predictions previews in buffer.
// This setting takes two possible values:
// 1. Display predictions inline when there are no language server completions available.
@@ -1114,12 +1090,7 @@
// Default directories to search for virtual environments, relative
// to the current working directory. We recommend overriding this
// in your project's settings, rather than globally.
"directories": [
".env",
"env",
".venv",
"venv"
],
"directories": [".env", "env", ".venv", "venv"],
// Can also be `csh`, `fish`, `nushell` and `power_shell`
"activate_script": "default"
}
@@ -1183,15 +1154,8 @@
// }
//
"file_types": {
"JSONC": [
"**/.zed/**/*.json",
"**/zed/**/*.json",
"**/Zed/**/*.json",
"**/.vscode/**/*.json"
],
"Shell Script": [
".env.*"
]
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "**/.vscode/**/*.json"],
"Shell Script": [".env.*"]
},
// By default use a recent system version of node, or install our own.
// You can override this to use a version of node that is not in $PATH with:
@@ -1264,15 +1228,10 @@
// Different settings for specific languages.
"languages": {
"Astro": {
"language_servers": [
"astro-language-server",
"..."
],
"language_servers": ["astro-language-server", "..."],
"prettier": {
"allowed": true,
"plugins": [
"prettier-plugin-astro"
]
"plugins": ["prettier-plugin-astro"]
}
},
"Blade": {
@@ -1308,19 +1267,10 @@
"ensure_final_newline_on_save": false
},
"Elixir": {
"language_servers": [
"elixir-ls",
"!next-ls",
"!lexical",
"..."
]
"language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."]
},
"Erlang": {
"language_servers": [
"erlang-ls",
"!elp",
"..."
]
"language_servers": ["erlang-ls", "!elp", "..."]
},
"Git Commit": {
"allow_rewrap": "anywhere"
@@ -1336,12 +1286,7 @@
}
},
"HEEX": {
"language_servers": [
"elixir-ls",
"!next-ls",
"!lexical",
"..."
]
"language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."]
},
"HTML": {
"prettier": {
@@ -1351,17 +1296,11 @@
"Java": {
"prettier": {
"allowed": true,
"plugins": [
"prettier-plugin-java"
]
"plugins": ["prettier-plugin-java"]
}
},
"JavaScript": {
"language_servers": [
"!typescript-language-server",
"vtsls",
"..."
],
"language_servers": ["!typescript-language-server", "vtsls", "..."],
"prettier": {
"allowed": true
}
@@ -1379,10 +1318,7 @@
"LaTeX": {
"format_on_save": "on",
"formatter": "language_server",
"language_servers": [
"texlab",
"..."
],
"language_servers": ["texlab", "..."],
"prettier": {
"allowed": false
}
@@ -1397,16 +1333,10 @@
}
},
"PHP": {
"language_servers": [
"phpactor",
"!intelephense",
"..."
],
"language_servers": ["phpactor", "!intelephense", "..."],
"prettier": {
"allowed": true,
"plugins": [
"@prettier/plugin-php"
],
"plugins": ["@prettier/plugin-php"],
"parser": "php"
}
},
@@ -1414,12 +1344,7 @@
"allow_rewrap": "anywhere"
},
"Ruby": {
"language_servers": [
"solargraph",
"!ruby-lsp",
"!rubocop",
"..."
]
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "..."]
},
"SCSS": {
"prettier": {
@@ -1429,36 +1354,21 @@
"SQL": {
"prettier": {
"allowed": true,
"plugins": [
"prettier-plugin-sql"
]
"plugins": ["prettier-plugin-sql"]
}
},
"Starlark": {
"language_servers": [
"starpls",
"!buck2-lsp",
"..."
]
"language_servers": ["starpls", "!buck2-lsp", "..."]
},
"Svelte": {
"language_servers": [
"svelte-language-server",
"..."
],
"language_servers": ["svelte-language-server", "..."],
"prettier": {
"allowed": true,
"plugins": [
"prettier-plugin-svelte"
]
"plugins": ["prettier-plugin-svelte"]
}
},
"TSX": {
"language_servers": [
"!typescript-language-server",
"vtsls",
"..."
],
"language_servers": ["!typescript-language-server", "vtsls", "..."],
"prettier": {
"allowed": true
}
@@ -1469,20 +1379,13 @@
}
},
"TypeScript": {
"language_servers": [
"!typescript-language-server",
"vtsls",
"..."
],
"language_servers": ["!typescript-language-server", "vtsls", "..."],
"prettier": {
"allowed": true
}
},
"Vue.js": {
"language_servers": [
"vue-language-server",
"..."
],
"language_servers": ["vue-language-server", "..."],
"prettier": {
"allowed": true
}
@@ -1490,9 +1393,7 @@
"XML": {
"prettier": {
"allowed": true,
"plugins": [
"@prettier/plugin-xml"
]
"plugins": ["@prettier/plugin-xml"]
}
},
"YAML": {
@@ -1501,10 +1402,7 @@
}
},
"Zig": {
"language_servers": [
"zls",
"..."
]
"language_servers": ["zls", "..."]
}
},
// Different settings for specific language models.
@@ -1658,6 +1556,7 @@
// }
// ]
"ssh_connections": [],
// Configures context servers for use in the Assistant.
"context_servers": {},
"debugger": {

View File

@@ -19,6 +19,7 @@ test-support = [
]
[dependencies]
agent_rules.workspace = true
anyhow.workspace = true
assistant_context_editor.workspace = true
assistant_settings.workspace = true
@@ -80,7 +81,6 @@ terminal.workspace = true
terminal_view.workspace = true
text.workspace = true
theme.workspace = true
thiserror.workspace = true
time.workspace = true
time_format.workspace = true
ui.workspace = true
@@ -90,7 +90,6 @@ uuid.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
buffer_diff = { workspace = true, features = ["test-support"] }

View File

@@ -1,45 +1,40 @@
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::context::{AssistantContext, ContextId};
use crate::context_picker::MentionLink;
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
ThreadEvent, ThreadFeedback,
};
use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse};
use crate::thread_store::ThreadStore;
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use assistant_tool::ToolUseStatus;
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer};
use editor::{Editor, MultiBuffer};
use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem,
DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment,
ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription,
Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Hsla, ListAlignment, ListState,
MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, Task,
TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, RequestUsage, Role,
StopReason,
};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::parser::CodeBlockKind;
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, without_fences};
use project::ProjectItem as _;
use rope::Point;
use settings::{Settings as _, update_settings_file};
use std::ops::Range;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;
use text::ToPoint;
use theme::ThemeSettings;
use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
};
use ui::{Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*};
use util::ResultExt as _;
use workspace::{OpenOptions, Workspace};
@@ -64,7 +59,6 @@ pub struct ActiveThread {
expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
expanded_code_blocks: HashMap<(MessageId, usize), bool>,
last_error: Option<ThreadError>,
last_usage: Option<RequestUsage>,
notifications: Vec<WindowHandle<AgentNotification>>,
copied_code_block_ids: HashSet<(MessageId, usize)>,
_subscriptions: Vec<Subscription>,
@@ -180,37 +174,11 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
});
MarkdownStyle {
base_text_style: text_style.clone(),
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block_overflow_x_scroll: true,
table_overflow_x_scroll: true,
heading_level_styles: Some(HeadingLevelStyles {
h1: Some(TextStyleRefinement {
font_size: Some(rems(1.15).into()),
..Default::default()
}),
h2: Some(TextStyleRefinement {
font_size: Some(rems(1.1).into()),
..Default::default()
}),
h3: Some(TextStyleRefinement {
font_size: Some(rems(1.05).into()),
..Default::default()
}),
h4: Some(TextStyleRefinement {
font_size: Some(rems(1.).into()),
..Default::default()
}),
h5: Some(TextStyleRefinement {
font_size: Some(rems(0.95).into()),
..Default::default()
}),
h6: Some(TextStyleRefinement {
font_size: Some(rems(0.875).into()),
..Default::default()
}),
}),
code_block: StyleRefinement {
padding: EdgesRefinement {
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
@@ -322,14 +290,12 @@ fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
}
}
const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 10;
fn render_markdown_code_block(
message_id: MessageId,
ix: usize,
kind: &CodeBlockKind,
parsed_markdown: &ParsedMarkdown,
metadata: CodeBlockMetadata,
codeblock_range: Range<usize>,
active_thread: Entity<ActiveThread>,
workspace: WeakEntity<Workspace>,
_window: &Window,
@@ -501,15 +467,22 @@ fn render_markdown_code_block(
.element_background
.blend(cx.theme().colors().editor_foreground.opacity(0.01));
const CODE_FENCES_LINE_COUNT: usize = 2;
const MAX_COLLAPSED_LINES: usize = 5;
let line_count = parsed_markdown.source()[codeblock_range.clone()]
.bytes()
.filter(|c| *c == b'\n')
.count()
.saturating_sub(CODE_FENCES_LINE_COUNT - 1);
let codeblock_header = h_flex()
.group("codeblock_header")
.py_1()
.pl_1p5()
.pr_1()
.p_1()
.gap_1()
.justify_between()
.border_b_1()
.border_color(cx.theme().colors().border.opacity(0.6))
.border_color(cx.theme().colors().border_variant)
.bg(codeblock_header_bg)
.rounded_t_md()
.children(label)
@@ -532,14 +505,16 @@ fn render_markdown_code_block(
.on_click({
let active_thread = active_thread.clone();
let parsed_markdown = parsed_markdown.clone();
let code_block_range = metadata.content_range.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
this.copied_code_block_ids.insert((message_id, ix));
let code = parsed_markdown.source()[code_block_range.clone()]
.to_string();
cx.write_to_clipboard(ClipboardItem::new_string(code));
let code = without_fences(
&parsed_markdown.source()[codeblock_range.clone()],
)
.to_string();
cx.write_to_clipboard(ClipboardItem::new_string(code.clone()));
cx.spawn(async move |this, cx| {
cx.background_executor()
@@ -561,41 +536,38 @@ fn render_markdown_code_block(
}),
),
)
.when(
metadata.line_count > MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK,
|header| {
header.child(
IconButton::new(
("expand-collapse-code", ix),
if is_expanded {
IconName::ChevronUp
} else {
IconName::ChevronDown
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text(if is_expanded {
"Collapse Code"
.when(line_count > MAX_COLLAPSED_LINES, |header| {
header.child(
IconButton::new(
("expand-collapse-code", ix),
if is_expanded {
IconName::ChevronUp
} else {
"Expand Code"
}))
.on_click({
let active_thread = active_thread.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
let is_expanded = this
.expanded_code_blocks
.entry((message_id, ix))
.or_insert(false);
*is_expanded = !*is_expanded;
cx.notify();
});
}
}),
IconName::ChevronDown
},
)
},
),
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text(if is_expanded {
"Collapse Code"
} else {
"Expand Code"
}))
.on_click({
let active_thread = active_thread.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
let is_expanded = this
.expanded_code_blocks
.entry((message_id, ix))
.or_insert(false);
*is_expanded = !*is_expanded;
cx.notify();
});
}
}),
)
}),
);
v_flex()
@@ -603,19 +575,16 @@ fn render_markdown_code_block(
.overflow_hidden()
.rounded_lg()
.border_1()
.border_color(cx.theme().colors().border.opacity(0.6))
.border_color(cx.theme().colors().border_variant)
.bg(cx.theme().colors().editor_background)
.child(codeblock_header)
.when(
metadata.line_count > MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK,
|this| {
if is_expanded {
this.h_full()
} else {
this.max_h_80()
}
},
)
.when(line_count > MAX_COLLAPSED_LINES, |this| {
if is_expanded {
this.h_full()
} else {
this.max_h_40()
}
})
}
fn open_markdown_link(
@@ -688,9 +657,6 @@ fn open_markdown_link(
struct EditMessageState {
editor: Entity<Editor>,
last_estimated_token_count: Option<usize>,
_subscription: Subscription,
_update_token_count_task: Option<Task<anyhow::Result<()>>>,
}
impl ActiveThread {
@@ -706,7 +672,6 @@ impl ActiveThread {
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe_in(&thread, window, Self::handle_thread_event),
cx.subscribe(&thread_store, Self::handle_rules_loading_error),
];
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), {
@@ -736,7 +701,6 @@ impl ActiveThread {
hide_scrollbar_task: None,
editing_message: None,
last_error: None,
last_usage: None,
copied_code_block_ids: HashSet::default(),
notifications: Vec::new(),
_subscriptions: subscriptions,
@@ -761,10 +725,6 @@ impl ActiveThread {
this
}
pub fn context_store(&self) -> &Entity<ContextStore> {
&self.context_store
}
pub fn thread(&self) -> &Entity<Thread> {
&self.thread
}
@@ -795,17 +755,6 @@ impl ActiveThread {
self.last_error.take();
}
pub fn last_usage(&self) -> Option<RequestUsage> {
self.last_usage
}
/// Returns the editing message id and the estimated token count in the content
pub fn editing_message_id(&self) -> Option<(MessageId, usize)> {
self.editing_message
.as_ref()
.map(|(id, state)| (*id, state.last_estimated_token_count.unwrap_or(0)))
}
fn push_message(
&mut self,
id: &MessageId,
@@ -883,17 +832,15 @@ impl ActiveThread {
ThreadEvent::ShowError(error) => {
self.last_error = Some(error.clone());
}
ThreadEvent::UsageUpdated(usage) => {
self.last_usage = Some(*usage);
}
ThreadEvent::StreamedCompletion
| ThreadEvent::SummaryGenerated
| ThreadEvent::SummaryChanged => {
self.save_thread(cx);
}
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
let thread = self.thread.read(cx);
ThreadEvent::DoneStreaming => {
let thread = self.thread.read(cx);
if !thread.is_generating() {
self.show_notification(
if thread.used_tools_since_last_user_message() {
"Finished running tools"
@@ -905,8 +852,7 @@ impl ActiveThread {
cx,
);
}
_ => {}
},
}
ThreadEvent::ToolConfirmationNeeded => {
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
}
@@ -972,8 +918,8 @@ impl ActiveThread {
&tool_use.input,
self.thread
.read(cx)
.output_for_tool(&tool_use.id)
.map(|output| output.clone().into())
.tool_result(&tool_use.id)
.map(|result| result.content.clone().into())
.unwrap_or("".into()),
cx,
);
@@ -983,19 +929,6 @@ impl ActiveThread {
}
}
fn handle_rules_loading_error(
&mut self,
_thread_store: Entity<ThreadStore>,
error: &RulesLoadingError,
cx: &mut Context<Self>,
) {
self.last_error = Some(ThreadError::Message {
header: "Error loading rules file".into(),
message: error.message.clone(),
});
cx.notify();
}
fn show_notification(
&mut self,
caption: impl Into<SharedString>,
@@ -1154,91 +1087,15 @@ impl ActiveThread {
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
editor
});
let subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => {
this.update_editing_message_token_count(true, cx);
}
_ => {}
});
self.editing_message = Some((
message_id,
EditMessageState {
editor: editor.clone(),
last_estimated_token_count: None,
_subscription: subscription,
_update_token_count_task: None,
},
));
self.update_editing_message_token_count(false, cx);
cx.notify();
}
fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context<Self>) {
let Some((message_id, state)) = self.editing_message.as_mut() else {
return;
};
cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
state._update_token_count_task.take();
let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
state.last_estimated_token_count.take();
return;
};
let editor = state.editor.clone();
let thread = self.thread.clone();
let message_id = *message_id;
state._update_token_count_task = Some(cx.spawn(async move |this, cx| {
if debounce {
cx.background_executor()
.timer(Duration::from_millis(200))
.await;
}
let token_count = if let Some(task) = cx.update(|cx| {
let context = thread.read(cx).context_for_message(message_id);
let new_context = thread.read(cx).filter_new_context(context);
let context_text =
format_context_as_string(new_context, cx).unwrap_or(String::new());
let message_text = editor.read(cx).text(cx);
let content = context_text + &message_text;
if content.is_empty() {
return None;
}
let request = language_model::LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![content.into()],
cache: false,
}],
tools: vec![],
stop: vec![],
temperature: None,
};
Some(default_model.model.count_tokens(request, cx))
})? {
task.await?
} else {
0
};
this.update(cx, |this, cx| {
let Some((_message_id, state)) = this.editing_message.as_mut() else {
return;
};
state.last_estimated_token_count = Some(token_count);
cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
})
}));
}
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
self.editing_message.take();
cx.notify();
@@ -1507,23 +1364,16 @@ impl ActiveThread {
let editor_bg_color = colors.editor_background;
let bg_user_message_header = editor_bg_color.blend(active_color.opacity(0.25));
let open_as_markdown = IconButton::new(("open-as-markdown", ix), IconName::FileCode)
let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileCode)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
.tooltip(Tooltip::text("Open Thread as Markdown"))
.on_click(|_, window, cx| {
.on_click(|_event, window, cx| {
window.dispatch_action(Box::new(OpenActiveThreadAsMarkdown), cx)
});
// For all items that should be aligned with the Assistant's response.
const RESPONSE_PADDING_X: Pixels = px(18.);
let feedback_container = h_flex()
.py_2()
.px(RESPONSE_PADDING_X)
.gap_1()
.justify_between();
let feedback_container = h_flex().py_2().px_4().gap_1().justify_between();
let feedback_items = match self.thread.read(cx).message_feedback(message_id) {
Some(feedback) => feedback_container
.child(
@@ -1636,36 +1486,12 @@ impl ActiveThread {
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.pt_1()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.child(edit_message_editor)
.into_any()
} else {
div()
@@ -1724,9 +1550,9 @@ impl ActiveThread {
this.pt_4()
}
})
.pb_4()
.pl_2()
.pr_2p5()
.pb_4()
.child(
v_flex()
.bg(colors.editor_background)
@@ -1787,9 +1613,6 @@ impl ActiveThread {
"confirm-edit-message",
"Regenerate",
)
.disabled(
edit_message_editor.read(cx).is_empty(cx),
)
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
@@ -1833,8 +1656,11 @@ impl ActiveThread {
),
Role::Assistant => v_flex()
.id(("message-container", ix))
.px(RESPONSE_PADDING_X)
.gap_2()
.ml_2()
.pl_2()
.pr_4()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.children(message_content)
.when(has_tool_uses, |parent| {
parent.children(
@@ -1851,15 +1677,6 @@ impl ActiveThread {
),
};
let after_editing_message = self
.editing_message
.as_ref()
.map_or(false, |(editing_message_id, _)| {
message_id > *editing_message_id
});
let panel_background = cx.theme().colors().panel_background;
v_flex()
.w_full()
.when_some(checkpoint, |parent, checkpoint| {
@@ -2023,18 +1840,6 @@ impl ActiveThread {
},
)
})
.when(after_editing_message, |parent| {
// Backdrop to dim out the whole thread below the editing user message
parent.relative().child(
div()
.occlude()
.absolute()
.inset_0()
.size_full()
.bg(panel_background)
.opacity(0.8),
)
})
.into_any()
}
@@ -2061,15 +1866,6 @@ impl ActiveThread {
None
};
let message_role = self
.thread
.read(cx)
.message(message_id)
.map(|m| m.role)
.unwrap_or(Role::User);
let is_assistant = message_role == Role::Assistant;
v_flex()
.text_ui(cx)
.gap_2()
@@ -2090,100 +1886,77 @@ impl ActiveThread {
cx,
)
.into_any_element(),
RenderedMessageSegment::Text(markdown) => {
let markdown_element = MarkdownElement::new(
markdown.clone(),
default_markdown_style(window, cx),
);
RenderedMessageSegment::Text(markdown) => div()
.child(
MarkdownElement::new(
markdown.clone(),
default_markdown_style(window, cx),
)
.code_block_renderer(markdown::CodeBlockRenderer::Custom {
render: Arc::new({
let workspace = workspace.clone();
let active_thread = cx.entity();
move |kind, parsed_markdown, range, window, cx| {
render_markdown_code_block(
message_id,
range.start,
kind,
parsed_markdown,
range,
active_thread.clone(),
workspace.clone(),
window,
cx,
)
}
}),
transform: Some(Arc::new({
let active_thread = cx.entity();
move |el, range, _, cx| {
let is_expanded = active_thread
.read(cx)
.expanded_code_blocks
.get(&(message_id, range.start))
.copied()
.unwrap_or(false);
let markdown_element = if is_assistant {
markdown_element.code_block_renderer(
markdown::CodeBlockRenderer::Custom {
render: Arc::new({
let workspace = workspace.clone();
let active_thread = cx.entity();
move |kind,
parsed_markdown,
range,
metadata,
window,
cx| {
render_markdown_code_block(
message_id,
range.start,
kind,
parsed_markdown,
metadata,
active_thread.clone(),
workspace.clone(),
window,
cx,
)
if is_expanded {
return el;
}
}),
transform: Some(Arc::new({
let active_thread = cx.entity();
move |el, range, metadata, _, cx| {
let is_expanded = active_thread
.read(cx)
.expanded_code_blocks
.get(&(message_id, range.start))
.copied()
.unwrap_or(false);
if is_expanded
|| metadata.line_count
<= MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK
{
return el;
}
el.child(
div()
.absolute()
.bottom_0()
.left_0()
.w_full()
.h_1_4()
.rounded_b_lg()
.bg(gpui::linear_gradient(
el.child(
div()
.absolute()
.bottom_0()
.left_0()
.w_full()
.h_1_4()
.rounded_b_lg()
.bg(gpui::linear_gradient(
0.,
gpui::linear_color_stop(
cx.theme().colors().editor_background,
0.,
gpui::linear_color_stop(
cx.theme()
.colors()
.editor_background,
0.,
),
gpui::linear_color_stop(
cx.theme()
.colors()
.editor_background
.opacity(0.),
1.,
),
)),
)
}
})),
},
)
} else {
markdown_element.code_block_renderer(
markdown::CodeBlockRenderer::Default {
copy_button: false,
border: true,
},
)
};
div()
.child(markdown_element.on_url_click({
),
gpui::linear_color_stop(
cx.theme()
.colors()
.editor_background
.opacity(0.),
1.,
),
)),
)
}
})),
})
.on_url_click({
let workspace = self.workspace.clone();
move |text, window, cx| {
open_markdown_link(text, workspace.clone(), window, cx);
}
}))
.into_any_element()
}
}),
)
.into_any_element(),
},
),
)
@@ -2443,10 +2216,6 @@ impl ActiveThread {
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement + use<> {
if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) {
return card.render(&tool_use.status, window, cx);
}
let is_open = self
.expanded_tool_uses
.get(&tool_use.id)
@@ -2511,10 +2280,6 @@ impl ActiveThread {
rendered.input.clone(),
tool_use_markdown_style(window, cx),
)
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
border: false,
})
.on_url_click({
let workspace = self.workspace.clone();
move |text, window, cx| {
@@ -2541,17 +2306,12 @@ impl ActiveThread {
rendered.output.clone(),
tool_use_markdown_style(window, cx),
)
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
border: false,
})
.on_url_click({
let workspace = self.workspace.clone();
move |text, window, cx| {
open_markdown_link(text, workspace.clone(), window, cx);
}
})
.into_any_element()
}),
)),
),
@@ -2608,7 +2368,6 @@ impl ActiveThread {
open_markdown_link(text, workspace.clone(), window, cx);
}
})
.into_any_element()
})),
),
),
@@ -2646,10 +2405,11 @@ impl ActiveThread {
))
};
v_flex().gap_1().mb_3().map(|element| {
div().map(|element| {
if !edit_tools {
element.child(
v_flex()
.my_2()
.child(
h_flex()
.group("disclosure-header")
@@ -2671,7 +2431,7 @@ impl ActiveThread {
.color(Color::Muted),
)
.child(
h_flex().pr_8().text_size(rems(0.8125)).children(
h_flex().pr_8().text_ui_sm(cx).children(
rendered_tool_use.map(|rendered| MarkdownElement::new(rendered.label, tool_use_markdown_style(window, cx)).on_url_click({let workspace = self.workspace.clone(); move |text, window, cx| {
open_markdown_link(text, workspace.clone(), window, cx);
}}))
@@ -2721,7 +2481,7 @@ impl ActiveThread {
)
} else {
v_flex()
.mb_2()
.my_3()
.rounded_lg()
.border_1()
.border_color(self.tool_card_border_color(cx))
@@ -2938,17 +2698,16 @@ impl ActiveThread {
)
})
}
}).into_any_element()
})
}
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
let project_context = self.thread.read(cx).project_context();
let project_context = project_context.borrow();
let Some(project_context) = project_context.as_ref() else {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
return div().into_any();
};
let rules_files = project_context
let rules_files = system_prompt_context
.worktrees
.iter()
.filter_map(|worktree| worktree.rules_file.as_ref())
@@ -3038,13 +2797,12 @@ impl ActiveThread {
}
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let project_context = self.thread.read(cx).project_context();
let project_context = project_context.borrow();
let Some(project_context) = project_context.as_ref() else {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
return;
};
let abs_paths = project_context
let abs_paths = system_prompt_context
.worktrees
.iter()
.flat_map(|worktree| worktree.rules_file.as_ref())
@@ -3130,12 +2888,6 @@ impl ActiveThread {
}
}
pub enum ActiveThreadEvent {
EditingMessageTokenCountChanged,
}
impl EventEmitter<ActiveThreadEvent> for ActiveThread {}
impl Render for ActiveThread {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
@@ -3211,21 +2963,28 @@ pub(crate) fn open_context(
.start
.to_point(&snapshot);
open_editor_at_position(project_path, target_position, &workspace, window, cx)
.detach();
}
}
AssistantContext::Excerpt(excerpt_context) => {
if let Some(project_path) = excerpt_context
.context_buffer
.buffer
.read(cx)
.project_path(cx)
{
let snapshot = excerpt_context.context_buffer.buffer.read(cx).snapshot();
let target_position = excerpt_context.range.start.to_point(&snapshot);
open_editor_at_position(project_path, target_position, &workspace, window, cx)
let open_task = workspace.update(cx, |workspace, cx| {
workspace.open_path(project_path, None, true, window, cx)
});
window
.spawn(cx, async move |cx| {
if let Some(active_editor) = open_task
.await
.log_err()
.and_then(|item| item.downcast::<Editor>())
{
active_editor
.downgrade()
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(
target_position,
window,
cx,
);
})
.log_err();
}
})
.detach();
}
}
@@ -3246,29 +3005,3 @@ pub(crate) fn open_context(
}
}
}
fn open_editor_at_position(
project_path: project::ProjectPath,
target_position: Point,
workspace: &Entity<Workspace>,
window: &mut Window,
cx: &mut App,
) -> Task<()> {
let open_task = workspace.update(cx, |workspace, cx| {
workspace.open_path(project_path, None, true, window, cx)
});
window.spawn(cx, async move |cx| {
if let Some(active_editor) = open_task
.await
.log_err()
.and_then(|item| item.downcast::<Editor>())
{
active_editor
.downgrade()
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(target_position, window, cx);
})
.log_err();
}
})
}

View File

@@ -1,7 +1,7 @@
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
use crate::{Keep, Reject, Thread, ThreadEvent};
use anyhow::Result;
use buffer_diff::DiffHunkStatus;
use collections::{HashMap, HashSet};
use collections::HashSet;
use editor::{
Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
actions::{GoToHunk, GoToPreviousHunk},
@@ -355,24 +355,16 @@ impl AgentDiff {
self.update_selection(&diff_hunks_in_ranges, window, cx);
}
let mut ranges_by_buffer = HashMap::default();
for hunk in &diff_hunks_in_ranges {
let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
if let Some(buffer) = buffer {
ranges_by_buffer
.entry(buffer.clone())
.or_insert_with(Vec::new)
.push(hunk.buffer_range.clone());
self.thread
.update(cx, |thread, cx| {
thread.reject_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
})
.detach_and_log_err(cx);
}
}
for (buffer, ranges) in ranges_by_buffer {
self.thread
.update(cx, |thread, cx| {
thread.reject_edits_in_ranges(buffer, ranges, cx)
})
.detach_and_log_err(cx);
}
}
fn update_selection(
@@ -851,7 +843,7 @@ impl ToolbarItemView for AgentDiffToolbar {
}
impl Render for AgentDiffToolbar {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let agent_diff = match self.agent_diff(cx) {
Some(ad) => ad,
None => return div(),
@@ -863,8 +855,6 @@ impl Render for AgentDiffToolbar {
return div();
}
let focus_handle = agent_diff.focus_handle(cx);
h_group_xl()
.my_neg_1()
.items_center()
@@ -874,25 +864,15 @@ impl Render for AgentDiffToolbar {
.child(
h_group_sm()
.child(
Button::new("reject-all", "Reject All")
.key_binding({
KeyBinding::for_action_in(&RejectAll, &focus_handle, window, cx)
.map(|kb| kb.size(rems_from_px(12.)))
})
.on_click(cx.listener(|this, _, window, cx| {
this.dispatch_action(&RejectAll, window, cx)
})),
Button::new("reject-all", "Reject All").on_click(cx.listener(
|this, _, window, cx| {
this.dispatch_action(&crate::RejectAll, window, cx)
},
)),
)
.child(
Button::new("keep-all", "Keep All")
.key_binding({
KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx)
.map(|kb| kb.size(rems_from_px(12.)))
})
.on_click(cx.listener(|this, _, window, cx| {
this.dispatch_action(&KeepAll, window, cx)
})),
),
.child(Button::new("keep-all", "Keep All").on_click(cx.listener(
|this, _, window, cx| this.dispatch_action(&crate::KeepAll, window, cx),
))),
)
}
}
@@ -902,7 +882,6 @@ mod tests {
use super::*;
use crate::{ThreadStore, thread_store};
use assistant_settings::AssistantSettings;
use assistant_tool::ToolWorkingSet;
use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::TestAppContext;
@@ -942,16 +921,15 @@ mod tests {
})
.unwrap();
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})
.await;
let thread_store = cx.update(|cx| {
ThreadStore::new(
project.clone(),
Arc::default(),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
.unwrap()
});
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());

View File

@@ -18,7 +18,6 @@ mod terminal_inline_assistant;
mod thread;
mod thread_history;
mod thread_store;
mod tool_compatibility;
mod tool_use;
mod ui;
@@ -47,11 +46,10 @@ pub use agent_diff::{AgentDiff, AgentDiffToolbar};
actions!(
agent,
[
NewTextThread,
NewPromptEditor,
ToggleContextPicker,
ToggleProfileSelector,
RemoveAllContext,
ExpandMessageEditor,
OpenHistory,
AddContextServer,
RemoveSelectedThread,

View File

@@ -9,15 +9,10 @@ use assistant_tool::{ToolSource, ToolWorkingSet};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use fs::Fs;
use gpui::{
Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription,
};
use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
use settings::{Settings, update_settings_file};
use ui::{
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
Switch, Tooltip, prelude::*,
};
use ui::{Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, prelude::*};
use util::ResultExt as _;
use zed_actions::ExtensionCategoryFilter;
@@ -32,17 +27,15 @@ pub struct AssistantConfiguration {
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
context_server_manager: Entity<ContextServerManager>,
expanded_context_server_tools: HashMap<Arc<str>, bool>,
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
_registry_subscription: Subscription,
scroll_handle: ScrollHandle,
scrollbar_state: ScrollbarState,
}
impl AssistantConfiguration {
pub fn new(
fs: Arc<dyn Fs>,
context_server_manager: Entity<ContextServerManager>,
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -65,9 +58,6 @@ impl AssistantConfiguration {
},
);
let scroll_handle = ScrollHandle::new();
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
let mut this = Self {
fs,
focus_handle,
@@ -76,8 +66,6 @@ impl AssistantConfiguration {
expanded_context_server_tools: HashMap::default(),
tools,
_registry_subscription: registry_subscription,
scroll_handle,
scrollbar_state,
};
this.build_provider_configuration_views(window, cx);
this
@@ -119,7 +107,7 @@ pub enum AssistantConfigurationEvent {
impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
impl AssistantConfiguration {
fn render_provider_configuration_block(
fn render_provider_configuration(
&mut self,
provider: &Arc<dyn LanguageModelProvider>,
cx: &mut Context<Self>,
@@ -174,7 +162,7 @@ impl AssistantConfiguration {
.p(DynamicSpacing::Base08.rems(cx))
.bg(cx.theme().colors().editor_background)
.border_1()
.border_color(cx.theme().colors().border)
.border_color(cx.theme().colors().border_variant)
.rounded_sm()
.map(|parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
@@ -185,33 +173,6 @@ impl AssistantConfiguration {
)
}
fn render_provider_configuration_section(
&mut self,
cx: &mut Context<Self>,
) -> impl IntoElement {
let providers = LanguageModelRegistry::read_global(cx).providers();
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_4()
.flex_1()
.child(
v_flex()
.gap_0p5()
.child(Headline::new("LLM Providers").size(HeadlineSize::Small))
.child(
Label::new("Add at least one provider to use AI-powered features.")
.color(Color::Muted),
),
)
.children(
providers
.into_iter()
.map(|provider| self.render_provider_configuration_block(&provider, cx)),
)
}
fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
@@ -219,7 +180,6 @@ impl AssistantConfiguration {
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_2()
.flex_1()
.child(Headline::new("General Settings").size(HeadlineSize::Small))
@@ -264,23 +224,19 @@ impl AssistantConfiguration {
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let tools_by_source = self.tools.tools_by_source(cx);
let empty = Vec::new();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_2()
.flex_1()
.child(
v_flex()
.gap_0p5()
.child(
Headline::new("Model Context Protocol (MCP) Servers")
.size(HeadlineSize::Small),
)
.child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
.child(Label::new(SUBHEADING).color(Color::Muted)),
)
.children(context_servers.into_iter().map(|context_server| {
@@ -306,9 +262,10 @@ impl AssistantConfiguration {
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.p_1()
.justify_between()
.when(are_tools_expanded && tool_count > 1, |element| {
.px_2()
.py_1()
.when(are_tools_expanded, |element| {
element
.border_b_1()
.border_color(cx.theme().colors().border)
@@ -318,7 +275,6 @@ impl AssistantConfiguration {
.gap_2()
.child(
Disclosure::new("tool-list-disclosure", are_tools_expanded)
.disabled(tool_count == 0)
.on_click(cx.listener({
let context_server_id = context_server.id();
move |this, _event, _window, _cx| {
@@ -339,11 +295,10 @@ impl AssistantConfiguration {
.child(Label::new(context_server.id()))
.child(
Label::new(format!("{tool_count} tools"))
.color(Color::Muted)
.size(LabelSize::Small),
.color(Color::Muted),
),
)
.child(
.child(h_flex().child(
Switch::new("context-server-switch", is_running.into()).on_click({
let context_server_manager =
self.context_server_manager.clone();
@@ -379,7 +334,7 @@ impl AssistantConfiguration {
}
}
}),
),
)),
)
.map(|parent| {
if !are_tools_expanded {
@@ -389,29 +344,14 @@ impl AssistantConfiguration {
parent.child(v_flex().children(tools.into_iter().enumerate().map(
|(ix, tool)| {
h_flex()
.id("tool-item")
.pl_2()
.pr_1()
.px_2()
.py_1()
.gap_2()
.justify_between()
.when(ix < tool_count - 1, |element| {
element
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.border_color(cx.theme().colors().border)
})
.child(
Label::new(tool.name())
.buffer_font(cx)
.size(LabelSize::Small),
)
.child(
IconButton::new(("tool-description", ix), IconName::Info)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Ignored)
.tooltip(Tooltip::text(tool.description())),
)
.child(Label::new(tool.name()))
},
)))
})
@@ -422,7 +362,7 @@ impl AssistantConfiguration {
.gap_2()
.child(
h_flex().w_full().child(
Button::new("add-context-server", "Add MCPs Directly")
Button::new("add-context-server", "Add Context Server")
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
.full_width()
@@ -438,7 +378,7 @@ impl AssistantConfiguration {
h_flex().w_full().child(
Button::new(
"install-context-server-extensions",
"Install MCP Extensions",
"Install Context Server Extensions",
)
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
@@ -465,51 +405,39 @@ impl AssistantConfiguration {
impl Render for AssistantConfiguration {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let providers = LanguageModelRegistry::read_global(cx).providers();
v_flex()
.id("assistant-configuration")
.key_context("AgentConfiguration")
.track_focus(&self.focus_handle(cx))
.relative()
.size_full()
.pb_8()
.bg(cx.theme().colors().panel_background)
.size_full()
.overflow_y_scroll()
.child(self.render_command_permission(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_context_servers_section(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(
v_flex()
.id("assistant-configuration-content")
.track_scroll(&self.scroll_handle)
.size_full()
.overflow_y_scroll()
.child(self.render_command_permission(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_context_servers_section(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_provider_configuration_section(cx)),
)
.child(
div()
.id("assistant-configuration-scrollbar")
.occlude()
.absolute()
.right(px(3.))
.top_0()
.bottom_0()
.pb_6()
.w(px(12.))
.cursor_default()
.on_mouse_move(cx.listener(|_, _, _window, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _window, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _window, cx| {
cx.stop_propagation();
})
.on_scroll_wheel(cx.listener(|_, _, _window, cx| {
cx.notify();
}))
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
.p(DynamicSpacing::Base16.rems(cx))
.mt_1()
.gap_6()
.flex_1()
.child(
v_flex()
.gap_0p5()
.child(Headline::new("LLM Providers").size(HeadlineSize::Small))
.child(
Label::new("Add at least one provider to use AI-powered features.")
.color(Color::Muted),
),
)
.children(
providers
.into_iter()
.map(|provider| self.render_provider_configuration(&provider, cx)),
),
)
}
}

View File

@@ -84,7 +84,7 @@ pub struct NewProfileMode {
pub struct ManageProfilesModal {
fs: Arc<dyn Fs>,
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle,
mode: Mode,
@@ -117,7 +117,7 @@ impl ManageProfilesModal {
pub fn new(
fs: Arc<dyn Fs>,
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut Context<Self>,

View File

@@ -60,7 +60,7 @@ pub struct ToolPickerDelegate {
impl ToolPickerDelegate {
pub fn new(
fs: Arc<dyn Fs>,
tool_set: Entity<ToolWorkingSet>,
tool_set: Arc<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
profile_id: AgentProfileId,
profile: AgentProfile,
@@ -68,7 +68,7 @@ impl ToolPickerDelegate {
) -> Self {
let mut tool_entries = Vec::new();
for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
for (source, tools) in tool_set.tools_by_source(cx) {
tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
name: tool.name().into(),
source: source.clone(),
@@ -192,7 +192,7 @@ impl PickerDelegate for ToolPickerDelegate {
if active_profile_id == &self.profile_id {
self.thread_store
.update(cx, |this, cx| {
this.load_profile(self.profile.clone(), cx);
this.load_profile(&self.profile, cx);
})
.log_err();
}

View File

@@ -80,16 +80,17 @@ impl AssistantModelSelector {
impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let focus_handle = self.focus_handle.clone();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = match self.model_type {
ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
};
let (model_name, model_icon) = match model {
Some(model) => (model.model.name().0, Some(model.provider.icon())),
_ => (SharedString::from("No model selected"), None),
let focus_handle = self.focus_handle.clone();
let model_name = match model {
Some(model) => model.model.name().0,
_ => SharedString::from("No model selected"),
};
LanguageModelSelectorPopoverMenu::new(
@@ -99,16 +100,10 @@ impl Render for AssistantModelSelector {
.child(
h_flex()
.gap_0p5()
.children(
model_icon.map(|icon| {
Icon::new(icon).color(Color::Muted).size(IconSize::Small)
}),
)
.child(
Label::new(model_name)
.size(LabelSize::Small)
.color(Color::Muted)
.ml_1(),
.color(Color::Muted),
)
.child(
Icon::new(IconName::ChevronDown)

View File

@@ -1,4 +1,3 @@
use std::ops::Range;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
@@ -6,14 +5,14 @@ use std::time::Duration;
use anyhow::{Result, anyhow};
use assistant_context_editor::{
AssistantPanelDelegate, ConfigurationError, ContextEditor, SlashCommandCompletionProvider,
humanize_token_count, make_lsp_adapter_delegate, render_remaining_tokens,
make_lsp_adapter_delegate, render_remaining_tokens,
};
use assistant_settings::{AssistantDockPosition, AssistantSettings};
use assistant_slash_command::SlashCommandWorkingSet;
use assistant_tool::ToolWorkingSet;
use client::zed_urls;
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use editor::{Editor, EditorEvent, MultiBuffer};
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, Corner, Entity,
@@ -26,7 +25,6 @@ use language_model_selector::ToggleModelSelector;
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::PromptBuilder;
use proto::Plan;
use settings::{Settings, update_settings_file};
use time::UtcOffset;
use ui::{
@@ -38,17 +36,16 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
use zed_actions::agent::OpenConfiguration;
use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus};
use crate::active_thread::{ActiveThread, ActiveThreadEvent};
use crate::active_thread::ActiveThread;
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
use crate::history_store::{HistoryEntry, HistoryStore};
use crate::message_editor::{MessageEditor, MessageEditorEvent};
use crate::message_editor::MessageEditor;
use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
use crate::thread_history::{PastContext, PastThread, ThreadHistory};
use crate::thread_store::ThreadStore;
use crate::ui::UsageBanner;
use crate::{
AgentDiff, ExpandMessageEditor, InlineAssistant, NewTextThread, NewThread,
OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker,
AgentDiff, InlineAssistant, NewPromptEditor, NewThread, OpenActiveThreadAsMarkdown,
OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker,
};
pub fn init(cx: &mut App) {
@@ -73,7 +70,7 @@ pub fn init(cx: &mut App) {
panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
}
})
.register_action(|workspace, _: &NewTextThread, window, cx| {
.register_action(|workspace, _: &NewPromptEditor, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
@@ -93,16 +90,6 @@ pub fn init(cx: &mut App) {
let thread = panel.read(cx).thread.read(cx).thread().clone();
AgentDiff::deploy_in_workspace(thread, workspace, window, cx);
}
})
.register_action(|workspace, _: &ExpandMessageEditor, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.message_editor.update(cx, |editor, cx| {
editor.expand_message_editor(&ExpandMessageEditor, window, cx);
});
});
}
});
},
)
@@ -114,9 +101,7 @@ enum ActiveView {
change_title_editor: Entity<Editor>,
_subscriptions: Vec<gpui::Subscription>,
},
PromptEditor {
context_editor: Entity<ContextEditor>,
},
PromptEditor,
History,
Configuration,
}
@@ -185,9 +170,10 @@ pub struct AssistantPanel {
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<ActiveThread>,
_thread_subscription: Subscription,
message_editor: Entity<MessageEditor>,
_active_thread_subscriptions: Vec<Subscription>,
context_store: Entity<assistant_context_editor::ContextStore>,
context_editor: Option<Entity<ContextEditor>>,
configuration: Option<Entity<AssistantConfiguration>>,
configuration_subscription: Option<Subscription>,
local_timezone: UtcOffset,
@@ -207,13 +193,11 @@ impl AssistantPanel {
cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
let tools = cx.new(|_| ToolWorkingSet::default())?;
let thread_store = workspace
.update(cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
})?
.await;
let tools = Arc::new(ToolWorkingSet::default());
let thread_store = workspace.update(cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::new(project, tools.clone(), prompt_builder.clone(), cx)
})??;
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
let context_store = workspace
@@ -267,13 +251,6 @@ impl AssistantPanel {
)
});
let message_editor_subscription =
cx.subscribe(&message_editor, |_, _, event, cx| match event {
MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
cx.notify();
}
});
let history_store =
cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx));
@@ -298,12 +275,6 @@ impl AssistantPanel {
)
});
let active_thread_subscription = cx.subscribe(&thread, |_, _, event, cx| match &event {
ActiveThreadEvent::EditingMessageTokenCountChanged => {
cx.notify();
}
});
Self {
active_view,
workspace,
@@ -312,13 +283,10 @@ impl AssistantPanel {
language_registry,
thread_store: thread_store.clone(),
thread,
_thread_subscription: thread_subscription,
message_editor,
_active_thread_subscriptions: vec![
thread_subscription,
active_thread_subscription,
message_editor_subscription,
],
context_store,
context_editor: None,
configuration: None,
configuration_subscription: None,
local_timezone: UtcOffset::from_whole_seconds(
@@ -401,13 +369,6 @@ impl AssistantPanel {
.detach_and_log_err(cx);
}
let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
if let ThreadEvent::MessageAdded(_) = &event {
// needed to leave empty state
cx.notify();
}
});
self.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
@@ -420,12 +381,12 @@ impl AssistantPanel {
)
});
let active_thread_subscription =
cx.subscribe(&self.thread, |_, _, event, cx| match &event {
ActiveThreadEvent::EditingMessageTokenCountChanged => {
cx.notify();
}
});
self._thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
if let ThreadEvent::MessageAdded(_) = &event {
// needed to leave empty state
cx.notify();
}
});
self.message_editor = cx.new(|cx| {
MessageEditor::new(
@@ -439,22 +400,11 @@ impl AssistantPanel {
)
});
self.message_editor.focus_handle(cx).focus(window);
let message_editor_subscription =
cx.subscribe(&self.message_editor, |_, _, event, cx| match event {
MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
cx.notify();
}
});
self._active_thread_subscriptions = vec![
thread_subscription,
active_thread_subscription,
message_editor_subscription,
];
}
fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.set_active_view(ActiveView::PromptEditor, window, cx);
let context = self
.context_store
.update(cx, |context_store, cx| context_store.create(cx));
@@ -462,7 +412,7 @@ impl AssistantPanel {
.log_err()
.flatten();
let context_editor = cx.new(|cx| {
self.context_editor = Some(cx.new(|cx| {
let mut editor = ContextEditor::for_context(
context,
self.fs.clone(),
@@ -474,16 +424,11 @@ impl AssistantPanel {
);
editor.insert_default_prompt(window, cx);
editor
});
}));
self.set_active_view(
ActiveView::PromptEditor {
context_editor: context_editor.clone(),
},
window,
cx,
);
context_editor.focus_handle(cx).focus(window);
if let Some(context_editor) = self.context_editor.as_ref() {
context_editor.focus_handle(cx).focus(window);
}
}
fn deploy_prompt_library(
@@ -550,13 +495,8 @@ impl AssistantPanel {
cx,
)
});
this.set_active_view(
ActiveView::PromptEditor {
context_editor: editor,
},
window,
cx,
);
this.set_active_view(ActiveView::PromptEditor, window, cx);
this.context_editor = Some(editor);
anyhow::Ok(())
})??;
@@ -585,13 +525,6 @@ impl AssistantPanel {
Some(this.thread_store.downgrade()),
)
});
let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
if let ThreadEvent::MessageAdded(_) = &event {
// needed to leave empty state
cx.notify();
}
});
this.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
@@ -603,14 +536,6 @@ impl AssistantPanel {
cx,
)
});
let active_thread_subscription =
cx.subscribe(&this.thread, |_, _, event, cx| match &event {
ActiveThreadEvent::EditingMessageTokenCountChanged => {
cx.notify();
}
});
this.message_editor = cx.new(|cx| {
MessageEditor::new(
this.fs.clone(),
@@ -623,19 +548,6 @@ impl AssistantPanel {
)
});
this.message_editor.focus_handle(cx).focus(window);
let message_editor_subscription =
cx.subscribe(&this.message_editor, |_, _, event, cx| match event {
MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
cx.notify();
}
});
this._active_thread_subscriptions = vec![
thread_subscription,
active_thread_subscription,
message_editor_subscription,
];
})
})
}
@@ -645,7 +557,6 @@ impl AssistantPanel {
ActiveView::Configuration | ActiveView::History => {
self.active_view =
ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx);
self.message_editor.focus_handle(cx).focus(window);
cx.notify();
}
_ => {}
@@ -787,15 +698,8 @@ impl AssistantPanel {
.update(cx, |this, cx| this.delete_thread(thread_id, cx))
}
pub(crate) fn has_active_thread(&self) -> bool {
matches!(self.active_view, ActiveView::Thread { .. })
}
pub(crate) fn active_context_editor(&self) -> Option<Entity<ContextEditor>> {
match &self.active_view {
ActiveView::PromptEditor { context_editor } => Some(context_editor.clone()),
_ => None,
}
self.context_editor.clone()
}
pub(crate) fn delete_context(
@@ -833,10 +737,16 @@ impl AssistantPanel {
impl Focusable for AssistantPanel {
fn focus_handle(&self, cx: &App) -> FocusHandle {
match &self.active_view {
match self.active_view {
ActiveView::Thread { .. } => self.message_editor.focus_handle(cx),
ActiveView::History => self.history.focus_handle(cx),
ActiveView::PromptEditor { context_editor } => context_editor.focus_handle(cx),
ActiveView::PromptEditor => {
if let Some(context_editor) = self.context_editor.as_ref() {
context_editor.focus_handle(cx)
} else {
cx.focus_handle()
}
}
ActiveView::Configuration => {
if let Some(configuration) = self.configuration.as_ref() {
configuration.focus_handle(cx)
@@ -929,7 +839,7 @@ impl Panel for AssistantPanel {
}
impl AssistantPanel {
fn render_title_view(&self, _window: &mut Window, cx: &Context<Self>) -> AnyElement {
fn render_title_view(&self, _window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
const LOADING_SUMMARY_PLACEHOLDER: &str = "Loading Summary…";
let content = match &self.active_view {
@@ -960,8 +870,15 @@ impl AssistantPanel {
.into_any_element()
}
}
ActiveView::PromptEditor { context_editor } => {
let title = SharedString::from(context_editor.read(cx).title(cx).to_string());
ActiveView::PromptEditor => {
let title = self
.context_editor
.as_ref()
.map(|context_editor| {
SharedString::from(context_editor.read(cx).title(cx).to_string())
})
.unwrap_or_else(|| SharedString::from(LOADING_SUMMARY_PLACEHOLDER));
Label::new(title).ml_2().truncate().into_any_element()
}
ActiveView::History => Label::new("History").truncate().into_any_element(),
@@ -982,18 +899,21 @@ impl AssistantPanel {
fn render_toolbar(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let active_thread = self.thread.read(cx);
let thread = active_thread.thread().read(cx);
let token_usage = thread.total_token_usage(cx);
let thread_id = thread.id().clone();
let is_generating = thread.is_generating();
let is_empty = active_thread.is_empty();
let focus_handle = self.focus_handle(cx);
let is_history = matches!(self.active_view, ActiveView::History);
let show_token_count = match &self.active_view {
ActiveView::Thread { .. } => !is_empty,
ActiveView::PromptEditor { .. } => true,
ActiveView::PromptEditor => self.context_editor.is_some(),
_ => false,
};
let focus_handle = self.focus_handle(cx);
let go_back_button = match &self.active_view {
ActiveView::History | ActiveView::Configuration => Some(
div().pl_1().child(
@@ -1040,9 +960,69 @@ impl AssistantPanel {
h_flex()
.h_full()
.gap_2()
.when(show_token_count, |parent|
parent.children(self.render_token_count(&thread, cx))
)
.when(show_token_count, |parent| match self.active_view {
ActiveView::Thread { .. } => {
if token_usage.total == 0 {
return parent;
}
let token_color = match token_usage.ratio {
TokenUsageRatio::Normal => Color::Muted,
TokenUsageRatio::Warning => Color::Warning,
TokenUsageRatio::Exceeded => Color::Error,
};
parent.child(
h_flex()
.flex_shrink_0()
.gap_0p5()
.child(
Label::new(assistant_context_editor::humanize_token_count(
token_usage.total,
))
.size(LabelSize::Small)
.color(token_color)
.map(|label| {
if is_generating {
label
.with_animation(
"used-tokens-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(
0.6, 1.,
)),
|label, delta| label.alpha(delta),
)
.into_any()
} else {
label.into_any_element()
}
}),
)
.child(
Label::new("/").size(LabelSize::Small).color(Color::Muted),
)
.child(
Label::new(assistant_context_editor::humanize_token_count(
token_usage.max,
))
.size(LabelSize::Small)
.color(Color::Muted),
),
)
}
ActiveView::PromptEditor => {
let Some(editor) = self.context_editor.as_ref() else {
return parent;
};
let Some(element) = render_remaining_tokens(editor, cx) else {
return parent;
};
parent.child(element)
}
_ => parent,
})
.child(
h_flex()
.h_full()
@@ -1106,30 +1086,20 @@ impl AssistantPanel {
window,
cx,
|menu, _window, _cx| {
menu
.when(!is_empty, |menu| {
menu.action(
"Start New From Summary",
Box::new(NewThread {
from_thread_id: Some(thread_id.clone()),
}),
).separator()
})
.action(
menu.action(
"New Text Thread",
NewTextThread.boxed_clone(),
NewPromptEditor.boxed_clone(),
)
.action("Prompt Library", Box::new(OpenPromptLibrary))
.action("Settings", Box::new(OpenConfiguration))
.separator()
.action(
"Install MCPs",
Box::new(zed_actions::Extensions {
category_filter: Some(
zed_actions::ExtensionCategoryFilter::ContextServers,
),
.when(!is_empty, |menu| {
menu.action(
"Continue in New Thread",
Box::new(NewThread {
from_thread_id: Some(thread_id.clone()),
}),
)
)
})
.separator()
.action("Settings", OpenConfiguration.boxed_clone())
},
))
}),
@@ -1138,110 +1108,6 @@ impl AssistantPanel {
)
}
fn render_token_count(&self, thread: &Thread, cx: &App) -> Option<AnyElement> {
let is_generating = thread.is_generating();
let message_editor = self.message_editor.read(cx);
let conversation_token_usage = thread.total_token_usage(cx);
let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) =
self.thread.read(cx).editing_message_id()
{
let combined = thread
.token_usage_up_to_message(editing_message_id, cx)
.add(unsent_tokens);
(combined, unsent_tokens > 0)
} else {
let unsent_tokens = message_editor.last_estimated_token_count().unwrap_or(0);
let combined = conversation_token_usage.add(unsent_tokens);
(combined, unsent_tokens > 0)
};
let is_waiting_to_update_token_count = message_editor.is_waiting_to_update_token_count();
match &self.active_view {
ActiveView::Thread { .. } => {
if total_token_usage.total == 0 {
return None;
}
let token_color = match total_token_usage.ratio() {
TokenUsageRatio::Normal if is_estimating => Color::Default,
TokenUsageRatio::Normal => Color::Muted,
TokenUsageRatio::Warning => Color::Warning,
TokenUsageRatio::Exceeded => Color::Error,
};
let token_count = h_flex()
.id("token-count")
.flex_shrink_0()
.gap_0p5()
.when(!is_generating && is_estimating, |parent| {
parent
.child(
h_flex()
.mr_1()
.size_2p5()
.justify_center()
.rounded_full()
.bg(cx.theme().colors().text.opacity(0.1))
.child(
div().size_1().rounded_full().bg(cx.theme().colors().text),
),
)
.tooltip(move |window, cx| {
Tooltip::with_meta(
"Estimated New Token Count",
None,
format!(
"Current Conversation Tokens: {}",
humanize_token_count(conversation_token_usage.total)
),
window,
cx,
)
})
})
.child(
Label::new(humanize_token_count(total_token_usage.total))
.size(LabelSize::Small)
.color(token_color)
.map(|label| {
if is_generating || is_waiting_to_update_token_count {
label
.with_animation(
"used-tokens-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any()
} else {
label.into_any_element()
}
}),
)
.child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
.child(
Label::new(humanize_token_count(total_token_usage.max))
.size(LabelSize::Small)
.color(Color::Muted),
)
.into_any();
Some(token_count)
}
ActiveView::PromptEditor { context_editor } => {
let element = render_remaining_tokens(context_editor, cx)?;
Some(element.into_any_element())
}
_ => None,
}
}
fn render_active_thread_or_empty_state(
&self,
window: &mut Window,
@@ -1430,7 +1296,6 @@ impl AssistantPanel {
let configuration_error_ref = &configuration_error;
parent
.overflow_hidden()
.p_1p5()
.justify_end()
.gap_1()
@@ -1542,12 +1407,6 @@ impl AssistantPanel {
})
}
fn render_usage_banner(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let usage = self.thread.read(cx).last_usage()?;
Some(UsageBanner::new(zed_llm_client::Plan::ZedProTrial, usage.amount).into_any_element())
}
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let last_error = self.thread.read(cx).last_error()?;
@@ -1566,9 +1425,6 @@ impl AssistantPanel {
ThreadError::MaxMonthlySpendReached => {
self.render_max_monthly_spend_reached_error(cx)
}
ThreadError::ModelRequestLimitReached { plan } => {
self.render_model_request_limit_reached_error(plan, cx)
}
ThreadError::Message { header, message } => {
self.render_error_message(header, message, cx)
}
@@ -1671,71 +1527,6 @@ impl AssistantPanel {
.into_any()
}
fn render_model_request_limit_reached_error(
&self,
plan: Plan,
cx: &mut Context<Self>,
) -> AnyElement {
let error_message = match plan {
Plan::ZedPro => {
"Model request limit reached. Upgrade to usage-based billing for more requests."
}
Plan::ZedProTrial => {
"Model request limit reached. Upgrade to Zed Pro for more requests."
}
Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
};
let call_to_action = match plan {
Plan::ZedPro => "Upgrade to usage-based billing",
Plan::ZedProTrial => "Upgrade to Zed Pro",
Plan::Free => "Upgrade to Zed Pro",
};
v_flex()
.gap_0p5()
.child(
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new("Model Request Limit Reached").weight(FontWeight::MEDIUM)),
)
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(error_message)),
)
.child(
h_flex()
.justify_end()
.mt_1()
.child(
Button::new("subscribe", call_to_action).on_click(cx.listener(
|this, _, _, cx| {
this.thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.open_url(&zed_urls::account_url(cx));
cx.notify();
},
)),
)
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, _, cx| {
this.thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.notify();
},
))),
)
.into_any()
}
fn render_error_message(
&self,
header: SharedString,
@@ -1778,7 +1569,7 @@ impl AssistantPanel {
fn key_context(&self) -> KeyContext {
let mut key_context = KeyContext::new_with_defaults();
key_context.add("AgentPanel");
if matches!(self.active_view, ActiveView::PromptEditor { .. }) {
if matches!(self.active_view, ActiveView::PromptEditor) {
key_context.add("prompt_editor");
}
key_context
@@ -1806,14 +1597,13 @@ impl Render for AssistantPanel {
.on_action(cx.listener(Self::open_agent_diff))
.on_action(cx.listener(Self::go_back))
.child(self.render_toolbar(window, cx))
.map(|parent| match &self.active_view {
.map(|parent| match self.active_view {
ActiveView::Thread { .. } => parent
.child(self.render_active_thread_or_empty_state(window, cx))
.children(self.render_usage_banner(cx))
.child(h_flex().child(self.message_editor.clone()))
.children(self.render_last_error(cx)),
ActiveView::History => parent.child(self.history.clone()),
ActiveView::PromptEditor { context_editor } => parent.child(context_editor.clone()),
ActiveView::PromptEditor => parent.children(self.context_editor.clone()),
ActiveView::Configuration => parent.children(self.configuration.clone()),
})
}
@@ -1878,7 +1668,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
cx: &mut Context<Workspace>,
) -> Option<Entity<ContextEditor>> {
let panel = workspace.panel::<AssistantPanel>(cx)?;
panel.read(cx).active_context_editor()
panel.update(cx, |panel, _cx| panel.context_editor.clone())
}
fn open_saved_context(
@@ -1909,59 +1699,10 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
fn quote_selection(
&self,
workspace: &mut Workspace,
selection_ranges: Vec<Range<Anchor>>,
buffer: Entity<MultiBuffer>,
window: &mut Window,
cx: &mut Context<Workspace>,
_workspace: &mut Workspace,
_creases: Vec<(String, String)>,
_window: &mut Window,
_cx: &mut Context<Workspace>,
) {
let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
return;
};
if !panel.focus_handle(cx).contains_focused(window, cx) {
workspace.toggle_panel_focus::<AssistantPanel>(window, cx);
}
panel.update(cx, |_, cx| {
// Wait to create a new context until the workspace is no longer
// being updated.
cx.defer_in(window, move |panel, window, cx| {
if panel.has_active_thread() {
panel.thread.update(cx, |thread, cx| {
thread.context_store().update(cx, |store, cx| {
let buffer = buffer.read(cx);
let selection_ranges = selection_ranges
.into_iter()
.flat_map(|range| {
let (start_buffer, start) =
buffer.text_anchor_for_position(range.start, cx)?;
let (end_buffer, end) =
buffer.text_anchor_for_position(range.end, cx)?;
if start_buffer != end_buffer {
return None;
}
Some((start_buffer, start..end))
})
.collect::<Vec<_>>();
for (buffer, range) in selection_ranges {
store.add_excerpt(range, buffer, cx).detach_and_log_err(cx);
}
})
})
} else if let Some(context_editor) = panel.active_context_editor() {
let snapshot = buffer.read(cx).snapshot(cx);
let selection_ranges = selection_ranges
.into_iter()
.map(|range| range.to_point(&snapshot))
.collect::<Vec<_>>();
context_editor.update(cx, |context_editor, cx| {
context_editor.quote_ranges(selection_ranges, snapshot, window, cx)
});
}
});
});
}
}

View File

@@ -4,7 +4,6 @@ use gpui::{App, Entity, SharedString};
use language::{Buffer, File};
use language_model::LanguageModelRequestMessage;
use project::{ProjectPath, Worktree};
use rope::Point;
use serde::{Deserialize, Serialize};
use text::{Anchor, BufferId};
use ui::IconName;
@@ -24,7 +23,6 @@ pub enum ContextKind {
File,
Directory,
Symbol,
Excerpt,
FetchedUrl,
Thread,
}
@@ -35,7 +33,6 @@ impl ContextKind {
ContextKind::File => IconName::File,
ContextKind::Directory => IconName::Folder,
ContextKind::Symbol => IconName::Code,
ContextKind::Excerpt => IconName::Code,
ContextKind::FetchedUrl => IconName::Globe,
ContextKind::Thread => IconName::MessageBubbles,
}
@@ -49,7 +46,6 @@ pub enum AssistantContext {
Symbol(SymbolContext),
FetchedUrl(FetchedUrlContext),
Thread(ThreadContext),
Excerpt(ExcerptContext),
}
impl AssistantContext {
@@ -60,7 +56,6 @@ impl AssistantContext {
Self::Symbol(symbol) => symbol.id,
Self::FetchedUrl(url) => url.id,
Self::Thread(thread) => thread.id,
Self::Excerpt(excerpt) => excerpt.id,
}
}
}
@@ -160,14 +155,6 @@ pub struct ContextSymbolId {
pub range: Range<Anchor>,
}
#[derive(Debug, Clone)]
pub struct ExcerptContext {
pub id: ContextId,
pub range: Range<Anchor>,
pub line_range: Range<Point>,
pub context_buffer: ContextBuffer,
}
/// Formats a collection of contexts into a string representation
pub fn format_context_as_string<'a>(
contexts: impl Iterator<Item = &'a AssistantContext>,
@@ -176,7 +163,6 @@ pub fn format_context_as_string<'a>(
let mut file_context = Vec::new();
let mut directory_context = Vec::new();
let mut symbol_context = Vec::new();
let mut excerpt_context = Vec::new();
let mut fetch_context = Vec::new();
let mut thread_context = Vec::new();
@@ -185,7 +171,6 @@ pub fn format_context_as_string<'a>(
AssistantContext::File(context) => file_context.push(context),
AssistantContext::Directory(context) => directory_context.push(context),
AssistantContext::Symbol(context) => symbol_context.push(context),
AssistantContext::Excerpt(context) => excerpt_context.push(context),
AssistantContext::FetchedUrl(context) => fetch_context.push(context),
AssistantContext::Thread(context) => thread_context.push(context),
}
@@ -194,7 +179,6 @@ pub fn format_context_as_string<'a>(
if file_context.is_empty()
&& directory_context.is_empty()
&& symbol_context.is_empty()
&& excerpt_context.is_empty()
&& fetch_context.is_empty()
&& thread_context.is_empty()
{
@@ -232,15 +216,6 @@ pub fn format_context_as_string<'a>(
result.push_str("</symbols>\n");
}
if !excerpt_context.is_empty() {
result.push_str("<excerpts>\n");
for context in excerpt_context {
result.push_str(&context.context_buffer.text);
result.push('\n');
}
result.push_str("</excerpts>\n");
}
if !fetch_context.is_empty() {
result.push_str("<fetched_urls>\n");
for context in &fetch_context {

View File

@@ -34,6 +34,12 @@ use crate::context_store::ContextStore;
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
#[derive(Debug, Clone, Copy)]
pub enum ConfirmBehavior {
KeepOpen,
Close,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContextPickerMode {
File,
@@ -99,6 +105,7 @@ pub(super) struct ContextPicker {
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
confirm_behavior: ConfirmBehavior,
_subscriptions: Vec<Subscription>,
}
@@ -107,6 +114,7 @@ impl ContextPicker {
workspace: WeakEntity<Workspace>,
thread_store: Option<WeakEntity<ThreadStore>>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -135,6 +143,7 @@ impl ContextPicker {
workspace,
context_store,
thread_store,
confirm_behavior,
_subscriptions: subscriptions,
}
}
@@ -157,32 +166,37 @@ impl ContextPicker {
let modes = supported_context_picker_modes(&self.thread_store);
menu.when(has_recent, |menu| {
menu.custom_row(|_, _| {
div()
.mb_1()
.child(
Label::new("Recent")
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element()
})
})
.extend(recent_entries)
.when(has_recent, |menu| menu.separator())
.extend(modes.into_iter().map(|mode| {
let context_picker = context_picker.clone();
ContextMenuEntry::new(mode.label())
.icon(mode.icon())
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.handler(move |window, cx| {
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
let menu = menu
.when(has_recent, |menu| {
menu.custom_row(|_, _| {
div()
.mb_1()
.child(
Label::new("Recent")
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element()
})
}))
.keep_open_on_confirm()
})
.extend(recent_entries)
.when(has_recent, |menu| menu.separator())
.extend(modes.into_iter().map(|mode| {
let context_picker = context_picker.clone();
ContextMenuEntry::new(mode.label())
.icon(mode.icon())
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.handler(move |window, cx| {
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
})
}));
match self.confirm_behavior {
ConfirmBehavior::KeepOpen => menu.keep_open_on_confirm(),
ConfirmBehavior::Close => menu,
}
});
cx.subscribe(&menu, move |_, _, _: &DismissEvent, cx| {
@@ -213,6 +227,7 @@ impl ContextPicker {
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -224,6 +239,7 @@ impl ContextPicker {
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -235,6 +251,7 @@ impl ContextPicker {
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -247,6 +264,7 @@ impl ContextPicker {
thread_store.clone(),
context_picker.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -491,13 +509,14 @@ fn recent_context_picker_entries(
recent
}
pub(crate) fn insert_fold_for_mention(
pub(crate) fn insert_crease_for_mention(
excerpt_id: ExcerptId,
crease_start: text::Anchor,
content_len: usize,
crease_label: SharedString,
crease_icon_path: SharedString,
editor_entity: Entity<Editor>,
window: &mut Window,
cx: &mut App,
) {
editor_entity.update(cx, |editor, cx| {
@@ -516,7 +535,6 @@ pub(crate) fn insert_fold_for_mention(
crease_label,
editor_entity.downgrade(),
),
merge_adjacent: false,
..Default::default()
};
@@ -530,9 +548,8 @@ pub(crate) fn insert_fold_for_mention(
render_trailer,
);
editor.display_map.update(cx, |display_map, cx| {
display_map.fold(vec![crease], cx);
});
editor.insert_creases(vec![crease.clone()], cx);
editor.fold_creases(vec![crease], false, window, cx);
});
}
@@ -589,13 +606,12 @@ fn render_fold_icon_button(
.gap_1()
.child(
Icon::from_path(icon_path.clone())
.size(IconSize::XSmall)
.size(IconSize::Small)
.color(Color::Muted),
)
.child(
Label::new(label.clone())
.size(LabelSize::Small)
.buffer_font(cx)
.single_line(),
),
)

View File

@@ -8,7 +8,6 @@ use std::sync::atomic::AtomicBool;
use anyhow::Result;
use editor::{CompletionProvider, Editor, ExcerptId};
use file_icons::FileIcons;
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{App, Entity, Task, WeakEntity};
use http_client::HttpClientWithUrl;
use language::{Buffer, CodeLabel, HighlightId};
@@ -38,24 +37,7 @@ pub(crate) enum Match {
File(FileMatch),
Thread(ThreadMatch),
Fetch(SharedString),
Mode(ModeMatch),
}
pub struct ModeMatch {
mat: Option<StringMatch>,
mode: ContextPickerMode,
}
impl Match {
pub fn score(&self) -> f64 {
match self {
Match::File(file) => file.mat.score,
Match::Mode(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.),
Match::Thread(_) => 1.,
Match::Symbol(_) => 1.,
Match::Fetch(_) => 1.,
}
}
Mode(ContextPickerMode),
}
fn search(
@@ -144,54 +126,19 @@ fn search(
matches.extend(
supported_context_picker_modes(&thread_store)
.into_iter()
.map(|mode| Match::Mode(ModeMatch { mode, mat: None })),
.map(Match::Mode),
);
Task::ready(matches)
} else {
let executor = cx.background_executor().clone();
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
let modes = supported_context_picker_modes(&thread_store);
let mode_candidates = modes
.iter()
.enumerate()
.map(|(ix, mode)| StringMatchCandidate::new(ix, mode.mention_prefix()))
.collect::<Vec<_>>();
cx.background_spawn(async move {
let mut matches = search_files_task
search_files_task
.await
.into_iter()
.map(Match::File)
.collect::<Vec<_>>();
let mode_matches = fuzzy::match_strings(
&mode_candidates,
&query,
false,
100,
&Arc::new(AtomicBool::default()),
executor,
)
.await;
matches.extend(mode_matches.into_iter().map(|mat| {
Match::Mode(ModeMatch {
mode: modes[mat.candidate_id],
mat: Some(mat),
})
}));
matches.sort_by(|a, b| {
b.score()
.partial_cmp(&a.score())
.unwrap_or(std::cmp::Ordering::Equal)
});
matches
.collect()
})
}
}
@@ -601,7 +548,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
context_store.clone(),
http_client.clone(),
)),
Match::Mode(ModeMatch { mode, .. }) => {
Match::Mode(mode) => {
Some(Self::completion_for_mode(source_range.clone(), mode))
}
})
@@ -663,20 +610,21 @@ fn confirm_completion_callback(
editor: Entity<Editor>,
add_context_fn: impl Fn(&mut App) -> () + Send + Sync + 'static,
) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> {
Arc::new(move |_, _, cx| {
Arc::new(move |_, window, cx| {
add_context_fn(cx);
let crease_text = crease_text.clone();
let crease_icon_path = crease_icon_path.clone();
let editor = editor.clone();
cx.defer(move |cx| {
crate::context_picker::insert_fold_for_mention(
window.defer(cx, move |window, cx| {
crate::context_picker::insert_crease_for_mention(
excerpt_id,
start,
content_len,
crease_text,
crease_icon_path,
editor,
window,
cx,
);
});
@@ -746,7 +694,6 @@ impl MentionCompletion {
#[cfg(test)]
mod tests {
use super::*;
use editor::AnchorRangeExt;
use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext};
use project::{Project, ProjectPath};
use serde_json::json;
@@ -927,7 +874,7 @@ mod tests {
let editor = workspace.update_in(&mut cx, |workspace, window, cx| {
let editor = cx.new(|cx| {
Editor::new(
editor::EditorMode::full(),
editor::EditorMode::Full,
multi_buffer::MultiBuffer::build_simple("", cx),
None,
window,
@@ -1020,7 +967,7 @@ mod tests {
assert_eq!(editor.text(cx), "Lorem [@one.txt](@file:dir/a/one.txt)",);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
crease_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -1031,7 +978,7 @@ mod tests {
assert_eq!(editor.text(cx), "Lorem [@one.txt](@file:dir/a/one.txt) ",);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
crease_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -1045,7 +992,7 @@ mod tests {
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
crease_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -1059,7 +1006,7 @@ mod tests {
);
assert!(editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
crease_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -1077,7 +1024,7 @@ mod tests {
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
crease_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 79)
@@ -1094,7 +1041,7 @@ mod tests {
);
assert!(editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
crease_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 79)
@@ -1115,7 +1062,7 @@ mod tests {
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
fold_ranges(editor, cx),
crease_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 79),
@@ -1125,13 +1072,15 @@ mod tests {
});
}
fn fold_ranges(editor: &Editor, cx: &mut App) -> Vec<Range<Point>> {
fn crease_ranges(editor: &Editor, cx: &mut App) -> Vec<Range<Point>> {
let snapshot = editor.buffer().read(cx).snapshot(cx);
editor.display_map.update(cx, |display_map, cx| {
display_map
.snapshot(cx)
.folds_in_range(0..snapshot.len())
.map(|fold| fold.range.to_point(&snapshot))
.crease_snapshot
.crease_items_with_offsets(&snapshot)
.into_iter()
.map(|(_, range)| range)
.collect()
})
}

View File

@@ -11,7 +11,7 @@ use picker::{Picker, PickerDelegate};
use ui::{Context, ListItem, Window, prelude::*};
use workspace::Workspace;
use crate::context_picker::ContextPicker;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_store::ContextStore;
pub struct FetchContextPicker {
@@ -23,10 +23,16 @@ impl FetchContextPicker {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = FetchContextPickerDelegate::new(context_picker, workspace, context_store);
let delegate = FetchContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
@@ -56,6 +62,7 @@ pub struct FetchContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
url: String,
}
@@ -64,11 +71,13 @@ impl FetchContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
FetchContextPickerDelegate {
context_picker,
workspace,
context_store,
confirm_behavior,
url: String::new(),
}
}
@@ -195,15 +204,25 @@ impl PickerDelegate for FetchContextPickerDelegate {
let http_client = workspace.read(cx).client().http_client().clone();
let url = self.url.clone();
let confirm_behavior = self.confirm_behavior;
cx.spawn_in(window, async move |this, cx| {
let text = cx
.background_spawn(fetch_url_content(http_client, url.clone()))
.await?;
this.update(cx, |this, cx| {
this.delegate.context_store.update(cx, |context_store, cx| {
context_store.add_fetched_url(url, text, cx)
})
this.update_in(cx, |this, window, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_fetched_url(url, text, cx)
})?;
match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}
anyhow::Ok(())
})??;
anyhow::Ok(())

View File

@@ -11,9 +11,9 @@ use picker::{Picker, PickerDelegate};
use project::{PathMatchCandidateSet, ProjectPath, WorktreeId};
use ui::{ListItem, Tooltip, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
use workspace::{Workspace, notifications::NotifyResultExt};
use crate::context_picker::ContextPicker;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_store::{ContextStore, FileInclusion};
pub struct FileContextPicker {
@@ -25,10 +25,16 @@ impl FileContextPicker {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = FileContextPickerDelegate::new(context_picker, workspace, context_store);
let delegate = FileContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
@@ -51,6 +57,7 @@ pub struct FileContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<FileMatch>,
selected_index: usize,
}
@@ -60,11 +67,13 @@ impl FileContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
Self {
context_picker,
workspace,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
@@ -118,7 +127,7 @@ impl PickerDelegate for FileContextPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(FileMatch { mat, .. }) = self.matches.get(self.selected_index) else {
return;
};
@@ -144,7 +153,17 @@ impl PickerDelegate for FileContextPickerDelegate {
return;
};
task.detach_and_log_err(cx);
let confirm_behavior = self.confirm_behavior;
cx.spawn_in(window, async move |this, cx| {
match task.await.notify_async_err(cx) {
None => anyhow::Ok(()),
Some(()) => this.update_in(cx, |this, window, cx| match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}),
}
})
.detach_and_log_err(cx);
}
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {

View File

@@ -15,7 +15,7 @@ use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
use crate::context_picker::ContextPicker;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_store::ContextStore;
pub struct SymbolContextPicker {
@@ -27,10 +27,16 @@ impl SymbolContextPicker {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = SymbolContextPickerDelegate::new(context_picker, workspace, context_store);
let delegate = SymbolContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
@@ -53,6 +59,7 @@ pub struct SymbolContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<SymbolEntry>,
selected_index: usize,
}
@@ -62,11 +69,13 @@ impl SymbolContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
Self {
context_picker,
workspace,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
@@ -126,7 +135,7 @@ impl PickerDelegate for SymbolContextPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(mat) = self.matches.get(self.selected_index) else {
return;
};
@@ -134,6 +143,7 @@ impl PickerDelegate for SymbolContextPickerDelegate {
return;
};
let confirm_behavior = self.confirm_behavior;
let add_symbol_task = add_symbol(
mat.symbol.clone(),
true,
@@ -143,12 +153,16 @@ impl PickerDelegate for SymbolContextPickerDelegate {
);
let selected_index = self.selected_index;
cx.spawn(async move |this, cx| {
cx.spawn_in(window, async move |this, cx| {
let included = add_symbol_task.await?;
this.update(cx, |this, _| {
this.update_in(cx, |this, window, cx| {
if let Some(mat) = this.delegate.matches.get_mut(selected_index) {
mat.is_included = included;
}
match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}
})
})
.detach_and_log_err(cx);

View File

@@ -6,7 +6,7 @@ use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use ui::{ListItem, prelude::*};
use crate::context_picker::ContextPicker;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_store::{self, ContextStore};
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
@@ -20,11 +20,16 @@ impl ThreadContextPicker {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate =
ThreadContextPickerDelegate::new(thread_store, context_picker, context_store);
let delegate = ThreadContextPickerDelegate::new(
thread_store,
context_picker,
context_store,
confirm_behavior,
);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
ThreadContextPicker { picker }
@@ -53,6 +58,7 @@ pub struct ThreadContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<ThreadContextEntry>,
selected_index: usize,
}
@@ -62,11 +68,13 @@ impl ThreadContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
ThreadContextPickerDelegate {
thread_store,
context_picker,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
@@ -119,7 +127,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(entry) = self.matches.get(self.selected_index) else {
return;
};
@@ -130,15 +138,20 @@ impl PickerDelegate for ThreadContextPickerDelegate {
let open_thread_task = thread_store.update(cx, |this, cx| this.open_thread(&entry.id, cx));
cx.spawn(async move |this, cx| {
cx.spawn_in(window, async move |this, cx| {
let thread = open_thread_task.await?;
this.update(cx, |this, cx| {
this.update_in(cx, |this, window, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_thread(thread, true, cx)
})
.ok();
match this.delegate.confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}
})
})
.detach_and_log_err(cx);

View File

@@ -9,14 +9,14 @@ use futures::{self, Future, FutureExt, future};
use gpui::{App, AppContext as _, Context, Entity, SharedString, Task, WeakEntity};
use language::{Buffer, File};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use rope::{Point, Rope};
use rope::Rope;
use text::{Anchor, BufferId, OffsetRangeExt};
use util::{ResultExt as _, maybe};
use crate::ThreadStore;
use crate::context::{
AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext,
ExcerptContext, FetchedUrlContext, FileContext, SymbolContext, ThreadContext,
FetchedUrlContext, FileContext, SymbolContext, ThreadContext,
};
use crate::context_strip::SuggestedContext;
use crate::thread::{Thread, ThreadId};
@@ -110,7 +110,7 @@ impl ContextStore {
}
let (buffer_info, text_task) =
this.update(cx, |_, cx| collect_buffer_info_and_text(buffer, cx))??;
this.update(cx, |_, cx| collect_buffer_info_and_text(buffer, None, cx))??;
let text = text_task.await;
@@ -129,7 +129,7 @@ impl ContextStore {
) -> Task<Result<()>> {
cx.spawn(async move |this, cx| {
let (buffer_info, text_task) =
this.update(cx, |_, cx| collect_buffer_info_and_text(buffer, cx))??;
this.update(cx, |_, cx| collect_buffer_info_and_text(buffer, None, cx))??;
let text = text_task.await;
@@ -206,7 +206,7 @@ impl ContextStore {
// Skip all binary files and other non-UTF8 files
for buffer in buffers.into_iter().flatten() {
if let Some((buffer_info, text_task)) =
collect_buffer_info_and_text(buffer, cx).log_err()
collect_buffer_info_and_text(buffer, None, cx).log_err()
{
buffer_infos.push(buffer_info);
text_tasks.push(text_task);
@@ -290,14 +290,11 @@ impl ContextStore {
}
}
let (buffer_info, collect_content_task) = match collect_buffer_info_and_text_for_range(
buffer,
symbol_enclosing_range.clone(),
cx,
) {
Ok((_, buffer_info, collect_context_task)) => (buffer_info, collect_context_task),
Err(err) => return Task::ready(Err(err)),
};
let (buffer_info, collect_content_task) =
match collect_buffer_info_and_text(buffer, Some(symbol_enclosing_range.clone()), cx) {
Ok((buffer_info, collect_context_task)) => (buffer_info, collect_context_task),
Err(err) => return Task::ready(Err(err)),
};
cx.spawn(async move |this, cx| {
let content = collect_content_task.await;
@@ -419,49 +416,6 @@ impl ContextStore {
cx.notify();
}
pub fn add_excerpt(
&mut self,
range: Range<Anchor>,
buffer: Entity<Buffer>,
cx: &mut Context<ContextStore>,
) -> Task<Result<()>> {
cx.spawn(async move |this, cx| {
let (line_range, buffer_info, text_task) = this.update(cx, |_, cx| {
collect_buffer_info_and_text_for_range(buffer, range.clone(), cx)
})??;
let text = text_task.await;
this.update(cx, |this, cx| {
this.insert_excerpt(
make_context_buffer(buffer_info, text),
range,
line_range,
cx,
)
})?;
anyhow::Ok(())
})
}
fn insert_excerpt(
&mut self,
context_buffer: ContextBuffer,
range: Range<Anchor>,
line_range: Range<Point>,
cx: &mut Context<Self>,
) {
let id = self.next_context_id.post_inc();
self.context.push(AssistantContext::Excerpt(ExcerptContext {
id,
range,
line_range,
context_buffer,
}));
cx.notify();
}
pub fn accept_suggested_context(
&mut self,
suggested: &SuggestedContext,
@@ -511,7 +465,6 @@ impl ContextStore {
self.symbol_buffers.remove(&symbol.context_symbol.id);
self.symbols.retain(|_, context_id| *context_id != id);
}
AssistantContext::Excerpt(_) => {}
AssistantContext::FetchedUrl(_) => {
self.fetched_urls.retain(|_, context_id| *context_id != id);
}
@@ -639,7 +592,6 @@ impl ContextStore {
}
AssistantContext::Directory(_)
| AssistantContext::Symbol(_)
| AssistantContext::Excerpt(_)
| AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_) => None,
})
@@ -691,78 +643,41 @@ fn make_context_symbol(
}
}
fn collect_buffer_info_and_text_for_range(
buffer: Entity<Buffer>,
range: Range<Anchor>,
cx: &App,
) -> Result<(Range<Point>, BufferInfo, Task<SharedString>)> {
let content = buffer
.read(cx)
.text_for_range(range.clone())
.collect::<Rope>();
let line_range = range.to_point(&buffer.read(cx).snapshot());
let buffer_info = collect_buffer_info(buffer, cx)?;
let full_path = buffer_info.file.full_path(cx);
let text_task = cx.background_spawn({
let line_range = line_range.clone();
async move { to_fenced_codeblock(&full_path, content, Some(line_range)) }
});
Ok((line_range, buffer_info, text_task))
}
fn collect_buffer_info_and_text(
buffer: Entity<Buffer>,
range: Option<Range<Anchor>>,
cx: &App,
) -> Result<(BufferInfo, Task<SharedString>)> {
let content = buffer.read(cx).as_rope().clone();
let buffer_info = collect_buffer_info(buffer, cx)?;
let full_path = buffer_info.file.full_path(cx);
let text_task =
cx.background_spawn(async move { to_fenced_codeblock(&full_path, content, None) });
Ok((buffer_info, text_task))
}
fn collect_buffer_info(buffer: Entity<Buffer>, cx: &App) -> Result<BufferInfo> {
let buffer_ref = buffer.read(cx);
let file = buffer_ref.file().context("file context must have a path")?;
// Important to collect version at the same time as content so that staleness logic is correct.
let version = buffer_ref.version();
let content = if let Some(range) = range {
buffer_ref.text_for_range(range).collect::<Rope>()
} else {
buffer_ref.as_rope().clone()
};
Ok(BufferInfo {
let buffer_info = BufferInfo {
buffer,
id: buffer_ref.remote_id(),
file: file.clone(),
version,
})
};
let full_path = file.full_path(cx);
let text_task = cx.background_spawn(async move { to_fenced_codeblock(&full_path, content) });
Ok((buffer_info, text_task))
}
fn to_fenced_codeblock(
path: &Path,
content: Rope,
line_range: Option<Range<Point>>,
) -> SharedString {
let line_range_text = line_range.map(|range| {
if range.start.row == range.end.row {
format!(":{}", range.start.row + 1)
} else {
format!(":{}-{}", range.start.row + 1, range.end.row + 1)
}
});
fn to_fenced_codeblock(path: &Path, content: Rope) -> SharedString {
let path_extension = path.extension().and_then(|ext| ext.to_str());
let path_string = path.to_string_lossy();
let capacity = 3
+ path_extension.map_or(0, |extension| extension.len() + 1)
+ path_string.len()
+ line_range_text.as_ref().map_or(0, |text| text.len())
+ 1
+ content.len()
+ 5;
@@ -776,10 +691,6 @@ fn to_fenced_codeblock(
}
buffer.push_str(&path_string);
if let Some(line_range_text) = line_range_text {
buffer.push_str(&line_range_text);
}
buffer.push('\n');
for chunk in content.chunks() {
buffer.push_str(&chunk);
@@ -858,14 +769,6 @@ pub fn refresh_context_store_text(
return refresh_symbol_text(context_store, symbol_context, cx);
}
}
AssistantContext::Excerpt(excerpt_context) => {
if changed_buffers.is_empty()
|| changed_buffers.contains(&excerpt_context.context_buffer.buffer)
{
let context_store = context_store.clone();
return refresh_excerpt_text(context_store, excerpt_context, cx);
}
}
AssistantContext::Thread(thread_context) => {
if changed_buffers.is_empty() {
let context_store = context_store.clone();
@@ -977,34 +880,6 @@ fn refresh_symbol_text(
}
}
fn refresh_excerpt_text(
context_store: Entity<ContextStore>,
excerpt_context: &ExcerptContext,
cx: &App,
) -> Option<Task<()>> {
let id = excerpt_context.id;
let range = excerpt_context.range.clone();
let task = refresh_context_excerpt(&excerpt_context.context_buffer, range.clone(), cx);
if let Some(task) = task {
Some(cx.spawn(async move |cx| {
let (line_range, context_buffer) = task.await;
context_store
.update(cx, |context_store, _| {
let new_excerpt_context = ExcerptContext {
id,
range,
line_range,
context_buffer,
};
context_store.replace_context(AssistantContext::Excerpt(new_excerpt_context));
})
.ok();
}))
} else {
None
}
}
fn refresh_thread_text(
context_store: Entity<ContextStore>,
thread_context: &ThreadContext,
@@ -1033,29 +908,13 @@ fn refresh_context_buffer(
let buffer = context_buffer.buffer.read(cx);
if buffer.version.changed_since(&context_buffer.version) {
let (buffer_info, text_task) =
collect_buffer_info_and_text(context_buffer.buffer.clone(), cx).log_err()?;
collect_buffer_info_and_text(context_buffer.buffer.clone(), None, cx).log_err()?;
Some(text_task.map(move |text| make_context_buffer(buffer_info, text)))
} else {
None
}
}
fn refresh_context_excerpt(
context_buffer: &ContextBuffer,
range: Range<Anchor>,
cx: &App,
) -> Option<impl Future<Output = (Range<Point>, ContextBuffer)> + use<>> {
let buffer = context_buffer.buffer.read(cx);
if buffer.version.changed_since(&context_buffer.version) {
let (line_range, buffer_info, text_task) =
collect_buffer_info_and_text_for_range(context_buffer.buffer.clone(), range, cx)
.log_err()?;
Some(text_task.map(move |text| (line_range, make_context_buffer(buffer_info, text))))
} else {
None
}
}
fn refresh_context_symbol(
context_symbol: &ContextSymbol,
cx: &App,
@@ -1063,9 +922,9 @@ fn refresh_context_symbol(
let buffer = context_symbol.buffer.read(cx);
let project_path = buffer.project_path(cx)?;
if buffer.version.changed_since(&context_symbol.buffer_version) {
let (_, buffer_info, text_task) = collect_buffer_info_and_text_for_range(
let (buffer_info, text_task) = collect_buffer_info_and_text(
context_symbol.buffer.clone(),
context_symbol.enclosing_range.clone(),
Some(context_symbol.enclosing_range.clone()),
cx,
)
.log_err()?;

View File

@@ -15,7 +15,7 @@ use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
use workspace::{Workspace, notifications::NotifyResultExt};
use crate::context::{ContextId, ContextKind};
use crate::context_picker::ContextPicker;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_store::ContextStore;
use crate::thread::Thread;
use crate::thread_store::ThreadStore;
@@ -52,6 +52,7 @@ impl ContextStrip {
workspace.clone(),
thread_store.clone(),
context_store.downgrade(),
ConfirmBehavior::KeepOpen,
window,
cx,
)

File diff suppressed because it is too large Load Diff

View File

@@ -86,7 +86,7 @@ impl ProfileSelector {
thread_store
.update(cx, |this, cx| {
this.load_profile_by_id(profile_id.clone(), cx);
this.load_profile_by_id(&profile_id, cx);
})
.log_err();
}

View File

@@ -2,41 +2,38 @@ use std::fmt::Write as _;
use std::io::Write;
use std::ops::Range;
use std::sync::Arc;
use std::time::Instant;
use agent_rules::load_worktree_rules_file;
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap};
use feature_flags::{self, FeatureFlagAppExt};
use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
TokenUsage,
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
PaymentRequiredError, Role, StopReason, TokenUsage,
};
use project::Project;
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
use prompt_store::PromptBuilder;
use proto::Plan;
use project::{Project, Worktree};
use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse, SharedProjectContext,
SerializedToolUse,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
@@ -228,36 +225,10 @@ pub enum DetailedSummaryState {
pub struct TotalTokenUsage {
pub total: usize,
pub max: usize,
pub ratio: TokenUsageRatio,
}
impl TotalTokenUsage {
pub fn ratio(&self) -> TokenUsageRatio {
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.8".to_string())
.parse()
.unwrap();
#[cfg(not(debug_assertions))]
let warning_threshold: f32 = 0.8;
if self.total >= self.max {
TokenUsageRatio::Exceeded
} else if self.total as f32 / self.max as f32 >= warning_threshold {
TokenUsageRatio::Warning
} else {
TokenUsageRatio::Normal
}
}
pub fn add(&self, tokens: usize) -> TotalTokenUsage {
TotalTokenUsage {
total: self.total + tokens,
max: self.max,
}
}
}
#[derive(Debug, Default, PartialEq, Eq)]
#[derive(Default, PartialEq, Eq)]
pub enum TokenUsageRatio {
#[default]
Normal,
@@ -276,43 +247,28 @@ pub struct Thread {
next_message_id: MessageId,
context: BTreeMap<ContextId, AssistantContext>,
context_by_message: HashMap<MessageId, Vec<ContextId>>,
project_context: SharedProjectContext,
system_prompt_context: Option<AssistantSystemPromptContext>,
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
project: Entity<Project>,
prompt_builder: Arc<PromptBuilder>,
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState,
action_log: Entity<ActionLog>,
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
pending_checkpoint: Option<ThreadCheckpoint>,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
request_token_usage: Vec<TokenUsage>,
cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>,
feedback: Option<ThreadFeedback>,
message_feedback: HashMap<MessageId, ThreadFeedback>,
last_auto_capture_at: Option<Instant>,
request_callback: Option<
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExceededWindowError {
/// Model used when last message exceeded context window
model_id: LanguageModelId,
/// Token count including last message
token_count: usize,
}
impl Thread {
pub fn new(
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
system_prompt: SharedProjectContext,
cx: &mut Context<Self>,
) -> Self {
Self {
@@ -325,7 +281,7 @@ impl Thread {
next_message_id: MessageId(0),
context: BTreeMap::default(),
context_by_message: HashMap::default(),
project_context: system_prompt,
system_prompt_context: None,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
@@ -342,13 +298,9 @@ impl Thread {
.spawn(async move { Some(project_snapshot.await) })
.shared()
},
request_token_usage: Vec::new(),
cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
request_callback: None,
}
}
@@ -356,9 +308,8 @@ impl Thread {
id: ThreadId,
serialized: SerializedThread,
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext,
cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
@@ -399,7 +350,7 @@ impl Thread {
next_message_id,
context: BTreeMap::default(),
context_by_message: HashMap::default(),
project_context,
system_prompt_context: None,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
@@ -411,24 +362,12 @@ impl Thread {
tool_use,
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
request_token_usage: serialized.request_token_usage,
cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
request_callback: None,
}
}
pub fn set_request_callback(
&mut self,
callback: impl 'static
+ FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
) {
self.request_callback = Some(Box::new(callback));
}
pub fn id(&self) -> &ThreadId {
&self.id
}
@@ -449,10 +388,6 @@ impl Thread {
self.summary.clone()
}
pub fn project_context(&self) -> SharedProjectContext {
self.project_context.clone()
}
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
pub fn summary_or_default(&self) -> SharedString {
@@ -502,7 +437,7 @@ impl Thread {
!self.pending_completions.is_empty() || !self.all_tools_finished()
}
pub fn tools(&self) -> &Entity<ToolWorkingSet> {
pub fn tools(&self) -> &Arc<ToolWorkingSet> {
&self.tools
}
@@ -674,30 +609,10 @@ impl Thread {
self.tool_use.tool_result(id)
}
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
Some(&self.tool_use.tool_result(id)?.content)
}
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
self.tool_use.tool_result_card(id).cloned()
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
/// Filter out contexts that have already been included in previous messages
pub fn filter_new_context<'a>(
&self,
context: impl Iterator<Item = &'a AssistantContext>,
) -> impl Iterator<Item = &'a AssistantContext> {
context.filter(|ctx| self.is_context_new(ctx))
}
fn is_context_new(&self, context: &AssistantContext) -> bool {
!self.context.contains_key(&context.id())
}
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
@@ -709,9 +624,10 @@ impl Thread {
let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
// Filter out contexts that have already been included in previous messages
let new_context: Vec<_> = context
.into_iter()
.filter(|ctx| self.is_context_new(ctx))
.filter(|ctx| !self.context.contains_key(&ctx.id()))
.collect();
if !new_context.is_empty() {
@@ -739,12 +655,6 @@ impl Thread {
cx,
);
}
AssistantContext::Excerpt(excerpt_context) => {
log.buffer_added_as_context(
excerpt_context.context_buffer.buffer.clone(),
cx,
);
}
AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
}
}
@@ -896,14 +806,92 @@ impl Thread {
})
.collect(),
initial_project_snapshot,
cumulative_token_usage: this.cumulative_token_usage,
request_token_usage: this.request_token_usage.clone(),
cumulative_token_usage: this.cumulative_token_usage.clone(),
detailed_summary_state: this.detailed_summary_state.clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
})
})
}
pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
self.system_prompt_context = Some(context);
}
pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
&self.system_prompt_context
}
pub fn load_system_prompt_context(
&self,
cx: &App,
) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
let project = self.project.read(cx);
let tasks = project
.visible_worktrees(cx)
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(
project.fs().clone(),
worktree.read(cx),
cx,
)
})
.collect::<Vec<_>>();
cx.spawn(async |_cx| {
let results = futures::future::join_all(tasks).await;
let mut first_err = None;
let worktrees = results
.into_iter()
.map(|(worktree, err)| {
if first_err.is_none() && err.is_some() {
first_err = err;
}
worktree
})
.collect::<Vec<_>>();
(AssistantSystemPromptContext::new(worktrees), first_err)
})
}
fn load_worktree_info_for_system_prompt(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
let root_name = worktree.root_name().into();
let abs_path = worktree.abs_path();
let rules_task = load_worktree_rules_file(fs, worktree, cx);
let Some(rules_task) = rules_task else {
return Task::ready((
WorktreeInfoForSystemPrompt {
root_name,
abs_path,
rules_file: None,
},
None,
));
};
cx.spawn(async move |_| {
let (rules_file, rules_file_error) = match rules_task.await {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(ThreadError::Message {
header: "Error loading rules file".into(),
message: format!("{err}").into(),
}),
),
};
let worktree_info = WorktreeInfoForSystemPrompt {
root_name,
abs_path,
rules_file,
};
(worktree_info, rules_file_error)
})
}
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
@@ -914,21 +902,13 @@ impl Thread {
if model.supports_tools() {
request.tools = {
let mut tools = Vec::new();
tools.extend(
self.tools()
.read(cx)
.enabled_tools(cx)
.into_iter()
.filter_map(|tool| {
// Skip tools that cannot be supported
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema,
})
}),
);
tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool.input_schema(model.tool_input_format()),
}
}));
tools
};
@@ -961,10 +941,10 @@ impl Thread {
temperature: None,
};
if let Some(project_context) = self.project_context.borrow().as_ref() {
if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
if let Some(system_prompt) = self
.prompt_builder
.generate_assistant_system_prompt(project_context)
.generate_assistant_system_prompt(system_prompt_context)
.context("failed to generate assistant system prompt")
.log_err()
{
@@ -975,7 +955,7 @@ impl Thread {
});
}
} else {
log::error!("project_context not set.")
log::error!("system_prompt_context not set.")
}
for message in &self.messages {
@@ -1056,6 +1036,15 @@ impl Thread {
content.push(stale_message.into());
}
if action_log.has_edited_files_since_project_diagnostics_check() {
content.push(
"\n\nWhen you're done making changes, make sure to check project diagnostics \
and fix all errors AND warnings you introduced! \
DO NOT mention you're going to do this until you're done."
.into(),
);
}
if !content.is_empty() {
let context_message = LanguageModelRequestMessage {
role: Role::User,
@@ -1074,36 +1063,17 @@ impl Thread {
cx: &mut Context<Self>,
) {
let pending_completion_id = post_inc(&mut self.completion_count);
let mut request_callback_parameters = if self.request_callback.is_some() {
Some((request.clone(), Vec::new()))
} else {
None
};
let task = cx.spawn(async move |thread, cx| {
let stream_completion_future = model.stream_completion_with_usage(request, &cx);
let stream = model.stream_completion(request, &cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
let stream_completion = async {
let (mut events, usage) = stream_completion_future.await?;
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
let mut current_token_usage = TokenUsage::default();
if let Some(usage) = usage {
thread
.update(cx, |_thread, cx| {
cx.emit(ThreadEvent::UsageUpdated(usage));
})
.ok();
}
while let Some(event) = events.next().await {
if let Some((_, response_events)) = request_callback_parameters.as_mut() {
response_events
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
}
let event = event?;
thread.update(cx, |thread, cx| {
@@ -1119,10 +1089,9 @@ impl Thread {
stop_reason = reason;
}
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
thread.update_token_usage_at_last_message(token_usage);
thread.cumulative_token_usage = thread.cumulative_token_usage
+ token_usage
- current_token_usage;
thread.cumulative_token_usage =
thread.cumulative_token_usage.clone() + token_usage.clone()
- current_token_usage.clone();
current_token_usage = token_usage;
}
LanguageModelCompletionEvent::Text(chunk) => {
@@ -1231,26 +1200,6 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(
ThreadError::MaxMonthlySpendReached,
));
} else if let Some(error) =
error.downcast_ref::<ModelRequestLimitReachedError>()
{
cx.emit(ThreadEvent::ShowError(
ThreadError::ModelRequestLimitReached { plan: error.plan },
));
} else if let Some(known_error) =
error.downcast_ref::<LanguageModelKnownError>()
{
match known_error {
LanguageModelKnownError::ContextWindowLimitExceeded {
tokens,
} => {
thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(),
token_count: *tokens,
});
cx.notify();
}
}
} else {
let error_message = error
.chain()
@@ -1266,20 +1215,12 @@ impl Thread {
thread.cancel_last_completion(cx);
}
}
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
if let Some((request_callback, (request, response_events))) = thread
.request_callback
.as_mut()
.zip(request_callback_parameters.as_ref())
{
request_callback(request, response_events);
}
cx.emit(ThreadEvent::DoneStreaming);
thread.auto_capture_telemetry(cx);
if let Ok(initial_usage) = initial_token_usage {
let usage = thread.cumulative_token_usage - initial_usage;
let usage = thread.cumulative_token_usage.clone() - initial_usage;
telemetry::event!(
"Assistant Thread Completion",
@@ -1326,15 +1267,8 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
let stream = model.model.stream_completion_text_with_usage(request, &cx);
let (mut messages, usage) = stream.await?;
if let Some(usage) = usage {
this.update(cx, |_thread, cx| {
cx.emit(ThreadEvent::UsageUpdated(usage));
})
.ok();
}
let stream = model.model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
let mut new_summary = String::new();
while let Some(message) = messages.stream.next().await {
@@ -1457,7 +1391,7 @@ impl Thread {
.collect::<Vec<_>>();
for tool_use in pending_tool_uses.iter() {
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
if tool.needs_confirmation(&tool_use.input, cx)
&& !AssistantSettings::get_global(cx).always_allow_tool_actions
{
@@ -1509,8 +1443,8 @@ impl Thread {
) -> Task<()> {
let tool_name: Arc<str> = tool.name().into();
let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
} else {
tool.run(
input,
@@ -1521,15 +1455,9 @@ impl Thread {
)
};
// Store the card separately if it exists
if let Some(card) = tool_result.card.clone() {
self.tool_use
.insert_tool_result_card(tool_use_id.clone(), card);
}
cx.spawn({
async move |thread: WeakEntity<Thread>, cx| {
let output = tool_result.output.await;
let output = run_tool.await;
thread
.update(cx, |thread, cx| {
@@ -1569,10 +1497,17 @@ impl Thread {
});
}
/// Insert an empty message to be populated with tool results upon send.
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
self.insert_message(Role::User, vec![], cx);
self.auto_capture_telemetry(cx);
// Insert a user message to contain the tool results.
self.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back
// responses that also don't have any content. We currently don't handle this case well,
// so for now we provide some text to keep the model on track.
"Here are the tool results.",
Vec::new(),
None,
cx,
);
}
/// Cancels the last pending completion, if there are any pending.
@@ -1623,7 +1558,6 @@ impl Thread {
let enabled_tool_names: Vec<String> = self
.tools()
.read(cx)
.enabled_tools(cx)
.iter()
.map(|tool| tool.name().to_string())
@@ -1902,14 +1836,14 @@ impl Thread {
.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
}
pub fn reject_edits_in_ranges(
pub fn reject_edits_in_range(
&mut self,
buffer: Entity<language::Buffer>,
buffer_ranges: Vec<Range<language::Anchor>>,
buffer_range: Range<language::Anchor>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.action_log.update(cx, |action_log, cx| {
action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
action_log.reject_edits_in_range(buffer, buffer_range, cx)
})
}
@@ -1921,83 +1855,69 @@ impl Thread {
&self.project
}
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone()
}
pub fn auto_capture_telemetry(&self, cx: &mut Context<Self>) {
static mut LAST_CAPTURE: Option<std::time::Instant> = None;
let now = std::time::Instant::now();
let should_check = unsafe {
if let Some(last) = LAST_CAPTURE {
if now.duration_since(last).as_secs() < 10 {
return;
}
}
LAST_CAPTURE = Some(now);
true
};
if !should_check {
return;
}
let now = Instant::now();
if let Some(last) = self.last_auto_capture_at {
if now.duration_since(last).as_secs() < 10 {
let feature_flag_enabled = cx.has_flag::<feature_flags::ThreadAutoCapture>();
if cfg!(debug_assertions) {
if !feature_flag_enabled {
return;
}
}
self.last_auto_capture_at = Some(now);
let thread_id = self.id().clone();
let github_login = self
let github_handle = self
.project
.read(cx)
.user_store()
.read(cx)
.current_user()
.map(|user| user.github_login.clone());
let client = self.project.read(cx).client().clone();
let serialize_task = self.serialize(cx);
cx.background_executor()
let serialized_thread = self.serialize(cx);
cx.foreground_executor()
.spawn(async move {
if let Ok(serialized_thread) = serialize_task.await {
if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
telemetry::event!(
"Agent Thread Auto-Captured",
thread_id = thread_id.to_string(),
thread_data = thread_data,
auto_capture_reason = "tracked_user",
github_login = github_login
);
if let Ok(serialized_thread) = serialized_thread.await {
let thread_data = serde_json::to_value(serialized_thread)
.unwrap_or_else(|_| serde_json::Value::Null);
client.telemetry().flush_events();
}
telemetry::event!(
"Agent Thread AutoCaptured",
thread_id = thread_id.to_string(),
thread_data = thread_data,
auto_capture_reason = "tracked_user",
github_handle = github_handle
);
client.telemetry().flush_events();
}
})
.detach();
}
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage
}
pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return TotalTokenUsage::default();
};
let max = model.model.max_token_count();
let index = self
.messages
.iter()
.position(|msg| msg.id == message_id)
.unwrap_or(0);
if index == 0 {
return TotalTokenUsage { total: 0, max };
}
let token_usage = &self
.request_token_usage
.get(index - 1)
.cloned()
.unwrap_or_default();
TotalTokenUsage {
total: token_usage.total_tokens() as usize,
max,
}
}
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.default_model() else {
@@ -2006,38 +1926,25 @@ impl Thread {
let max = model.model.max_token_count();
if let Some(exceeded_error) = &self.exceeded_window_error {
if model.model.id() == exceeded_error.model_id {
return TotalTokenUsage {
total: exceeded_error.token_count,
max,
};
}
}
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.8".to_string())
.parse()
.unwrap();
#[cfg(not(debug_assertions))]
let warning_threshold: f32 = 0.8;
let total = self
.token_usage_at_last_message()
.unwrap_or_default()
.total_tokens() as usize;
let total = self.cumulative_token_usage.total_tokens() as usize;
TotalTokenUsage { total, max }
}
let ratio = if total >= max {
TokenUsageRatio::Exceeded
} else if total as f32 / max as f32 >= warning_threshold {
TokenUsageRatio::Warning
} else {
TokenUsageRatio::Normal
};
fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
self.request_token_usage
.get(self.messages.len().saturating_sub(1))
.or_else(|| self.request_token_usage.last())
.cloned()
}
fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
let placeholder = self.token_usage_at_last_message().unwrap_or_default();
self.request_token_usage
.resize(self.messages.len(), placeholder);
if let Some(last) = self.request_token_usage.last_mut() {
*last = token_usage;
}
TotalTokenUsage { total, max, ratio }
}
pub fn deny_tool_use(
@@ -2056,15 +1963,10 @@ impl Thread {
}
}
#[derive(Debug, Clone, Error)]
#[derive(Debug, Clone)]
pub enum ThreadError {
#[error("Payment required")]
PaymentRequired,
#[error("Max monthly spend reached")]
MaxMonthlySpendReached,
#[error("Model request limit reached")]
ModelRequestLimitReached { plan: Plan },
#[error("Message {header}: {message}")]
Message {
header: SharedString,
message: SharedString,
@@ -2074,11 +1976,10 @@ pub enum ThreadError {
#[derive(Debug, Clone)]
pub enum ThreadEvent {
ShowError(ThreadError),
UsageUpdated(RequestUsage),
StreamedCompletion,
StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
Stopped(Result<StopReason, Arc<anyhow::Error>>),
DoneStreaming,
MessageAdded(MessageId),
MessageEdited(MessageId),
MessageDeleted(MessageId),
@@ -2184,9 +2085,9 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 2);
assert_eq!(request.messages.len(), 1);
let expected_full_message = format!("{}Please explain this code", expected_context);
assert_eq!(request.messages[1].string_contents(), expected_full_message);
assert_eq!(request.messages[0].string_contents(), expected_full_message);
}
#[gpui::test]
@@ -2277,20 +2178,20 @@ fn main() {{
});
// The request should contain all 3 messages
assert_eq!(request.messages.len(), 4);
assert_eq!(request.messages.len(), 3);
// Check that the contexts are properly formatted in each message
assert!(request.messages[1].string_contents().contains("file1.rs"));
assert!(!request.messages[1].string_contents().contains("file2.rs"));
assert!(request.messages[0].string_contents().contains("file1.rs"));
assert!(!request.messages[0].string_contents().contains("file2.rs"));
assert!(!request.messages[0].string_contents().contains("file3.rs"));
assert!(!request.messages[1].string_contents().contains("file1.rs"));
assert!(request.messages[1].string_contents().contains("file2.rs"));
assert!(!request.messages[1].string_contents().contains("file3.rs"));
assert!(!request.messages[2].string_contents().contains("file1.rs"));
assert!(request.messages[2].string_contents().contains("file2.rs"));
assert!(!request.messages[2].string_contents().contains("file3.rs"));
assert!(!request.messages[3].string_contents().contains("file1.rs"));
assert!(!request.messages[3].string_contents().contains("file2.rs"));
assert!(request.messages[3].string_contents().contains("file3.rs"));
assert!(!request.messages[2].string_contents().contains("file2.rs"));
assert!(request.messages[2].string_contents().contains("file3.rs"));
}
#[gpui::test]
@@ -2328,9 +2229,9 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 2);
assert_eq!(request.messages.len(), 1);
assert_eq!(
request.messages[1].string_contents(),
request.messages[0].string_contents(),
"What is the best way to learn Rust?"
);
@@ -2348,13 +2249,13 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 3);
assert_eq!(request.messages.len(), 2);
assert_eq!(
request.messages[1].string_contents(),
request.messages[0].string_contents(),
"What is the best way to learn Rust?"
);
assert_eq!(
request.messages[2].string_contents(),
request.messages[1].string_contents(),
"Are there any good books?"
);
}
@@ -2475,16 +2376,15 @@ fn main() {{
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_store = cx
.update(|_, cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})
.await;
let thread_store = cx.update(|_, cx| {
ThreadStore::new(
project.clone(),
Arc::default(),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
.unwrap()
});
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));

View File

@@ -4,14 +4,11 @@ use assistant_context_editor::SavedContextMetadata;
use editor::{Editor, EditorEvent};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, UniformListScrollHandle,
WeakEntity, Window, uniform_list,
App, Entity, FocusHandle, Focusable, ScrollStrategy, Task, UniformListScrollHandle, WeakEntity,
Window, uniform_list,
};
use time::{OffsetDateTime, UtcOffset};
use ui::{
HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState,
Tooltip, prelude::*,
};
use ui::{HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Tooltip, prelude::*};
use util::ResultExt;
use crate::history_store::{HistoryEntry, HistoryStore};
@@ -29,8 +26,6 @@ pub struct ThreadHistory {
matches: Vec<StringMatch>,
_subscriptions: Vec<gpui::Subscription>,
_search_task: Option<Task<()>>,
scrollbar_visibility: bool,
scrollbar_state: ScrollbarState,
}
impl ThreadHistory {
@@ -63,13 +58,10 @@ impl ThreadHistory {
this.update_all_entries(cx);
});
let scroll_handle = UniformListScrollHandle::default();
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
Self {
assistant_panel,
history_store,
scroll_handle,
scroll_handle: UniformListScrollHandle::default(),
selected_index: 0,
search_query: SharedString::new_static(""),
all_entries: entries,
@@ -77,8 +69,6 @@ impl ThreadHistory {
search_editor,
_subscriptions: vec![search_editor_subscription, history_store_subscription],
_search_task: None,
scrollbar_visibility: true,
scrollbar_state,
}
}
@@ -230,43 +220,6 @@ impl ThreadHistory {
cx.notify();
}
fn render_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) {
return None;
}
Some(
div()
.occlude()
.id("thread-history-scroll")
.h_full()
.bg(cx.theme().colors().panel_background.opacity(0.8))
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.absolute()
.right_1()
.top_0()
.bottom_0()
.w_4()
.pl_1()
.cursor_default()
.on_mouse_move(cx.listener(|_, _, _window, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _window, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _window, cx| {
cx.stop_propagation();
})
.on_scroll_wheel(cx.listener(|_, _, _window, cx| {
cx.notify();
}))
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
)
}
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
if let Some(entry) = self.get_match(self.selected_index) {
let task_result = match entry {
@@ -352,11 +305,7 @@ impl Render for ThreadHistory {
)
})
.child({
let view = v_flex()
.id("list-container")
.relative()
.overflow_hidden()
.flex_grow();
let view = v_flex().overflow_hidden().flex_grow();
if self.all_entries.is_empty() {
view.justify_center()
@@ -373,70 +322,59 @@ impl Render for ThreadHistory {
),
)
} else {
view.pr_5()
.child(
uniform_list(
cx.entity().clone(),
"thread-history",
self.matched_count(),
move |history, range, _window, _cx| {
let range_start = range.start;
let assistant_panel = history.assistant_panel.clone();
view.p_1().child(
uniform_list(
cx.entity().clone(),
"thread-history",
self.matched_count(),
move |history, range, _window, _cx| {
let range_start = range.start;
let assistant_panel = history.assistant_panel.clone();
let render_item = |index: usize,
entry: &HistoryEntry,
highlight_positions: Vec<usize>|
-> Div {
h_flex().w_full().pb_1().child(match entry {
HistoryEntry::Thread(thread) => PastThread::new(
thread.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
HistoryEntry::Context(context) => PastContext::new(
context.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
})
};
let render_item = |index: usize,
entry: &HistoryEntry,
highlight_positions: Vec<usize>|
-> Div {
h_flex().w_full().pb_1().child(match entry {
HistoryEntry::Thread(thread) => PastThread::new(
thread.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
HistoryEntry::Context(context) => PastContext::new(
context.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
})
};
if history.has_search_query() {
history.matches[range]
.iter()
.enumerate()
.filter_map(|(index, m)| {
history.all_entries.get(m.candidate_id).map(
|entry| {
render_item(
index,
entry,
m.positions.clone(),
)
},
)
if history.has_search_query() {
history.matches[range]
.iter()
.enumerate()
.filter_map(|(index, m)| {
history.all_entries.get(m.candidate_id).map(|entry| {
render_item(index, entry, m.positions.clone())
})
.collect()
} else {
history.all_entries[range]
.iter()
.enumerate()
.map(|(index, entry)| render_item(index, entry, vec![]))
.collect()
}
},
)
.p_1()
.track_scroll(self.scroll_handle.clone())
.flex_grow(),
})
.collect()
} else {
history.all_entries[range]
.iter()
.enumerate()
.map(|(index, entry)| render_item(index, entry, vec![]))
.collect()
}
},
)
.when_some(self.render_scrollbar(cx), |div, scrollbar| {
div.child(scrollbar)
})
.track_scroll(self.scroll_handle.clone())
.flex_grow(),
)
}
})
}
@@ -502,7 +440,6 @@ impl RenderOnce for PastThread {
IconButton::new("delete", IconName::TrashAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(move |window, cx| {
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
})
@@ -594,7 +531,6 @@ impl RenderOnce for PastContext {
IconButton::new("delete", IconName::TrashAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(move |window, cx| {
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
})

View File

@@ -1,254 +1,88 @@
use std::borrow::Cow;
use std::cell::{Ref, RefCell};
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Result, anyhow};
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use fs::Fs;
use futures::FutureExt as _;
use futures::future::{self, BoxFuture, Shared};
use gpui::{
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
Subscription, Task, prelude::*,
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
prelude::*,
};
use heed::Database;
use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::{Project, Worktree};
use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
use project::Project;
use prompt_store::PromptBuilder;
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
use crate::thread::{
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
};
const RULES_FILE_NAMES: [&'static str; 6] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
".github/copilot-instructions.md",
"CLAUDE.md",
];
pub fn init(cx: &mut App) {
ThreadsDatabase::init(cx);
}
/// A system prompt shared by all threads created by this ThreadStore
#[derive(Clone, Default)]
pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
impl SharedProjectContext {
pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
self.0.borrow()
}
}
pub struct ThreadStore {
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
project_context: SharedProjectContext,
_subscriptions: Vec<Subscription>,
}
pub struct RulesLoadingError {
pub message: SharedString,
}
impl EventEmitter<RulesLoadingError> for ThreadStore {}
impl ThreadStore {
pub fn load(
pub fn new(
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
) -> Task<Entity<Self>> {
let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
cx.foreground_executor().spawn(async move {
reload.await;
thread_store
})
}
fn new(
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
) -> Self {
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let settings_subscription =
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
) -> Result<Entity<Self>> {
let this = cx.new(|cx| {
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let project_subscription = cx.subscribe(&project, Self::handle_project_event);
let settings_subscription =
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
});
let this = Self {
project,
tools,
prompt_builder,
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
project_context: SharedProjectContext::default(),
_subscriptions: vec![settings_subscription, project_subscription],
};
this.load_default_profile(cx);
this.register_context_server_handlers(cx);
this.reload(cx).detach_and_log_err(cx);
this
}
fn handle_project_event(
&mut self,
_project: Entity<Project>,
event: &project::Event,
cx: &mut Context<Self>,
) {
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
self.reload_system_prompt(cx).detach();
}
project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| {
RULES_FILE_NAMES
.iter()
.any(|name| path.as_ref() == Path::new(name))
}) {
self.reload_system_prompt(cx).detach();
}
}
_ => {}
}
}
pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
let project = self.project.read(cx);
let tasks = project
.visible_worktrees(cx)
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(
project.fs().clone(),
worktree.read(cx),
cx,
)
})
.collect::<Vec<_>>();
cx.spawn(async move |this, cx| {
let results = futures::future::join_all(tasks).await;
let worktrees = results
.into_iter()
.map(|(worktree, rules_error)| {
if let Some(rules_error) = rules_error {
this.update(cx, |_, cx| cx.emit(rules_error)).ok();
}
worktree
})
.collect::<Vec<_>>();
this.update(cx, |this, _cx| {
*this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
})
.ok();
})
}
fn load_worktree_info_for_system_prompt(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
let root_name = worktree.root_name().into();
let abs_path = worktree.abs_path();
let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
let Some(rules_task) = rules_task else {
return Task::ready((
WorktreeContext {
root_name,
abs_path,
rules_file: None,
},
None,
));
};
cx.spawn(async move |_| {
let (rules_file, rules_file_error) = match rules_task.await {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(RulesLoadingError {
message: format!("{err}").into(),
}),
),
let this = Self {
project,
tools,
prompt_builder,
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
_subscriptions: vec![settings_subscription],
};
let worktree_info = WorktreeContext {
root_name,
abs_path,
rules_file,
};
(worktree_info, rules_file_error)
})
}
this.load_default_profile(cx);
this.register_context_server_handlers(cx);
this.reload(cx).detach_and_log_err(cx);
fn load_worktree_rules_file(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Option<Task<Result<RulesFileContext>>> {
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
})
.next();
this
});
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
selected_rules_file.map(|(path_in_worktree, abs_path)| {
let fs = fs.clone();
cx.background_spawn(async move {
let abs_path = abs_path?;
let text = fs.load(&abs_path).await.with_context(|| {
format!("Failed to load assistant rules file {:?}", abs_path)
})?;
anyhow::Ok(RulesFileContext {
path_in_worktree,
abs_path: abs_path.into(),
text: text.trim().to_string(),
})
})
})
Ok(this)
}
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
self.context_server_manager.clone()
}
pub fn tools(&self) -> Entity<ToolWorkingSet> {
pub fn tools(&self) -> Arc<ToolWorkingSet> {
self.tools.clone()
}
@@ -273,7 +107,6 @@ impl ThreadStore {
self.project.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
cx,
)
})
@@ -301,12 +134,21 @@ impl ThreadStore {
this.project.clone(),
this.tools.clone(),
this.prompt_builder.clone(),
this.project_context.clone(),
cx,
)
})
})?;
let (system_prompt_context, load_error) = thread
.update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
.await;
thread.update(cx, |thread, cx| {
thread.set_system_prompt_context(system_prompt_context);
if let Some(load_error) = load_error {
cx.emit(ThreadEvent::ShowError(load_error));
}
})?;
Ok(thread)
})
}
@@ -355,60 +197,52 @@ impl ThreadStore {
})
}
fn load_default_profile(&self, cx: &mut Context<Self>) {
fn load_default_profile(&self, cx: &Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
self.load_profile_by_id(&assistant_settings.default_profile, cx);
}
pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
self.load_profile(profile.clone(), cx);
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
self.load_profile(profile, cx);
}
}
pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
self.tools.update(cx, |tools, cx| {
tools.disable_all_tools(cx);
tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
);
});
pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
self.tools.disable_all_tools();
self.tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
);
if profile.enable_all_context_servers {
for context_server in self.context_server_manager.read(cx).all_servers() {
self.tools.update(cx, |tools, cx| {
tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
});
self.tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
}
} else {
for (context_server_id, preset) in &profile.context_servers {
self.tools.update(cx, |tools, cx| {
tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
)
})
self.tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
)
}
}
}
@@ -442,36 +276,29 @@ impl ThreadStore {
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>()
let tool_ids = tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.log_err();
.collect::<Vec<_>>();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, cx| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
}
this.update(cx, |this, cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
}
}
}
@@ -481,9 +308,7 @@ impl ThreadStore {
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
tool_working_set.remove(&tool_ids);
self.load_default_profile(cx);
}
}
@@ -509,11 +334,7 @@ pub struct SerializedThread {
#[serde(default)]
pub cumulative_token_usage: TokenUsage,
#[serde(default)]
pub request_token_usage: Vec<TokenUsage>,
#[serde(default)]
pub detailed_summary_state: DetailedSummaryState,
#[serde(default)]
pub exceeded_window_error: Option<ExceededWindowError>,
}
impl SerializedThread {
@@ -599,9 +420,7 @@ impl LegacySerializedThread {
messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
initial_project_snapshot: self.initial_project_snapshot,
cumulative_token_usage: TokenUsage::default(),
request_token_usage: Vec::new(),
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
}
}
}
@@ -672,7 +491,7 @@ impl ThreadsDatabase {
let database_future = executor
.spawn({
let executor = executor.clone();
let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
async move { ThreadsDatabase::new(database_path, executor) }
})
.then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))

View File

@@ -1,94 +0,0 @@
use std::sync::Arc;
use assistant_tool::{Tool, ToolSource, ToolWorkingSet, ToolWorkingSetEvent};
use collections::HashMap;
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
use language_model::{LanguageModel, LanguageModelToolSchemaFormat};
use ui::prelude::*;
pub struct IncompatibleToolsState {
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
tool_working_set: Entity<ToolWorkingSet>,
_tool_working_set_subscription: Subscription,
}
impl IncompatibleToolsState {
pub fn new(tool_working_set: Entity<ToolWorkingSet>, cx: &mut Context<Self>) -> Self {
let _tool_working_set_subscription =
cx.subscribe(&tool_working_set, |this, _, event, _| match event {
ToolWorkingSetEvent::EnabledToolsChanged => {
this.cache.clear();
}
});
Self {
cache: HashMap::default(),
tool_working_set,
_tool_working_set_subscription,
}
}
pub fn incompatible_tools(
&mut self,
model: &Arc<dyn LanguageModel>,
cx: &App,
) -> &[Arc<dyn Tool>] {
self.cache
.entry(model.tool_input_format())
.or_insert_with(|| {
self.tool_working_set
.read(cx)
.enabled_tools(cx)
.iter()
.filter(|tool| tool.input_schema(model.tool_input_format()).is_err())
.cloned()
.collect()
})
}
}
pub struct IncompatibleToolsTooltip {
pub incompatible_tools: Vec<Arc<dyn Tool>>,
}
impl Render for IncompatibleToolsTooltip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
ui::tooltip_container(window, cx, |container, _, cx| {
container
.w_72()
.child(Label::new("Incompatible Tools").size(LabelSize::Small))
.child(
Label::new(
"This model is incompatible with the following tools from your MCPs:",
)
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(
v_flex()
.my_1p5()
.py_0p5()
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.children(
self.incompatible_tools
.iter()
.map(|tool| h_flex().gap_4().child(Label::new(tool.name()).size(LabelSize::Small)).map(|parent|
match tool.source() {
ToolSource::Native => parent,
ToolSource::ContextServer { id } => parent.child(Label::new(id).size(LabelSize::Small).color(Color::Muted)),
}
)),
),
)
.child(Label::new("What To Do Instead").size(LabelSize::Small))
.child(
Label::new(
"Every other tool continues to work with this model, but to specifically use those, switch to another model.",
)
.size(LabelSize::Small)
.color(Color::Muted),
)
})
}
}

View File

@@ -1,11 +1,11 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
use assistant_tool::{Tool, ToolWorkingSet};
use collections::HashMap;
use futures::FutureExt as _;
use futures::future::Shared;
use gpui::{App, Entity, SharedString, Task};
use gpui::{App, SharedString, Task};
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
@@ -27,26 +27,45 @@ pub struct ToolUse {
pub needs_confirmation: bool,
}
pub const USING_TOOL_MARKER: &str = "<using_tool>";
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
NeedsConfirmation,
Pending,
Running,
Finished(SharedString),
Error(SharedString),
}
impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
ToolUseStatus::Error(out) => out.clone(),
}
}
}
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
}
pub const USING_TOOL_MARKER: &str = "<using_tool>";
impl ToolUseState {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
}
}
@@ -54,7 +73,7 @@ impl ToolUseState {
///
/// Accepts a function to filter the tools that should be used to populate the state.
pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>,
tools: Arc<ToolWorkingSet>,
messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self {
@@ -180,12 +199,12 @@ impl ToolUseState {
}
})();
let (icon, needs_confirmation) =
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
{
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
@@ -207,7 +226,7 @@ impl ToolUseState {
input: &serde_json::Value,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
if let Some(tool) = self.tools.tool(tool_name, cx) {
tool.ui_text(input).into()
} else {
format!("Unknown tool {tool_name:?}").into()
@@ -238,18 +257,6 @@ impl ToolUseState {
self.tool_results.get(tool_use_id)
}
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
self.tool_result_cards.get(tool_use_id)
}
pub fn insert_tool_result_card(
&mut self,
tool_use_id: LanguageModelToolUseId,
card: AnyToolCard,
) {
self.tool_result_cards.insert(tool_use_id, card);
}
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,

View File

@@ -1,7 +1,7 @@
mod agent_notification;
mod context_pill;
mod usage_banner;
mod user_spending;
pub use agent_notification::*;
pub use context_pill::*;
pub use usage_banner::*;
// pub use user_spending::*;

View File

@@ -191,12 +191,15 @@ impl RenderOnce for ContextPill {
ContextPill::Suggested {
name,
icon_path: _,
kind: _,
kind,
focused,
on_click,
} => base_pill
.cursor_pointer()
.pr_1()
.when(*focused, |this| {
this.bg(color.element_background.opacity(0.5))
})
.border_dashed()
.border_color(if *focused {
color.border_focused
@@ -204,17 +207,30 @@ impl RenderOnce for ContextPill {
color.border
})
.hover(|style| style.bg(color.element_hover.opacity(0.5)))
.when(*focused, |this| {
this.bg(color.element_background.opacity(0.5))
})
.child(
div().max_w_64().child(
div().px_0p5().max_w_64().child(
Label::new(name.clone())
.size(LabelSize::Small)
.color(Color::Muted)
.truncate(),
),
)
.child(
Label::new(match kind {
ContextKind::File => "Active Tab",
ContextKind::Thread
| ContextKind::Directory
| ContextKind::FetchedUrl
| ContextKind::Symbol => "Active",
})
.size(LabelSize::XSmall)
.color(Color::Muted),
)
.child(
Icon::new(IconName::Plus)
.size(IconSize::XSmall)
.into_any_element(),
)
.tooltip(|window, cx| {
Tooltip::with_meta("Suggested Context", None, "Click to add it", window, cx)
})
@@ -299,39 +315,6 @@ impl AddedContext {
summarizing: false,
},
AssistantContext::Excerpt(excerpt_context) => {
let full_path = excerpt_context.context_buffer.file.full_path(cx);
let mut full_path_string = full_path.to_string_lossy().into_owned();
let mut name = full_path
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| full_path_string.clone());
let line_range_text = format!(
" ({}-{})",
excerpt_context.line_range.start.row + 1,
excerpt_context.line_range.end.row + 1
);
full_path_string.push_str(&line_range_text);
name.push_str(&line_range_text);
let parent = full_path
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: excerpt_context.id,
kind: ContextKind::File, // Use File icon for excerpts
name: name.into(),
parent,
tooltip: Some(full_path_string.into()),
icon_path: FileIcons::get_icon(&full_path, cx),
summarizing: false,
}
}
AssistantContext::FetchedUrl(fetched_url_context) => AddedContext {
id: fetched_url_context.id,
kind: ContextKind::FetchedUrl,

View File

@@ -1,202 +0,0 @@
use client::zed_urls;
use ui::{Banner, ProgressBar, Severity, prelude::*};
use zed_llm_client::{Plan, UsageLimit};
#[derive(IntoElement, RegisterComponent)]
pub struct UsageBanner {
plan: Plan,
requests: i32,
}
impl UsageBanner {
pub fn new(plan: Plan, requests: i32) -> Self {
Self { plan, requests }
}
}
impl RenderOnce for UsageBanner {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let request_limit = self.plan.model_requests_limit();
let used_percentage = match request_limit {
UsageLimit::Limited(limit) => Some((self.requests as f32 / limit as f32) * 100.),
UsageLimit::Unlimited => None,
};
let (severity, message) = match request_limit {
UsageLimit::Limited(limit) => {
if self.requests >= limit {
let message = match self.plan {
Plan::ZedPro => "Monthly request limit reached",
Plan::ZedProTrial => "Trial request limit reached",
Plan::Free => "Free tier request limit reached",
};
(Severity::Error, message)
} else if (self.requests as f32 / limit as f32) >= 0.9 {
(Severity::Warning, "Approaching request limit")
} else {
let message = match self.plan {
Plan::ZedPro => "Zed Pro",
Plan::ZedProTrial => "Zed Pro (Trial)",
Plan::Free => "Zed Free",
};
(Severity::Info, message)
}
}
UsageLimit::Unlimited => {
let message = match self.plan {
Plan::ZedPro => "Zed Pro",
Plan::ZedProTrial => "Zed Pro (Trial)",
Plan::Free => "Zed Free",
};
(Severity::Info, message)
}
};
let action = match self.plan {
Plan::ZedProTrial | Plan::Free => {
Button::new("upgrade", "Upgrade").on_click(|_, _window, cx| {
cx.open_url(&zed_urls::account_url(cx));
})
}
Plan::ZedPro => Button::new("manage", "Manage").on_click(|_, _window, cx| {
cx.open_url(&zed_urls::account_url(cx));
}),
};
Banner::new().severity(severity).children(
h_flex().flex_1().gap_1().child(Label::new(message)).child(
h_flex()
.flex_1()
.justify_end()
.gap_1p5()
.children(used_percentage.map(|percent| {
h_flex()
.items_center()
.w_full()
.max_w(px(180.))
.child(ProgressBar::new("usage", percent, 100., cx))
}))
.child(
Label::new(match request_limit {
UsageLimit::Limited(limit) => {
format!("{} / {limit}", self.requests)
}
UsageLimit::Unlimited => format!("{} / ∞", self.requests),
})
.size(LabelSize::Small)
.color(Color::Muted),
)
// Note: This should go in the banner's `action_slot`, but doing that messes with the size of the
// progress bar.
.child(action),
),
)
}
}
impl Component for UsageBanner {
fn sort_name() -> &'static str {
"AgentUsageBanner"
}
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
let trial_examples = vec![
single_example(
"Zed Pro Trial - New User",
div()
.size_full()
.child(UsageBanner::new(Plan::ZedProTrial, 10))
.into_any_element(),
),
single_example(
"Zed Pro Trial - Approaching Limit",
div()
.size_full()
.child(UsageBanner::new(Plan::ZedProTrial, 135))
.into_any_element(),
),
single_example(
"Zed Pro Trial - Request Limit Reached",
div()
.size_full()
.child(UsageBanner::new(Plan::ZedProTrial, 150))
.into_any_element(),
),
];
let free_examples = vec![
single_example(
"Free - Normal Usage",
div()
.size_full()
.child(UsageBanner::new(Plan::Free, 25))
.into_any_element(),
),
single_example(
"Free - Approaching Limit",
div()
.size_full()
.child(UsageBanner::new(Plan::Free, 45))
.into_any_element(),
),
single_example(
"Free - Request Limit Reached",
div()
.size_full()
.child(UsageBanner::new(Plan::Free, 50))
.into_any_element(),
),
];
let zed_pro_examples = vec![
single_example(
"Zed Pro - Normal Usage",
div()
.size_full()
.child(UsageBanner::new(Plan::ZedPro, 250))
.into_any_element(),
),
single_example(
"Zed Pro - Approaching Limit",
div()
.size_full()
.child(UsageBanner::new(Plan::ZedPro, 450))
.into_any_element(),
),
single_example(
"Zed Pro - Request Limit Reached",
div()
.size_full()
.child(UsageBanner::new(Plan::ZedPro, 500))
.into_any_element(),
),
];
Some(
v_flex()
.gap_6()
.p_4()
.children(vec![
Label::new("Trial Plan")
.size(LabelSize::Large)
.into_any_element(),
example_group(trial_examples).vertical().into_any_element(),
Label::new("Free Plan")
.size(LabelSize::Large)
.into_any_element(),
example_group(free_examples).vertical().into_any_element(),
Label::new("Pro Plan")
.size(LabelSize::Large)
.into_any_element(),
example_group(zed_pro_examples)
.vertical()
.into_any_element(),
])
.into_any_element(),
)
}
}

View File

@@ -0,0 +1,186 @@
use gpui::{Entity, Render};
use ui::{ProgressBar, prelude::*};
#[derive(RegisterComponent)]
pub struct UserSpending {
free_tier_current: u32,
free_tier_cap: u32,
over_tier_current: u32,
over_tier_cap: u32,
free_tier_progress: Entity<ProgressBar>,
over_tier_progress: Entity<ProgressBar>,
}
impl UserSpending {
pub fn new(
free_tier_current: u32,
free_tier_cap: u32,
over_tier_current: u32,
over_tier_cap: u32,
cx: &mut App,
) -> Self {
let free_tier_capped = free_tier_current == free_tier_cap;
let free_tier_near_capped =
free_tier_current as f32 / 100.0 >= free_tier_cap as f32 / 100.0 * 0.9;
let over_tier_capped = over_tier_current == over_tier_cap;
let over_tier_near_capped =
over_tier_current as f32 / 100.0 >= over_tier_cap as f32 / 100.0 * 0.9;
let free_tier_progress = cx.new(|cx| {
ProgressBar::new(
"free_tier",
free_tier_current as f32,
free_tier_cap as f32,
cx,
)
});
let over_tier_progress = cx.new(|cx| {
ProgressBar::new(
"over_tier",
over_tier_current as f32,
over_tier_cap as f32,
cx,
)
});
if free_tier_capped {
free_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().error);
});
} else if free_tier_near_capped {
free_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().warning);
});
}
if over_tier_capped {
over_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().error);
});
} else if over_tier_near_capped {
over_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().warning);
});
}
Self {
free_tier_current,
free_tier_cap,
over_tier_current,
over_tier_cap,
free_tier_progress,
over_tier_progress,
}
}
}
impl Render for UserSpending {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let formatted_free_tier = format!(
"${} / ${}",
self.free_tier_current as f32 / 100.0,
self.free_tier_cap as f32 / 100.0
);
let formatted_over_tier = format!(
"${} / ${}",
self.over_tier_current as f32 / 100.0,
self.over_tier_cap as f32 / 100.0
);
v_group()
.elevation_2(cx)
.py_1p5()
.px_2p5()
.w(px(360.))
.child(
v_flex()
.child(
v_flex()
.p_1p5()
.gap_0p5()
.child(
h_flex()
.justify_between()
.child(Label::new("Free Tier Usage").size(LabelSize::Small))
.child(
Label::new(formatted_free_tier)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(self.free_tier_progress.clone()),
)
.child(
v_flex()
.p_1p5()
.gap_0p5()
.child(
h_flex()
.justify_between()
.child(Label::new("Current Spending").size(LabelSize::Small))
.child(
Label::new(formatted_over_tier)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(self.over_tier_progress.clone()),
),
)
}
}
impl Component for UserSpending {
fn scope() -> ComponentScope {
ComponentScope::None
}
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
let new_user = cx.new(|cx| UserSpending::new(0, 2000, 0, 2000, cx));
let free_capped = cx.new(|cx| UserSpending::new(2000, 2000, 0, 2000, cx));
let free_near_capped = cx.new(|cx| UserSpending::new(1800, 2000, 0, 2000, cx));
let over_near_capped = cx.new(|cx| UserSpending::new(2000, 2000, 1800, 2000, cx));
let over_capped = cx.new(|cx| UserSpending::new(1000, 2000, 2000, 2000, cx));
Some(
v_flex()
.gap_6()
.p_4()
.children(vec![example_group(vec![
single_example(
"New User",
div().size_full().child(new_user.clone()).into_any_element(),
),
single_example(
"Free Tier Capped",
div()
.size_full()
.child(free_capped.clone())
.into_any_element(),
),
single_example(
"Free Tier Near Capped",
div()
.size_full()
.child(free_near_capped.clone())
.into_any_element(),
),
single_example(
"Over Tier Near Capped",
div()
.size_full()
.child(over_near_capped.clone())
.into_any_element(),
),
single_example(
"Over Tier Capped",
div()
.size_full()
.child(over_capped.clone())
.into_any_element(),
),
])])
.into_any_element(),
)
}
}

View File

@@ -0,0 +1,46 @@
[package]
name = "agent_eval"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[[bin]]
name = "agent_eval"
path = "src/main.rs"
[dependencies]
agent.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
clap.workspace = true
client.workspace = true
collections.workspace = true
context_server.workspace = true
dap.workspace = true
env_logger.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
language.workspace = true
language_model.workspace = true
language_models.workspace = true
node_runtime.workspace = true
project.workspace = true
prompt_store.workspace = true
release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
smol.workspace = true
tempfile.workspace = true
util.workspace = true
walkdir.workspace = true
workspace-hack.workspace = true

View File

@@ -0,0 +1,52 @@
// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
use std::process::Command;
fn main() {
if cfg!(target_os = "macos") {
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
// Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
// Seems to be required to enable Swift concurrency
println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
// Register exported Objective-C selectors, protocols, etc
println!("cargo:rustc-link-arg=-Wl,-ObjC");
}
// Populate git sha environment variable if git is available
println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
println!(
"cargo:rustc-env=TARGET={}",
std::env::var("TARGET").unwrap()
);
if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
if output.status.success() {
let git_sha = String::from_utf8_lossy(&output.stdout);
let git_sha = git_sha.trim();
println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
if let Ok(build_profile) = std::env::var("PROFILE") {
if build_profile == "release" {
// This is currently the best way to make `cargo build ...`'s build script
// to print something to stdout without extra verbosity.
println!(
"cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
);
}
}
}
}
#[cfg(target_os = "windows")]
{
#[cfg(target_env = "msvc")]
{
// todo(windows): This is to avoid stack overflow. Remove it when solved.
println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
}
}
}

View File

@@ -0,0 +1,384 @@
use crate::git_commands::{run_git, setup_temp_repo};
use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
use crate::{get_exercise_language, get_exercise_name};
use agent::RequestKind;
use anyhow::{Result, anyhow};
use collections::HashMap;
use gpui::{App, Task};
use language_model::{LanguageModel, TokenUsage};
use serde::{Deserialize, Serialize};
use std::{
fs,
io::Write,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, SystemTime},
};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct EvalResult {
pub exercise_name: String,
pub diff: String,
pub assistant_response: String,
pub elapsed_time_ms: u128,
pub timestamp: u128,
// Token usage fields
pub input_tokens: usize,
pub output_tokens: usize,
pub total_tokens: usize,
pub tool_use_counts: usize,
}
pub struct EvalOutput {
pub diff: String,
pub last_message: String,
pub elapsed_time: Duration,
pub assistant_response_count: usize,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub token_usage: TokenUsage,
}
#[derive(Deserialize)]
pub struct EvalSetup {
pub url: String,
pub base_sha: String,
}
pub struct Eval {
pub repo_path: PathBuf,
pub eval_setup: EvalSetup,
pub user_prompt: String,
}
impl Eval {
// Keep this method for potential future use, but mark it as intentionally unused
#[allow(dead_code)]
pub async fn load(_name: String, path: PathBuf, repos_dir: &Path) -> Result<Self> {
let prompt_path = path.join("prompt.txt");
let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
let setup_path = path.join("setup.json");
let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
// Move this internal function inside the load method since it's only used here
fn repo_dir_name(url: &str) -> String {
url.trim_start_matches("https://")
.replace(|c: char| !c.is_alphanumeric(), "_")
}
let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
Ok(Eval {
repo_path,
eval_setup,
user_prompt,
})
}
pub fn run(
self,
app_state: Arc<HeadlessAppState>,
model: Arc<dyn LanguageModel>,
cx: &mut App,
) -> Task<Result<EvalOutput>> {
cx.spawn(async move |cx| {
run_git(&self.repo_path, &["checkout", &self.eval_setup.base_sha]).await?;
let (assistant, done_rx) =
cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
let _worktree = assistant
.update(cx, |assistant, cx| {
assistant.project.update(cx, |project, cx| {
project.create_worktree(&self.repo_path, true, cx)
})
})?
.await?;
let start_time = std::time::SystemTime::now();
let (system_prompt_context, load_error) = cx
.update(|cx| {
assistant
.read(cx)
.thread
.read(cx)
.load_system_prompt_context(cx)
})?
.await;
if let Some(load_error) = load_error {
return Err(anyhow!("{:?}", load_error));
};
assistant.update(cx, |assistant, cx| {
assistant.thread.update(cx, |thread, cx| {
let context = vec![];
thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
thread.set_system_prompt_context(system_prompt_context);
thread.send_to_model(model, RequestKind::Chat, cx);
});
})?;
done_rx.recv().await??;
// Add this section to check untracked files
println!("Checking for untracked files:");
let untracked = run_git(
&self.repo_path,
&["ls-files", "--others", "--exclude-standard"],
)
.await?;
if untracked.is_empty() {
println!("No untracked files found");
} else {
// Add all files to git so they appear in the diff
println!("Adding untracked files to git");
run_git(&self.repo_path, &["add", "."]).await?;
}
// get git status
let _status = run_git(&self.repo_path, &["status", "--short"]).await?;
let elapsed_time = start_time.elapsed()?;
// Get diff of staged changes (the files we just added)
let staged_diff = run_git(&self.repo_path, &["diff", "--staged"]).await?;
// Get diff of unstaged changes
let unstaged_diff = run_git(&self.repo_path, &["diff"]).await?;
// Combine both diffs
let diff = if unstaged_diff.is_empty() {
staged_diff
} else if staged_diff.is_empty() {
unstaged_diff
} else {
format!(
"# Staged changes\n{}\n\n# Unstaged changes\n{}",
staged_diff, unstaged_diff
)
};
assistant.update(cx, |assistant, cx| {
let thread = assistant.thread.read(cx);
let last_message = thread.messages().last().unwrap();
if last_message.role != language_model::Role::Assistant {
return Err(anyhow!("Last message is not from assistant"));
}
let assistant_response_count = thread
.messages()
.filter(|message| message.role == language_model::Role::Assistant)
.count();
Ok(EvalOutput {
diff,
last_message: last_message.to_string(),
elapsed_time,
assistant_response_count,
tool_use_counts: assistant.tool_use_counts.clone(),
token_usage: thread.cumulative_token_usage(),
})
})?
})
}
}
impl EvalOutput {
// Keep this method for potential future use, but mark it as intentionally unused
#[allow(dead_code)]
pub fn save_to_directory(&self, output_dir: &Path, eval_output_value: String) -> Result<()> {
// Create the output directory if it doesn't exist
fs::create_dir_all(&output_dir)?;
// Save the diff to a file
let diff_path = output_dir.join("diff.patch");
let mut diff_file = fs::File::create(&diff_path)?;
diff_file.write_all(self.diff.as_bytes())?;
// Save the last message to a file
let message_path = output_dir.join("assistant_response.txt");
let mut message_file = fs::File::create(&message_path)?;
message_file.write_all(self.last_message.as_bytes())?;
// Current metrics for this run
let current_metrics = serde_json::json!({
"elapsed_time_ms": self.elapsed_time.as_millis(),
"assistant_response_count": self.assistant_response_count,
"tool_use_counts": self.tool_use_counts,
"token_usage": self.token_usage,
"eval_output_value": eval_output_value,
});
// Get current timestamp in milliseconds
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_millis()
.to_string();
// Path to metrics file
let metrics_path = output_dir.join("metrics.json");
// Load existing metrics if the file exists, or create a new object
let mut historical_metrics = if metrics_path.exists() {
let metrics_content = fs::read_to_string(&metrics_path)?;
serde_json::from_str::<serde_json::Value>(&metrics_content)
.unwrap_or_else(|_| serde_json::json!({}))
} else {
serde_json::json!({})
};
// Add new run with timestamp as key
if let serde_json::Value::Object(ref mut map) = historical_metrics {
map.insert(timestamp, current_metrics);
}
// Write updated metrics back to file
let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
let mut metrics_file = fs::File::create(&metrics_path)?;
metrics_file.write_all(metrics_json.as_bytes())?;
Ok(())
}
}
pub async fn read_instructions(exercise_path: &Path) -> Result<String> {
let instructions_path = exercise_path.join(".docs").join("instructions.md");
println!("Reading instructions from: {}", instructions_path.display());
let instructions = smol::unblock(move || std::fs::read_to_string(&instructions_path)).await?;
Ok(instructions)
}
pub async fn save_eval_results(exercise_path: &Path, results: Vec<EvalResult>) -> Result<()> {
let eval_dir = exercise_path.join("evaluation");
fs::create_dir_all(&eval_dir)?;
let eval_file = eval_dir.join("evals.json");
println!("Saving evaluation results to: {}", eval_file.display());
println!(
"Results to save: {} evaluations for exercise path: {}",
results.len(),
exercise_path.display()
);
// Check file existence before reading/writing
if eval_file.exists() {
println!("Existing evals.json file found, will update it");
} else {
println!("No existing evals.json file found, will create new one");
}
// Structure to organize evaluations by test name and timestamp
let mut eval_data: serde_json::Value = if eval_file.exists() {
let content = fs::read_to_string(&eval_file)?;
serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({}))
} else {
serde_json::json!({})
};
// Get current timestamp for this batch of results
let timestamp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_millis()
.to_string();
// Group the new results by test name (exercise name)
for result in results {
let exercise_name = &result.exercise_name;
println!("Adding result: exercise={}", exercise_name);
// Ensure the exercise entry exists
if eval_data.get(exercise_name).is_none() {
eval_data[exercise_name] = serde_json::json!({});
}
// Ensure the timestamp entry exists as an object
if eval_data[exercise_name].get(&timestamp).is_none() {
eval_data[exercise_name][&timestamp] = serde_json::json!({});
}
// Add this result under the timestamp with template name as key
eval_data[exercise_name][&timestamp] = serde_json::to_value(&result)?;
}
// Write back to file with pretty formatting
let json_content = serde_json::to_string_pretty(&eval_data)?;
match fs::write(&eval_file, json_content) {
Ok(_) => println!("✓ Successfully saved results to {}", eval_file.display()),
Err(e) => println!("✗ Failed to write results file: {}", e),
}
Ok(())
}
pub async fn run_exercise_eval(
exercise_path: PathBuf,
model: Arc<dyn LanguageModel>,
app_state: Arc<HeadlessAppState>,
base_sha: String,
_framework_path: PathBuf,
cx: gpui::AsyncApp,
) -> Result<EvalResult> {
let exercise_name = get_exercise_name(&exercise_path);
let language = get_exercise_language(&exercise_path)?;
let mut instructions = read_instructions(&exercise_path).await?;
instructions.push_str(&format!(
"\n\nWhen writing the code for this prompt, use {} to achieve the goal.",
language
));
println!("Running evaluation for exercise: {}", exercise_name);
// Create temporary directory with exercise files
let temp_dir = setup_temp_repo(&exercise_path, &base_sha).await?;
let temp_path = temp_dir.path().to_path_buf();
let local_commit_sha = run_git(&temp_path, &["rev-parse", "HEAD"]).await?;
let start_time = SystemTime::now();
// Create a basic eval struct to work with the existing system
let eval = Eval {
repo_path: temp_path.clone(),
eval_setup: EvalSetup {
url: format!("file://{}", temp_path.display()),
base_sha: local_commit_sha, // Use the local commit SHA instead of the framework base SHA
},
user_prompt: instructions.clone(),
};
// Run the evaluation
let eval_output = cx
.update(|cx| eval.run(app_state.clone(), model.clone(), cx))?
.await?;
// Get diff from git
let diff = eval_output.diff.clone();
let elapsed_time = start_time.elapsed()?;
// Calculate total tokens as the sum of input and output tokens
let input_tokens = eval_output.token_usage.input_tokens;
let output_tokens = eval_output.token_usage.output_tokens;
let tool_use_counts = eval_output.tool_use_counts.values().sum::<u32>();
let total_tokens = input_tokens + output_tokens;
// Save results to evaluation directory
let result = EvalResult {
exercise_name: exercise_name.clone(),
diff,
assistant_response: eval_output.last_message.clone(),
elapsed_time_ms: elapsed_time.as_millis(),
timestamp: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_millis(),
// Convert u32 token counts to usize
input_tokens: input_tokens.try_into().unwrap(),
output_tokens: output_tokens.try_into().unwrap(),
total_tokens: total_tokens.try_into().unwrap(),
tool_use_counts: tool_use_counts.try_into().unwrap(),
};
Ok(result)
}

View File

@@ -0,0 +1,149 @@
use anyhow::{Result, anyhow};
use std::{
fs,
path::{Path, PathBuf},
};
pub fn get_exercise_name(exercise_path: &Path) -> String {
exercise_path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string()
}
pub fn get_exercise_language(exercise_path: &Path) -> Result<String> {
// Extract the language from path (data/python/exercises/... => python)
let parts: Vec<_> = exercise_path.components().collect();
for (i, part) in parts.iter().enumerate() {
if i > 0 && part.as_os_str() == "eval_code" {
if i + 1 < parts.len() {
let language = parts[i + 1].as_os_str().to_string_lossy().to_string();
return Ok(language);
}
}
}
Err(anyhow!(
"Could not determine language from path: {:?}",
exercise_path
))
}
pub fn find_exercises(
framework_path: &Path,
languages: &[&str],
max_per_language: Option<usize>,
) -> Result<Vec<PathBuf>> {
let mut all_exercises = Vec::new();
println!("Searching for exercises in languages: {:?}", languages);
for language in languages {
let language_dir = framework_path
.join("eval_code")
.join(language)
.join("exercises")
.join("practice");
println!("Checking language directory: {:?}", language_dir);
if !language_dir.exists() {
println!("Warning: Language directory not found: {:?}", language_dir);
continue;
}
let mut exercises = Vec::new();
match fs::read_dir(&language_dir) {
Ok(entries) => {
for entry_result in entries {
match entry_result {
Ok(entry) => {
let path = entry.path();
if path.is_dir() {
// Special handling for "internal" directory
if *language == "internal" {
// Check for repo_info.json to validate it's an internal exercise
let repo_info_path = path.join(".meta").join("repo_info.json");
let instructions_path =
path.join(".docs").join("instructions.md");
if repo_info_path.exists() && instructions_path.exists() {
exercises.push(path);
}
} else {
// Map the language to the file extension - original code
let language_extension = match *language {
"python" => "py",
"go" => "go",
"rust" => "rs",
"typescript" => "ts",
"javascript" => "js",
"ruby" => "rb",
"php" => "php",
"bash" => "sh",
"multi" => "diff",
_ => continue, // Skip unsupported languages
};
// Check if this is a valid exercise with instructions and example
let instructions_path =
path.join(".docs").join("instructions.md");
let has_instructions = instructions_path.exists();
let example_path = path
.join(".meta")
.join(format!("example.{}", language_extension));
let has_example = example_path.exists();
if has_instructions && has_example {
exercises.push(path);
}
}
}
}
Err(err) => println!("Error reading directory entry: {}", err),
}
}
}
Err(err) => println!(
"Error reading directory {}: {}",
language_dir.display(),
err
),
}
// Sort exercises by name for consistent selection
exercises.sort_by(|a, b| {
let a_name = a.file_name().unwrap_or_default().to_string_lossy();
let b_name = b.file_name().unwrap_or_default().to_string_lossy();
a_name.cmp(&b_name)
});
// Apply the limit if specified
if let Some(limit) = max_per_language {
if exercises.len() > limit {
println!(
"Limiting {} exercises to {} for language {}",
exercises.len(),
limit,
language
);
exercises.truncate(limit);
}
}
println!(
"Found {} exercises for language {}: {:?}",
exercises.len(),
language,
exercises
.iter()
.map(|p| p.file_name().unwrap_or_default().to_string_lossy())
.collect::<Vec<_>>()
);
all_exercises.extend(exercises);
}
Ok(all_exercises)
}

View File

@@ -0,0 +1,125 @@
use anyhow::{Result, anyhow};
use serde::Deserialize;
use std::{fs, path::Path};
use tempfile::TempDir;
use util::command::new_smol_command;
use walkdir::WalkDir;
#[derive(Debug, Deserialize)]
pub struct SetupConfig {
#[serde(rename = "base.sha")]
pub base_sha: String,
}
#[derive(Debug, Deserialize)]
pub struct RepoInfo {
pub remote_url: String,
pub head_sha: String,
}
pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = new_smol_command("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
if output.status.success() {
Ok(String::from_utf8(output.stdout)?.trim().to_string())
} else {
Err(anyhow!(
"Git command failed: {} with status: {}",
args.join(" "),
output.status
))
}
}
pub async fn read_base_sha(framework_path: &Path) -> Result<String> {
let setup_path = framework_path.join("setup.json");
let setup_content = smol::unblock(move || std::fs::read_to_string(&setup_path)).await?;
let setup_config: SetupConfig = serde_json_lenient::from_str_lenient(&setup_content)?;
Ok(setup_config.base_sha)
}
pub async fn read_repo_info(exercise_path: &Path) -> Result<RepoInfo> {
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
println!("Reading repo info from: {}", repo_info_path.display());
let repo_info_content = smol::unblock(move || std::fs::read_to_string(&repo_info_path)).await?;
let repo_info: RepoInfo = serde_json_lenient::from_str_lenient(&repo_info_content)?;
// Remove any quotes from the strings
let remote_url = repo_info.remote_url.trim_matches('"').to_string();
let head_sha = repo_info.head_sha.trim_matches('"').to_string();
Ok(RepoInfo {
remote_url,
head_sha,
})
}
pub async fn setup_temp_repo(exercise_path: &Path, _base_sha: &str) -> Result<TempDir> {
let temp_dir = TempDir::new()?;
// Check if this is an internal exercise by looking for repo_info.json
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
if repo_info_path.exists() {
// This is an internal exercise, handle it differently
let repo_info = read_repo_info(exercise_path).await?;
// Clone the repository to the temp directory
let url = repo_info.remote_url;
let clone_path = temp_dir.path();
println!(
"Cloning repository from {} to {}",
url,
clone_path.display()
);
run_git(
&std::env::current_dir()?,
&["clone", &url, &clone_path.to_string_lossy()],
)
.await?;
// Checkout the specified commit
println!("Checking out commit: {}", repo_info.head_sha);
run_git(temp_dir.path(), &["checkout", &repo_info.head_sha]).await?;
println!("Successfully set up internal repository");
} else {
// Original code for regular exercises
// Copy the exercise files to the temp directory, excluding .docs and .meta
for entry in WalkDir::new(exercise_path).min_depth(0).max_depth(10) {
let entry = entry?;
let source_path = entry.path();
// Skip .docs and .meta directories completely
if source_path.starts_with(exercise_path.join(".docs"))
|| source_path.starts_with(exercise_path.join(".meta"))
{
continue;
}
if source_path.is_file() {
let relative_path = source_path.strip_prefix(exercise_path)?;
let dest_path = temp_dir.path().join(relative_path);
// Make sure parent directories exist
if let Some(parent) = dest_path.parent() {
fs::create_dir_all(parent)?;
}
fs::copy(source_path, dest_path)?;
}
}
// Initialize git repo in the temp directory
run_git(temp_dir.path(), &["init"]).await?;
run_git(temp_dir.path(), &["add", "."]).await?;
run_git(temp_dir.path(), &["commit", "-m", "Initial commit"]).await?;
println!("Created temp repo without .docs and .meta directories");
}
Ok(temp_dir)
}

View File

@@ -0,0 +1,229 @@
use agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
use anyhow::anyhow;
use assistant_tool::ToolWorkingSet;
use client::{Client, UserStore};
use collections::HashMap;
use dap::DapRegistry;
use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use node_runtime::NodeRuntime;
use project::{Project, RealFs};
use prompt_store::PromptBuilder;
use settings::SettingsStore;
use smol::channel;
use std::sync::Arc;
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct HeadlessAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
// Additional fields not present in `workspace::AppState`.
pub prompt_builder: Arc<PromptBuilder>,
}
pub struct HeadlessAssistant {
pub thread: Entity<Thread>,
pub project: Entity<Project>,
#[allow(dead_code)]
pub thread_store: Entity<ThreadStore>,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub done_tx: channel::Sender<anyhow::Result<()>>,
_subscription: Subscription,
}
impl HeadlessAssistant {
pub fn new(
app_state: Arc<HeadlessAppState>,
cx: &mut App,
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
let env = None;
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
Arc::new(DapRegistry::default()),
app_state.fs.clone(),
env,
cx,
);
let tools = Arc::new(ToolWorkingSet::default());
let thread_store =
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
let headless_thread = cx.new(move |cx| Self {
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
thread,
project,
thread_store,
tool_use_counts: HashMap::default(),
done_tx,
});
Ok((headless_thread, done_rx))
}
fn handle_thread_event(
&mut self,
thread: Entity<Thread>,
event: &ThreadEvent,
cx: &mut Context<Self>,
) {
match event {
ThreadEvent::ShowError(err) => self
.done_tx
.send_blocking(Err(anyhow!("{:?}", err)))
.unwrap(),
ThreadEvent::DoneStreaming => {
let thread = thread.read(cx);
if let Some(message) = thread.messages().last() {
println!("Message: {}", message.to_string());
}
if thread.all_tools_finished() {
self.done_tx.send_blocking(Ok(())).unwrap()
}
}
ThreadEvent::UsePendingTools { .. } => {}
ThreadEvent::ToolConfirmationNeeded => {
// Automatically approve all tools that need confirmation in headless mode
println!("Tool confirmation needed - automatically approving in headless mode");
// Get the tools needing confirmation
let tools_needing_confirmation: Vec<_> = thread
.read(cx)
.tools_needing_confirmation()
.cloned()
.collect();
// Run each tool that needs confirmation
for tool_use in tools_needing_confirmation {
if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
thread.update(cx, |thread, cx| {
println!("Auto-approving tool: {}", tool_use.name);
// Create a request to send to the tool
let request = thread.to_completion_request(RequestKind::Chat, cx);
let messages = Arc::new(request.messages);
// Run the tool
thread.run_tool(
tool_use.id.clone(),
tool_use.ui_text.clone(),
tool_use.input.clone(),
&messages,
tool,
cx,
);
});
}
}
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
if let Some(pending_tool_use) = pending_tool_use {
println!(
"Used tool {} with input: {}",
pending_tool_use.name, pending_tool_use.input
);
*self
.tool_use_counts
.entry(pending_tool_use.name.clone())
.or_insert(0) += 1;
}
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
println!("Tool result: {:?}", tool_result);
}
}
_ => {}
}
}
}
pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
release_channel::init(SemanticVersion::default(), cx);
gpui_tokio::init(cx);
let mut settings_store = SettingsStore::new(cx);
settings_store
.set_default_settings(settings::default_settings().as_ref(), cx)
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
Project::init_settings(cx);
let client = Client::production(cx);
cx.set_http_client(client.http_client().clone());
let git_binary_path = None;
let fs = Arc::new(RealFs::new(
git_binary_path,
cx.background_executor().clone(),
));
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
assistant_tools::init(client.http_client().clone(), cx);
context_server::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
Arc::new(HeadlessAppState {
languages,
client,
user_store,
fs,
node_runtime: NodeRuntime::unavailable(),
prompt_builder,
})
}
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model = model_registry
.available_models(cx)
.find(|model| model.id().0 == model_name);
let Some(model) = model else {
return Err(anyhow!(
"No language model named {} was available. Available models: {}",
model_name,
model_registry
.available_models(cx)
.map(|model| model.id().0.clone())
.collect::<Vec<_>>()
.join(", ")
));
};
Ok(model)
}
pub fn authenticate_model_provider(
provider_id: LanguageModelProviderId,
cx: &mut App,
) -> Task<std::result::Result<(), AuthenticateError>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}

View File

@@ -0,0 +1,205 @@
mod eval;
mod get_exercise;
mod git_commands;
mod headless_assistant;
use clap::Parser;
use eval::{run_exercise_eval, save_eval_results};
use futures::stream::{self, StreamExt};
use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
use git_commands::read_base_sha;
use gpui::Application;
use headless_assistant::{authenticate_model_provider, find_model};
use language_model::LanguageModelRegistry;
use reqwest_client::ReqwestClient;
use std::{path::PathBuf, sync::Arc};
#[derive(Parser, Debug)]
#[command(
name = "agent_eval",
disable_version_flag = true,
before_help = "Tool eval runner"
)]
struct Args {
/// Match the names of evals to run.
#[arg(long)]
exercise_names: Vec<String>,
/// Runs all exercises, causes the exercise_names to be ignored.
#[arg(long)]
all: bool,
/// Supported language types to evaluate (default: internal).
/// Internal is data generated from the agent panel
#[arg(long, default_value = "internal")]
languages: String,
/// Name of the model (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model_name: String,
/// Name of the editor model (default: value of `--model_name`).
#[arg(long)]
editor_model_name: Option<String>,
/// Number of evaluations to run concurrently (default: 3)
#[arg(short, long, default_value = "5")]
concurrency: usize,
/// Maximum number of exercises to evaluate per language
#[arg(long)]
max_exercises_per_language: Option<usize>,
}
fn main() {
env_logger::init();
let args = Args::parse();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
// Path to the zed-ace-framework repo
let framework_path = PathBuf::from("../zed-ace-framework")
.canonicalize()
.unwrap();
// Fix the 'languages' lifetime issue by creating owned Strings instead of slices
let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
println!("Using zed-ace-framework at: {:?}", framework_path);
println!("Evaluating languages: {:?}", languages);
app.run(move |cx| {
let app_state = headless_assistant::init(cx);
let model = find_model(&args.model_name, cx).unwrap();
let editor_model = if let Some(model_name) = &args.editor_model_name {
find_model(model_name, cx).unwrap()
} else {
model.clone()
};
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(Some(model.clone()), cx);
});
let model_provider_id = model.provider_id();
let editor_model_provider_id = editor_model.provider_id();
let framework_path_clone = framework_path.clone();
let languages_clone = languages.clone();
let exercise_names = args.exercise_names.clone();
let all_flag = args.all;
cx.spawn(async move |cx| {
// Authenticate all model providers first
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
println!("framework path: {}", framework_path_clone.display());
let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
println!("base sha: {}", base_sha);
let all_exercises = find_exercises(
&framework_path_clone,
&languages_clone
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
args.max_exercises_per_language,
)
.unwrap();
println!("Found {} exercises total", all_exercises.len());
// Filter exercises if specific ones were requested
let exercises_to_run = if !exercise_names.is_empty() {
// If exercise names are specified, filter by them regardless of --all flag
all_exercises
.into_iter()
.filter(|path| {
let name = get_exercise_name(path);
exercise_names.iter().any(|filter| name.contains(filter))
})
.collect()
} else if all_flag {
// Only use all_flag if no exercise names are specified
all_exercises
} else {
// Default behavior (no filters)
all_exercises
};
println!("Will run {} exercises", exercises_to_run.len());
// Create exercise eval tasks - each exercise is a single task that will run templates sequentially
let exercise_tasks: Vec<_> = exercises_to_run
.into_iter()
.map(|exercise_path| {
let exercise_name = get_exercise_name(&exercise_path);
let model_clone = model.clone();
let app_state_clone = app_state.clone();
let base_sha_clone = base_sha.clone();
let framework_path_clone = framework_path_clone.clone();
let cx_clone = cx.clone();
async move {
println!("Processing exercise: {}", exercise_name);
let mut exercise_results = Vec::new();
match run_exercise_eval(
exercise_path.clone(),
model_clone.clone(),
app_state_clone.clone(),
base_sha_clone.clone(),
framework_path_clone.clone(),
cx_clone.clone(),
)
.await
{
Ok(result) => {
println!("Completed {}", exercise_name);
exercise_results.push(result);
}
Err(err) => {
println!("Error running {}: {}", exercise_name, err);
}
}
// Save results for this exercise
if !exercise_results.is_empty() {
if let Err(err) =
save_eval_results(&exercise_path, exercise_results.clone()).await
{
println!("Error saving results for {}: {}", exercise_name, err);
} else {
println!("Saved results for {}", exercise_name);
}
}
exercise_results
}
})
.collect();
println!(
"Running {} exercises with concurrency: {}",
exercise_tasks.len(),
args.concurrency
);
// Run exercises concurrently, with each exercise running its templates sequentially
let all_results = stream::iter(exercise_tasks)
.buffer_unordered(args.concurrency)
.flat_map(stream::iter)
.collect::<Vec<_>>()
.await;
println!("Completed {} evaluation runs", all_results.len());
cx.update(|cx| cx.quit()).unwrap();
})
.detach();
});
println!("Done running evals");
}

View File

@@ -0,0 +1,25 @@
[package]
name = "agent_rules"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/agent_rules.rs"
doctest = false
[dependencies]
anyhow.workspace = true
fs.workspace = true
gpui.workspace = true
prompt_store.workspace = true
util.workspace = true
worktree.workspace = true
workspace-hack = { version = "0.1", path = "../../tooling/workspace-hack" }
[dev-dependencies]
indoc.workspace = true

View File

@@ -0,0 +1,51 @@
use std::sync::Arc;
use anyhow::{Context as _, Result};
use fs::Fs;
use gpui::{App, AppContext, Task};
use prompt_store::SystemPromptRulesFile;
use util::maybe;
use worktree::Worktree;
const RULES_FILE_NAMES: [&'static str; 6] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
".github/copilot-instructions.md",
"CLAUDE.md",
];
pub fn load_worktree_rules_file(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Option<Task<Result<SystemPromptRulesFile>>> {
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
})
.next();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
selected_rules_file.map(|(path_in_worktree, abs_path)| {
let fs = fs.clone();
cx.background_spawn(maybe!(async move {
let abs_path = abs_path?;
let text = fs
.load(&abs_path)
.await
.with_context(|| format!("Failed to load assistant rules file {:?}", abs_path))?;
anyhow::Ok(SystemPromptRulesFile {
path_in_worktree,
abs_path: abs_path.into(),
text: text.trim().to_string(),
})
}))
})
}

View File

@@ -25,4 +25,5 @@ serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
util.workspace = true
workspace-hack.workspace = true

View File

@@ -10,6 +10,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
use util::ResultExt as _;
pub use supported_countries::*;
@@ -36,9 +37,9 @@ pub enum AnthropicModelMode {
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
#[default]
#[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
Claude3_5Sonnet,
#[default]
#[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
Claude3_7Sonnet,
#[serde(
@@ -362,25 +363,11 @@ pub struct RateLimitInfo {
impl RateLimitInfo {
fn from_headers(headers: &HeaderMap<HeaderValue>) -> Self {
// Check if any rate limit headers exist
let has_rate_limit_headers = headers
.keys()
.any(|k| k.as_str().starts_with("anthropic-ratelimit-"));
if !has_rate_limit_headers {
return Self {
requests: None,
tokens: None,
input_tokens: None,
output_tokens: None,
};
}
Self {
requests: RateLimit::from_headers("requests", headers).ok(),
tokens: RateLimit::from_headers("tokens", headers).ok(),
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
output_tokens: RateLimit::from_headers("output-tokens", headers).ok(),
requests: RateLimit::from_headers("requests", headers).log_err(),
tokens: RateLimit::from_headers("tokens", headers).log_err(),
input_tokens: RateLimit::from_headers("input-tokens", headers).log_err(),
output_tokens: RateLimit::from_headers("output-tokens", headers).log_err(),
}
}
}
@@ -737,54 +724,4 @@ impl ApiError {
pub fn is_rate_limit_error(&self) -> bool {
matches!(self.error_type.as_str(), "rate_limit_error")
}
pub fn match_window_exceeded(&self) -> Option<usize> {
let Some(ApiErrorCode::InvalidRequestError) = self.code() else {
return None;
};
parse_prompt_too_long(&self.message)
}
}
pub fn parse_prompt_too_long(message: &str) -> Option<usize> {
message
.strip_prefix("prompt is too long: ")?
.split_once(" tokens")?
.0
.parse::<usize>()
.ok()
}
#[test]
fn test_match_window_exceeded() {
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "prompt is too long: 220000 tokens > 200000".to_string(),
};
assert_eq!(error.match_window_exceeded(), Some(220_000));
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "prompt is too long: 1234953 tokens".to_string(),
};
assert_eq!(error.match_window_exceeded(), Some(1234953));
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "not a prompt length error".to_string(),
};
assert_eq!(error.match_window_exceeded(), None);
let error = ApiError {
error_type: "rate_limit_error".to_string(),
message: "prompt is too long: 12345 tokens".to_string(),
};
assert_eq!(error.match_window_exceeded(), None);
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "prompt is too long: invalid tokens".to_string(),
};
assert_eq!(error.match_window_exceeded(), None);
}

View File

@@ -18,4 +18,5 @@ gpui.workspace = true
smol.workspace = true
tempfile.workspace = true
util.workspace = true
which.workspace = true
workspace-hack.workspace = true

View File

@@ -72,8 +72,6 @@ impl AskPassSession {
let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
let listener =
UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
let zed_path = std::env::current_exe()
.context("Failed to figure out current executable path for use in askpass")?;
let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
let mut kill_tx = Some(askpass_kill_master_tx);
@@ -112,10 +110,21 @@ impl AskPassSession {
drop(temp_dir)
});
anyhow::ensure!(
which::which("nc").is_ok(),
"Cannot find `nc` command (netcat), which is required to connect over SSH."
);
// Create an askpass script that communicates back to this process.
let askpass_script = format!(
"{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
zed_exe = zed_path.display(),
"{shebang}\n{print_args} | {nc} -U {askpass_socket} 2> /dev/null \n",
// on macOS `brew install netcat` provides the GNU netcat implementation
// which does not support -U.
nc = if cfg!(target_os = "macos") {
"/usr/bin/nc"
} else {
"nc"
},
askpass_socket = askpass_socket.display(),
print_args = "printf '%s\\0' \"$@\"",
shebang = "#!/bin/sh",
@@ -161,51 +170,6 @@ impl AskPassSession {
}
}
/// The main function for when Zed is running in netcat mode for use in askpass.
/// Called from both the remote server binary and the zed binary in their respective main functions.
#[cfg(unix)]
pub fn main(socket: &str) {
use std::io::{self, Read, Write};
use std::os::unix::net::UnixStream;
use std::process::exit;
let mut stream = match UnixStream::connect(socket) {
Ok(stream) => stream,
Err(err) => {
eprintln!("Error connecting to socket {}: {}", socket, err);
exit(1);
}
};
let mut buffer = Vec::new();
if let Err(err) = io::stdin().read_to_end(&mut buffer) {
eprintln!("Error reading from stdin: {}", err);
exit(1);
}
if buffer.last() != Some(&b'\0') {
buffer.push(b'\0');
}
if let Err(err) = stream.write_all(&buffer) {
eprintln!("Error writing to socket: {}", err);
exit(1);
}
let mut response = Vec::new();
if let Err(err) = stream.read_to_end(&mut response) {
eprintln!("Error reading from socket: {}", err);
exit(1);
}
if let Err(err) = io::stdout().write_all(&response) {
eprintln!("Error writing to stdout: {}", err);
exit(1);
}
}
#[cfg(not(unix))]
pub fn main(_socket: &str) {}
#[cfg(not(unix))]
pub struct AskPassSession {
path: PathBuf,

View File

@@ -13,7 +13,7 @@ use assistant_context_editor::{
use assistant_settings::{AssistantDockPosition, AssistantSettings};
use assistant_slash_command::SlashCommandWorkingSet;
use client::{Client, Status, proto};
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use editor::{Editor, EditorEvent};
use fs::Fs;
use gpui::{
Action, App, AsyncWindowContext, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable,
@@ -28,12 +28,9 @@ use language_model::{
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::PromptBuilder;
use search::{BufferSearchBar, buffer_search::DivRegistrar};
use settings::{Settings, update_settings_file};
use smol::stream::StreamExt;
use std::ops::Range;
use std::{ops::ControlFlow, path::PathBuf, sync::Arc};
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
use ui::{ContextMenu, PopoverMenu, Tooltip, prelude::*};
@@ -1416,8 +1413,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
fn quote_selection(
&self,
workspace: &mut Workspace,
selection_ranges: Vec<Range<Anchor>>,
buffer: Entity<MultiBuffer>,
creases: Vec<(String, String)>,
window: &mut Window,
cx: &mut Context<Workspace>,
) {
@@ -1429,12 +1425,6 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
workspace.toggle_panel_focus::<AssistantPanel>(window, cx);
}
let snapshot = buffer.read(cx).snapshot(cx);
let selection_ranges = selection_ranges
.into_iter()
.map(|range| range.to_point(&snapshot))
.collect::<Vec<_>>();
panel.update(cx, |_, cx| {
// Wait to create a new context until the workspace is no longer
// being updated.
@@ -1443,9 +1433,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
.active_context_editor(cx)
.or_else(|| panel.new_context(window, cx))
{
context.update(cx, |context, cx| {
context.quote_ranges(selection_ranges, snapshot, window, cx)
});
context.update(cx, |context, cx| context.quote_creases(creases, window, cx));
};
});
});

View File

@@ -8,8 +8,8 @@ use assistant_slash_commands::{
use client::{proto, zed_urls};
use collections::{BTreeSet, HashMap, HashSet, hash_map};
use editor::{
Anchor, Editor, EditorEvent, MenuInlineCompletionsPolicy, MultiBuffer, MultiBufferSnapshot,
ProposedChangeLocation, ProposedChangesEditor, RowExt, ToOffset as _, ToPoint,
Anchor, Editor, EditorEvent, MenuInlineCompletionsPolicy, ProposedChangeLocation,
ProposedChangesEditor, RowExt, ToOffset as _, ToPoint,
actions::{MoveToEndOfLine, Newline, ShowCompletions},
display_map::{
BlockContext, BlockId, BlockPlacement, BlockProperties, BlockStyle, Crease, CreaseMetadata,
@@ -155,8 +155,7 @@ pub trait AssistantPanelDelegate {
fn quote_selection(
&self,
workspace: &mut Workspace,
selection_ranges: Vec<Range<Anchor>>,
buffer: Entity<MultiBuffer>,
creases: Vec<(String, String)>,
window: &mut Window,
cx: &mut Context<Workspace>,
);
@@ -1801,42 +1800,23 @@ impl ContextEditor {
return;
};
let Some((selections, buffer)) = maybe!({
let editor = workspace
.active_item(cx)
.and_then(|item| item.act_as::<Editor>(cx))?;
let buffer = editor.read(cx).buffer().clone();
let snapshot = buffer.read(cx).snapshot(cx);
let selections = editor.update(cx, |editor, cx| {
editor
.selections
.all_adjusted(cx)
.into_iter()
.map(|s| snapshot.anchor_after(s.start)..snapshot.anchor_before(s.end))
.collect::<Vec<_>>()
});
Some((selections, buffer))
}) else {
let Some(creases) = selections_creases(workspace, cx) else {
return;
};
if selections.is_empty() {
if creases.is_empty() {
return;
}
assistant_panel_delegate.quote_selection(workspace, selections, buffer, window, cx);
assistant_panel_delegate.quote_selection(workspace, creases, window, cx);
}
pub fn quote_ranges(
pub fn quote_creases(
&mut self,
ranges: Vec<Range<Point>>,
snapshot: MultiBufferSnapshot,
creases: Vec<(String, String)>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let creases = selections_creases(ranges, snapshot, cx);
self.editor.update(cx, |editor, cx| {
editor.insert("\n", window, cx);
for (text, crease_title) in creases {

View File

@@ -69,7 +69,7 @@ pub enum AssistantProviderContentV1 {
},
}
#[derive(Clone, Debug, Default)]
#[derive(Debug, Default)]
pub struct AssistantSettings {
pub enabled: bool,
pub button: bool,
@@ -742,7 +742,7 @@ mod tests {
AssistantSettings::get_global(cx).default_model,
LanguageModelSelection {
provider: "zed.dev".into(),
model: "claude-3-7-sonnet-latest".into(),
model: "claude-3-5-sonnet-latest".into(),
}
);
});

View File

@@ -3,12 +3,10 @@ use assistant_slash_command::{
ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent,
SlashCommandOutputSection, SlashCommandResult,
};
use editor::{Editor, MultiBufferSnapshot};
use editor::Editor;
use futures::StreamExt;
use gpui::{App, SharedString, Task, WeakEntity, Window};
use gpui::{App, Context, SharedString, Task, WeakEntity, Window};
use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
use rope::Point;
use std::ops::Range;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use ui::IconName;
@@ -71,22 +69,7 @@ impl SlashCommand for SelectionCommand {
let mut events = vec![];
let Some(creases) = workspace
.update(cx, |workspace, cx| {
let editor = workspace
.active_item(cx)
.and_then(|item| item.act_as::<Editor>(cx))?;
editor.update(cx, |editor, cx| {
let selection_ranges = editor
.selections
.all_adjusted(cx)
.iter()
.map(|selection| selection.range())
.collect::<Vec<_>>();
let snapshot = editor.buffer().read(cx).snapshot(cx);
Some(selections_creases(selection_ranges, snapshot, cx))
})
})
.update(cx, selections_creases)
.unwrap_or_else(|e| {
events.push(Err(e));
None
@@ -119,82 +102,94 @@ impl SlashCommand for SelectionCommand {
}
pub fn selections_creases(
selection_ranges: Vec<Range<Point>>,
snapshot: MultiBufferSnapshot,
cx: &App,
) -> Vec<(String, String)> {
let mut creases = Vec::new();
for range in selection_ranges {
let selected_text = snapshot.text_for_range(range.clone()).collect::<String>();
if selected_text.is_empty() {
continue;
}
let start_language = snapshot.language_at(range.start);
let end_language = snapshot.language_at(range.end);
let language_name = if start_language == end_language {
start_language.map(|language| language.code_fence_block_name())
} else {
None
};
let language_name = language_name.as_deref().unwrap_or("");
let filename = snapshot.file_at(range.start).map(|file| file.full_path(cx));
let text = if language_name == "markdown" {
selected_text
.lines()
.map(|line| format!("> {}", line))
.collect::<Vec<_>>()
.join("\n")
} else {
let start_symbols = snapshot
.symbols_containing(range.start, None)
.map(|(_, symbols)| symbols);
let end_symbols = snapshot
.symbols_containing(range.end, None)
.map(|(_, symbols)| symbols);
workspace: &mut workspace::Workspace,
cx: &mut Context<Workspace>,
) -> Option<Vec<(String, String)>> {
let editor = workspace
.active_item(cx)
.and_then(|item| item.act_as::<Editor>(cx))?;
let outline_text =
if let Some((start_symbols, end_symbols)) = start_symbols.zip(end_symbols) {
Some(
start_symbols
.into_iter()
.zip(end_symbols)
.take_while(|(a, b)| a == b)
.map(|(a, _)| a.text)
.collect::<Vec<_>>()
.join(" > "),
)
let mut creases = vec![];
editor.update(cx, |editor, cx| {
let selections = editor.selections.all_adjusted(cx);
let buffer = editor.buffer().read(cx).snapshot(cx);
for selection in selections {
let range = editor::ToOffset::to_offset(&selection.start, &buffer)
..editor::ToOffset::to_offset(&selection.end, &buffer);
let selected_text = buffer.text_for_range(range.clone()).collect::<String>();
if selected_text.is_empty() {
continue;
}
let start_language = buffer.language_at(range.start);
let end_language = buffer.language_at(range.end);
let language_name = if start_language == end_language {
start_language.map(|language| language.code_fence_block_name())
} else {
None
};
let language_name = language_name.as_deref().unwrap_or("");
let filename = buffer
.file_at(selection.start)
.map(|file| file.full_path(cx));
let text = if language_name == "markdown" {
selected_text
.lines()
.map(|line| format!("> {}", line))
.collect::<Vec<_>>()
.join("\n")
} else {
let start_symbols = buffer
.symbols_containing(selection.start, None)
.map(|(_, symbols)| symbols);
let end_symbols = buffer
.symbols_containing(selection.end, None)
.map(|(_, symbols)| symbols);
let outline_text =
if let Some((start_symbols, end_symbols)) = start_symbols.zip(end_symbols) {
Some(
start_symbols
.into_iter()
.zip(end_symbols)
.take_while(|(a, b)| a == b)
.map(|(a, _)| a.text)
.collect::<Vec<_>>()
.join(" > "),
)
} else {
None
};
let line_comment_prefix = start_language
.and_then(|l| l.default_scope().line_comment_prefixes().first().cloned());
let fence = codeblock_fence_for_path(
filename.as_deref(),
Some(selection.start.row..=selection.end.row),
);
if let Some((line_comment_prefix, outline_text)) =
line_comment_prefix.zip(outline_text)
{
let breadcrumb = format!("{line_comment_prefix}Excerpt from: {outline_text}\n");
format!("{fence}{breadcrumb}{selected_text}\n```")
} else {
None
};
let line_comment_prefix = start_language
.and_then(|l| l.default_scope().line_comment_prefixes().first().cloned());
let fence = codeblock_fence_for_path(
filename.as_deref(),
Some(range.start.row..=range.end.row),
);
if let Some((line_comment_prefix, outline_text)) = line_comment_prefix.zip(outline_text)
{
let breadcrumb = format!("{line_comment_prefix}Excerpt from: {outline_text}\n");
format!("{fence}{breadcrumb}{selected_text}\n```")
format!("{fence}{selected_text}\n```")
}
};
let crease_title = if let Some(path) = filename {
let start_line = selection.start.row + 1;
let end_line = selection.end.row + 1;
if start_line == end_line {
format!("{}, Line {}", path.display(), start_line)
} else {
format!("{}, Lines {} to {}", path.display(), start_line, end_line)
}
} else {
format!("{fence}{selected_text}\n```")
}
};
let crease_title = if let Some(path) = filename {
let start_line = range.start.row + 1;
let end_line = range.end.row + 1;
if start_line == end_line {
format!("{}, Line {}", path.display(), start_line)
} else {
format!("{}, Lines {} to {}", path.display(), start_line, end_line)
}
} else {
"Quoted selection".to_string()
};
creases.push((text, crease_title));
}
creases
"Quoted selection".to_string()
};
creases.push((text, crease_title));
}
});
Some(creases)
}

View File

@@ -3,8 +3,8 @@ use buffer_diff::BufferDiff;
use collections::BTreeMap;
use futures::{StreamExt, channel::mpsc};
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint};
use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
use project::{Project, ProjectItem};
use std::{cmp, ops::Range, sync::Arc};
use text::{Edit, Patch, Rope};
use util::RangeExt;
@@ -49,10 +49,6 @@ impl ActionLog {
.tracked_buffers
.entry(buffer.clone())
.or_insert_with(|| {
let open_lsp_handle = self.project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
});
let text_snapshot = buffer.read(cx).text_snapshot();
let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
let (diff_update_tx, diff_update_rx) = mpsc::unbounded();
@@ -80,7 +76,6 @@ impl ActionLog {
version: buffer.read(cx).version(),
diff,
diff_update: diff_update_tx,
_open_lsp_handle: open_lsp_handle,
_maintain_diff: cx.spawn({
let buffer = buffer.clone();
async move |this, cx| {
@@ -240,7 +235,7 @@ impl ActionLog {
.await;
diff.update(cx, |diff, cx| {
diff.set_snapshot(diff_snapshot, &buffer_snapshot, cx)
diff.set_snapshot(diff_snapshot, &buffer_snapshot, None, cx)
})?;
}
this.update(cx, |this, cx| {
@@ -363,10 +358,10 @@ impl ActionLog {
}
}
pub fn reject_edits_in_ranges(
pub fn reject_edits_in_range(
&mut self,
buffer: Entity<Buffer>,
buffer_ranges: Vec<Range<impl language::ToPoint>>,
buffer_range: Range<impl language::ToPoint>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else {
@@ -403,15 +398,29 @@ impl ActionLog {
}
TrackedBufferStatus::Modified => {
buffer.update(cx, |buffer, cx| {
let mut buffer_row_ranges = buffer_ranges
.into_iter()
.map(|range| {
range.start.to_point(buffer).row..range.end.to_point(buffer).row
})
.peekable();
let buffer_range =
buffer_range.start.to_point(buffer)..buffer_range.end.to_point(buffer);
let mut edits_to_revert = Vec::new();
for edit in tracked_buffer.unreviewed_changes.edits() {
if buffer_range.end.row < edit.new.start {
break;
} else if buffer_range.start.row > edit.new.end {
continue;
}
let old_range = tracked_buffer
.base_text
.point_to_offset(Point::new(edit.old.start, 0))
..tracked_buffer.base_text.point_to_offset(cmp::min(
Point::new(edit.old.end, 0),
tracked_buffer.base_text.max_point(),
));
let old_text = tracked_buffer
.base_text
.chunks_in_range(old_range)
.collect::<String>();
let new_range = tracked_buffer
.snapshot
.anchor_before(Point::new(edit.new.start, 0))
@@ -419,35 +428,7 @@ impl ActionLog {
Point::new(edit.new.end, 0),
tracked_buffer.snapshot.max_point(),
));
let new_row_range = new_range.start.to_point(buffer).row
..new_range.end.to_point(buffer).row;
let mut revert = false;
while let Some(buffer_row_range) = buffer_row_ranges.peek() {
if buffer_row_range.end < new_row_range.start {
buffer_row_ranges.next();
} else if buffer_row_range.start > new_row_range.end {
break;
} else {
revert = true;
break;
}
}
if revert {
let old_range = tracked_buffer
.base_text
.point_to_offset(Point::new(edit.old.start, 0))
..tracked_buffer.base_text.point_to_offset(cmp::min(
Point::new(edit.old.end, 0),
tracked_buffer.base_text.max_point(),
));
let old_text = tracked_buffer
.base_text
.chunks_in_range(old_range)
.collect::<String>();
edits_to_revert.push((new_range, old_text));
}
edits_to_revert.push((new_range, old_text));
}
buffer.edit(edits_to_revert, None, cx);
@@ -613,7 +594,6 @@ fn point_to_row_edit(edit: Edit<Point>, old_text: &Rope, new_text: &Rope) -> Edi
}
}
#[derive(Copy, Clone, Debug)]
enum ChangeAuthor {
User,
Agent,
@@ -635,7 +615,6 @@ struct TrackedBuffer {
diff: Entity<BufferDiff>,
snapshot: text::BufferSnapshot,
diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>,
_open_lsp_handle: OpenLspBufferHandle,
_maintain_diff: Task<()>,
_subscription: Subscription,
}
@@ -1150,48 +1129,9 @@ mod tests {
)]
);
// If the rejected range doesn't overlap with any hunk, we ignore it.
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(4, 0)..Point::new(4, 0)],
cx,
)
})
.await
.unwrap();
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndE\nXYZf\nghi\njkl\nmnO"
);
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![
HunkStatus {
range: Point::new(1, 0)..Point::new(3, 0),
diff_status: DiffHunkStatusKind::Modified,
old_text: "def\n".into(),
},
HunkStatus {
range: Point::new(5, 0)..Point::new(5, 3),
diff_status: DiffHunkStatusKind::Modified,
old_text: "mno".into(),
}
],
)]
);
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(1, 0)],
cx,
)
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx)
})
.await
.unwrap();
@@ -1214,11 +1154,7 @@ mod tests {
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(4, 0)..Point::new(4, 0)],
cx,
)
log.reject_edits_in_range(buffer.clone(), Point::new(4, 0)..Point::new(4, 0), cx)
})
.await
.unwrap();
@@ -1230,82 +1166,6 @@ mod tests {
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test(iterations = 10)]
async fn test_reject_multiple_edits(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
.unwrap()
});
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx)
.unwrap()
});
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
});
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndE\nXYZf\nghi\njkl\nmnO"
);
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![
HunkStatus {
range: Point::new(1, 0)..Point::new(3, 0),
diff_status: DiffHunkStatusKind::Modified,
old_text: "def\n".into(),
},
HunkStatus {
range: Point::new(5, 0)..Point::new(5, 3),
diff_status: DiffHunkStatusKind::Modified,
old_text: "mno".into(),
}
],
)]
);
action_log.update(cx, |log, cx| {
let range_1 = buffer.read(cx).anchor_before(Point::new(0, 0))
..buffer.read(cx).anchor_before(Point::new(1, 0));
let range_2 = buffer.read(cx).anchor_before(Point::new(5, 0))
..buffer.read(cx).anchor_before(Point::new(5, 3));
log.reject_edits_in_ranges(buffer.clone(), vec![range_1, range_2], cx)
.detach();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndef\nghi\njkl\nmno"
);
});
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndef\nghi\njkl\nmno"
);
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test(iterations = 10)]
async fn test_reject_deleted_file(cx: &mut TestAppContext) {
init_test(cx);
@@ -1349,11 +1209,7 @@ mod tests {
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(0, 0)],
cx,
)
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 0), cx)
})
.await
.unwrap();
@@ -1404,11 +1260,7 @@ mod tests {
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(0, 11)],
cx,
)
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 11), cx)
})
.await
.unwrap();
@@ -1454,7 +1306,7 @@ mod tests {
.update(cx, |log, cx| {
let range = buffer.read(cx).random_byte_range(0, &mut rng);
log::info!("rejecting edits in range {:?}", range);
log.reject_edits_in_ranges(buffer.clone(), vec![range], cx)
log.reject_edits_in_range(buffer.clone(), range, cx)
})
.await
.unwrap();

View File

@@ -1,6 +1,5 @@
mod action_log;
mod tool_registry;
mod tool_schema;
mod tool_working_set;
use std::fmt;
@@ -9,10 +8,6 @@ use std::fmt::Formatter;
use std::sync::Arc;
use anyhow::Result;
use gpui::AnyElement;
use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task};
use icons::IconName;
use language_model::LanguageModelRequestMessage;
@@ -21,97 +16,12 @@ use project::Project;
pub use crate::action_log::*;
pub use crate::tool_registry::*;
pub use crate::tool_schema::*;
pub use crate::tool_working_set::*;
pub fn init(cx: &mut App) {
ToolRegistry::default_global(cx);
}
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
NeedsConfirmation,
Pending,
Running,
Finished(SharedString),
Error(SharedString),
}
impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
ToolUseStatus::Error(out) => out.clone(),
}
}
}
/// The result of running a tool, containing both the asynchronous output
/// and an optional card view that can be rendered immediately.
pub struct ToolResult {
/// The asynchronous task that will eventually resolve to the tool's output
pub output: Task<Result<String>>,
/// An optional view to present the output of the tool.
pub card: Option<AnyToolCard>,
}
pub trait ToolCard: 'static + Sized {
fn render(
&mut self,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement;
}
#[derive(Clone)]
pub struct AnyToolCard {
entity: gpui::AnyEntity,
render: fn(
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut App,
) -> AnyElement,
}
impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
fn from(entity: Entity<T>) -> Self {
fn downcast_render<T: ToolCard>(
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut App,
) -> AnyElement {
let entity = entity.downcast::<T>().unwrap();
entity.update(cx, |entity, cx| {
entity.render(status, window, cx).into_any_element()
})
}
Self {
entity: entity.into(),
render: downcast_render::<T>,
}
}
}
impl AnyToolCard {
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
(self.render)(self.entity.clone(), status, window, cx)
}
}
impl From<Task<Result<String>>> for ToolResult {
/// Convert from a task to a ToolResult with no card
fn from(output: Task<Result<String>>) -> Self {
Self { output, card: None }
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
pub enum ToolSource {
/// A native tool built-in to Zed.
@@ -141,8 +51,8 @@ pub trait Tool: 'static + Send + Sync {
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
Ok(serde_json::Value::Object(serde_json::Map::default()))
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::default())
}
/// Returns markdown to be displayed in the UI for this tool.
@@ -156,7 +66,7 @@ pub trait Tool: 'static + Send + Sync {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult;
) -> Task<Result<String>>;
}
impl Debug for dyn Tool {

View File

@@ -1,236 +0,0 @@
use anyhow::Result;
use serde_json::Value;
use crate::LanguageModelToolSchemaFormat;
/// Tries to adapt a JSON schema representation to be compatible with the specified format.
///
/// If the json cannot be made compatible with the specified format, an error is returned.
pub fn adapt_schema_to_format(
json: &mut Value,
format: LanguageModelToolSchemaFormat,
) -> Result<()> {
match format {
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
}
}
/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
if let Value::Object(obj) = json {
const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
for key in UNSUPPORTED_KEYS {
if obj.contains_key(key) {
return Err(anyhow::anyhow!(
"Schema cannot be made compatible because it contains \"{}\" ",
key
));
}
}
const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
for key in KEYS_TO_REMOVE {
obj.remove(key);
}
if let Some(default) = obj.get("default") {
let is_null = default.is_null();
// Default is not supported, so we need to remove it
obj.remove("default");
if is_null {
obj.insert("nullable".to_string(), Value::Bool(true));
}
}
// If a type is not specified for an input parameter, add a default type
if obj.contains_key("description")
&& !obj.contains_key("type")
&& !(obj.contains_key("anyOf")
|| obj.contains_key("oneOf")
|| obj.contains_key("allOf"))
{
obj.insert("type".to_string(), Value::String("string".to_string()));
}
// Handle oneOf -> anyOf conversion
if let Some(subschemas) = obj.get_mut("oneOf") {
if subschemas.is_array() {
let subschemas_clone = subschemas.clone();
obj.remove("oneOf");
obj.insert("anyOf".to_string(), subschemas_clone);
}
}
// Recursively process all nested objects and arrays
for (_, value) in obj.iter_mut() {
if let Value::Object(_) | Value::Array(_) = value {
adapt_to_json_schema_subset(value)?;
}
}
} else if let Value::Array(arr) = json {
for item in arr.iter_mut() {
adapt_to_json_schema_subset(item)?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_transform_default_null_to_nullable() {
let mut json = json!({
"description": "A test field",
"type": "string",
"default": null
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field",
"type": "string",
"nullable": true
})
);
}
#[test]
fn test_transform_adds_type_when_missing() {
let mut json = json!({
"description": "A test field without type"
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field without type",
"type": "string"
})
);
}
#[test]
fn test_transform_removes_format() {
let mut json = json!({
"description": "A test field",
"type": "integer",
"format": "uint32"
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field",
"type": "integer"
})
);
}
#[test]
fn test_transform_one_of_to_any_of() {
let mut json = json!({
"description": "A test field",
"oneOf": [
{ "type": "string" },
{ "type": "integer" }
]
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field",
"anyOf": [
{ "type": "string" },
{ "type": "integer" }
]
})
);
}
#[test]
fn test_transform_nested_objects() {
let mut json = json!({
"type": "object",
"properties": {
"nested": {
"oneOf": [
{ "type": "string" },
{ "type": "null" }
],
"format": "email"
}
}
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"type": "object",
"properties": {
"nested": {
"anyOf": [
{ "type": "string" },
{ "type": "null" }
]
}
}
})
);
}
#[test]
fn test_transform_fails_if_unsupported_keys_exist() {
let mut json = json!({
"type": "object",
"properties": {
"$ref": "#/definitions/User",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
let mut json = json!({
"type": "object",
"properties": {
"if": "...",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
let mut json = json!({
"type": "object",
"properties": {
"then": "...",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
let mut json = json!({
"type": "object",
"properties": {
"else": "...",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
}
}

View File

@@ -1,7 +1,8 @@
use std::sync::Arc;
use collections::{HashMap, HashSet, IndexMap};
use gpui::{App, Context, EventEmitter};
use gpui::App;
use parking_lot::Mutex;
use crate::{Tool, ToolRegistry, ToolSource};
@@ -11,6 +12,11 @@ pub struct ToolId(usize);
/// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)]
pub struct ToolWorkingSet {
state: Mutex<WorkingSetState>,
}
#[derive(Default)]
struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
enabled_sources: HashSet<ToolSource>,
@@ -18,27 +24,99 @@ pub struct ToolWorkingSet {
next_tool_id: ToolId,
}
pub enum ToolWorkingSetEvent {
EnabledToolsChanged,
}
impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
self.context_server_tools_by_name
self.state
.lock()
.context_server_tools_by_name
.get(name)
.cloned()
.or_else(|| ToolRegistry::global(cx).tool(name))
}
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools();
tools.extend(self.context_server_tools_by_id.values().cloned());
tools
self.state.lock().tools(cx)
}
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
self.state.lock().tools_by_source(cx)
}
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().enabled_tools(cx)
}
pub fn disable_all_tools(&self) {
let mut state = self.state.lock();
state.disable_all_tools();
}
pub fn enable_source(&self, source: ToolSource, cx: &App) {
let mut state = self.state.lock();
state.enable_source(source, cx);
}
pub fn disable_source(&self, source: &ToolSource) {
let mut state = self.state.lock();
state.disable_source(source);
}
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
let mut state = self.state.lock();
let tool_id = state.next_tool_id;
state.next_tool_id.0 += 1;
state
.context_server_tools_by_id
.insert(tool_id, tool.clone());
state.tools_changed();
tool_id
}
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_enabled(source, name)
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_disabled(source, name)
}
pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
let mut state = self.state.lock();
state.enable(source, tools_to_enable);
}
pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
let mut state = self.state.lock();
state.disable(source, tools_to_disable);
}
pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
let mut state = self.state.lock();
state
.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
state.tools_changed();
}
}
impl WorkingSetState {
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
}
fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools();
tools.extend(self.context_server_tools_by_id.values().cloned());
tools
}
fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
let mut tools_by_source = IndexMap::default();
for tool in self.tools(cx) {
@@ -57,7 +135,7 @@ impl ToolWorkingSet {
tools_by_source
}
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let all_tools = self.tools(cx);
all_tools
@@ -66,12 +144,31 @@ impl ToolWorkingSet {
.collect()
}
pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
self.enabled_tools_by_source.clear();
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.enabled_tools_by_source
.get(source)
.map_or(false, |enabled_tools| enabled_tools.contains(name))
}
pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_enabled(source, name)
}
fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_enable.into_iter().cloned());
}
fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_disable.contains(name));
}
fn enable_source(&mut self, source: ToolSource, cx: &App) {
self.enabled_sources.insert(source.clone());
let tools_by_source = self.tools_by_source(cx);
@@ -84,72 +181,14 @@ impl ToolWorkingSet {
.collect::<HashSet<_>>(),
);
}
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
fn disable_source(&mut self, source: &ToolSource) {
self.enabled_sources.remove(source);
self.enabled_tools_by_source.remove(source);
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1;
self.context_server_tools_by_id
.insert(tool_id, tool.clone());
self.tools_changed();
tool_id
}
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.enabled_tools_by_source
.get(source)
.map_or(false, |enabled_tools| enabled_tools.contains(name))
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_enabled(source, name)
}
pub fn enable(
&mut self,
source: ToolSource,
tools_to_enable: &[Arc<str>],
cx: &mut Context<Self>,
) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_enable.into_iter().cloned());
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn disable(
&mut self,
source: ToolSource,
tools_to_disable: &[Arc<str>],
cx: &mut Context<Self>,
) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_disable.contains(name));
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed();
}
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
fn disable_all_tools(&mut self) {
self.enabled_tools_by_source.clear();
}
}

View File

@@ -16,18 +16,13 @@ anyhow.workspace = true
assistant_tool.workspace = true
chrono.workspace = true
collections.workspace = true
component.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
http_client.workspace = true
indoc.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
linkme.workspace = true
open.workspace = true
project.workspace = true
regex.workspace = true
schemars.workspace = true
@@ -35,10 +30,9 @@ serde.workspace = true
serde_json.workspace = true
ui.workspace = true
util.workspace = true
web_search.workspace = true
workspace-hack.workspace = true
worktree.workspace = true
zed_llm_client.workspace = true
open = { workspace = true }
workspace-hack.workspace = true
[dev-dependencies]
collections = { workspace = true, features = ["test-support"] }
@@ -46,8 +40,5 @@ gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
pretty_assertions.workspace = true
settings = { workspace = true, features = ["test-support"] }
tree-sitter-rust.workspace = true
workspace = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@@ -1,14 +1,13 @@
mod batch_tool;
mod code_action_tool;
mod code_symbols_tool;
mod contents_tool;
mod copy_path_tool;
mod create_directory_tool;
mod create_file_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_file_tool;
mod fetch_tool;
mod find_replace_file_tool;
mod list_directory_tool;
mod move_path_tool;
mod now_tool;
@@ -22,28 +21,24 @@ mod schema;
mod symbol_info_tool;
mod terminal_tool;
mod thinking_tool;
mod web_search_tool;
use std::sync::Arc;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::FeatureFlagAppExt;
use gpui::App;
use http_client::HttpClientWithUrl;
use move_path_tool::MovePathTool;
use web_search_tool::WebSearchTool;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
use crate::code_symbols_tool::CodeSymbolsTool;
use crate::contents_tool::ContentsTool;
use crate::create_directory_tool::CreateDirectoryTool;
use crate::create_file_tool::CreateFileTool;
use crate::delete_path_tool::DeletePathTool;
use crate::diagnostics_tool::DiagnosticsTool;
use crate::edit_file_tool::EditFileTool;
use crate::fetch_tool::FetchTool;
use crate::find_replace_file_tool::FindReplaceFileTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::open_tool::OpenTool;
@@ -65,7 +60,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(EditFileTool);
registry.register_tool(FindReplaceFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
@@ -74,61 +69,10 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(NowTool);
registry.register_tool(OpenTool);
registry.register_tool(CodeSymbolsTool);
registry.register_tool(ContentsTool);
registry.register_tool(PathSearchTool);
registry.register_tool(ReadFileTool);
registry.register_tool(RegexSearchTool);
registry.register_tool(RenameTool);
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
move |is_enabled, cx| {
if is_enabled {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
}
})
.detach();
}
#[cfg(test)]
mod tests {
use http_client::FakeHttpClient;
use super::*;
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
crate::init(
Arc::new(http_client::HttpClientWithUrl::new(
FakeHttpClient::with_200_response(),
"https://zed.dev",
None,
)),
cx,
);
for tool in ToolRegistry::global(cx).tools() {
let actual_schema = tool
.input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset)
.unwrap();
let mut expected_schema = actual_schema.clone();
assistant_tool::adapt_schema_to_format(
&mut expected_schema,
language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
)
.unwrap();
let error_message = format!(
"Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\
Are you using `schema::json_schema_for<T>(format)` to generate the schema?",
tool.name(),
);
assert_eq!(actual_schema, expected_schema, "{}", error_message)
}
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use futures::future::join_all;
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -172,7 +172,7 @@ impl Tool for BatchTool {
IconName::Cog
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<BatchToolInput>(format)
}
@@ -219,14 +219,14 @@ impl Tool for BatchTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<BatchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
if input.invocations.is_empty() {
return Task::ready(Err(anyhow!("No tool invocations provided"))).into();
return Task::ready(Err(anyhow!("No tool invocations provided")));
}
let run_tools_concurrently = input.run_tools_concurrently;
@@ -257,11 +257,11 @@ impl Tool for BatchTool {
let project = project.clone();
let action_log = action_log.clone();
let messages = messages.clone();
let tool_result = cx
let task = cx
.update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
tasks.push(tool_result.output);
tasks.push(task);
}
Ok((tasks, tool_names))
@@ -306,6 +306,5 @@ impl Tool for BatchTool {
Ok(formatted_results.trim().to_string())
})
.into()
}
}

View File

@@ -1,8 +1,8 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language::{self, Anchor, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use language_model::LanguageModelRequestMessage;
use project::{self, LspAction, Project};
use regex::Regex;
use schemars::JsonSchema;
@@ -10,8 +10,6 @@ use serde::{Deserialize, Serialize};
use std::{ops::Range, sync::Arc};
use ui::IconName;
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct CodeActionToolInput {
/// The relative path to the file containing the text range.
@@ -97,8 +95,12 @@ impl Tool for CodeActionTool {
IconName::Wand
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CodeActionToolInput>(format)
fn input_schema(
&self,
_format: language_model::LanguageModelToolSchemaFormat,
) -> serde_json::Value {
let schema = schemars::schema_for!(CodeActionToolInput);
serde_json::to_value(&schema).unwrap()
}
fn ui_text(&self, input: &serde_json::Value) -> String {
@@ -141,10 +143,10 @@ impl Tool for CodeActionTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<CodeActionToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
cx.spawn(async move |cx| {
@@ -319,7 +321,7 @@ impl Tool for CodeActionTool {
Ok(response)
}
}).into()
})
}
}

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use collections::IndexMap;
use gpui::{App, AsyncApp, Entity, Task};
use language::{OutlineItem, ParseStatus, Point};
@@ -91,7 +91,7 @@ impl Tool for CodeSymbolsTool {
IconName::Code
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<CodeSymbolsInput>(format)
}
@@ -129,10 +129,10 @@ impl Tool for CodeSymbolsTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let regex = match input.regex {
@@ -141,16 +141,15 @@ impl Tool for CodeSymbolsTool {
.build()
{
Ok(regex) => Some(regex),
Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))).into(),
Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))),
},
None => None,
};
cx.spawn(async move |cx| match input.path {
Some(path) => file_outline(project, path, action_log, regex, cx).await,
Some(path) => file_outline(project, path, action_log, regex, input.offset, cx).await,
None => project_symbols(project, regex, input.offset, cx).await,
})
.into()
}
}
@@ -159,6 +158,7 @@ pub async fn file_outline(
path: String,
action_log: Entity<ActionLog>,
regex: Option<Regex>,
offset: u32,
cx: &mut AsyncApp,
) -> anyhow::Result<String> {
let buffer = {
@@ -179,9 +179,11 @@ pub async fn file_outline(
// Wait until the buffer has been fully parsed, so that we can read its outline.
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
}
while parse_status
.recv()
.await
.map_or(false, |status| status != ParseStatus::Idle)
{}
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let Some(outline) = snapshot.outline(None) else {
@@ -194,8 +196,7 @@ pub async fn file_outline(
.into_iter()
.map(|item| item.to_point(&snapshot)),
regex,
0,
usize::MAX,
offset,
)
.await
}
@@ -294,10 +295,11 @@ async fn project_symbols(
async fn render_outline(
items: impl IntoIterator<Item = OutlineItem<Point>>,
regex: Option<Regex>,
offset: usize,
results_per_page: usize,
offset: u32,
) -> Result<String> {
let mut items = items.into_iter().skip(offset);
const RESULTS_PER_PAGE_USIZE: usize = RESULTS_PER_PAGE as usize;
let mut items = items.into_iter().skip(offset as usize);
let entries = items
.by_ref()
@@ -306,7 +308,7 @@ async fn render_outline(
.as_ref()
.is_none_or(|regex| regex.is_match(&item.text))
})
.take(results_per_page)
.take(RESULTS_PER_PAGE_USIZE)
.collect::<Vec<_>>();
let has_more = items.next().is_some();
@@ -337,10 +339,7 @@ async fn render_outline(
Ok(output)
}
fn render_entries(
output: &mut String,
items: impl IntoIterator<Item = OutlineItem<Point>>,
) -> usize {
fn render_entries(output: &mut String, items: impl IntoIterator<Item = OutlineItem<Point>>) -> u32 {
let mut entries_rendered = 0;
for item in items {

View File

@@ -1,239 +0,0 @@
use std::sync::Arc;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{fmt::Write, path::Path};
use ui::IconName;
use util::markdown::MarkdownString;
/// If the model requests to read a file whose size exceeds this, then
/// the tool will return the file's symbol outline instead of its contents,
/// and suggest trying again using line ranges from the outline.
const MAX_FILE_SIZE_TO_READ: usize = 16384;
/// If the model requests to list the entries in a directory with more
/// entries than this, then the tool will return a subset of the entries
/// and suggest trying again.
const MAX_DIR_ENTRIES: usize = 1024;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ContentsToolInput {
/// The relative path of the file or directory to access.
///
/// This path should never be absolute, and the first component
/// of the path should always be a root directory in a project.
///
/// <example>
/// If the project has the following root directories:
///
/// - directory1
/// - directory2
///
/// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
/// If you want to list contents in the directory `directory2/subfolder`, you should use the path `directory2/subfolder`.
/// </example>
pub path: String,
/// Optional position (1-based index) to start reading on, if you want to read a subset of the contents.
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
///
/// Defaults to 1.
pub start: Option<u32>,
/// Optional position (1-based index) to end reading on, if you want to read a subset of the contents.
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
///
/// Defaults to reading until the end of the file or directory.
pub end: Option<u32>,
}
pub struct ContentsTool;
impl Tool for ContentsTool {
fn name(&self) -> String {
"contents".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./contents_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::FileSearch
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ContentsToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ContentsToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownString::inline_code(&input.path);
match (input.start, input.end) {
(Some(start), None) => format!("Read {path} (from line {start})"),
(Some(start), Some(end)) => {
format!("Read {path} (lines {start}-{end})")
}
_ => format!("Read {path}"),
}
}
Err(_) => "Read file or directory".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ContentsToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// Sometimes models will return these even though we tell it to give a path and not a glob.
// When this happens, just list the root worktree directories.
if matches!(input.path.as_str(), "." | "" | "./" | "*") {
let output = project
.read(cx)
.worktrees(cx)
.filter_map(|worktree| {
worktree.read(cx).root_entry().and_then(|entry| {
if entry.is_dir() {
entry.path.to_str()
} else {
None
}
})
})
.collect::<Vec<_>>()
.join("\n");
return Task::ready(Ok(output)).into();
}
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found"))).into();
};
let worktree = worktree.read(cx);
let Some(entry) = worktree.entry_for_path(&project_path.path) else {
return Task::ready(Err(anyhow!("Path not found: {}", input.path))).into();
};
// If it's a directory, list its contents
if entry.is_dir() {
let mut output = String::new();
let start_index = input
.start
.map(|line| (line as usize).saturating_sub(1))
.unwrap_or(0);
let end_index = input
.end
.map(|line| (line as usize).saturating_sub(1))
.unwrap_or(MAX_DIR_ENTRIES);
let mut skipped = 0;
for (index, entry) in worktree.child_entries(&project_path.path).enumerate() {
if index >= start_index && index <= end_index {
writeln!(
output,
"{}",
Path::new(worktree.root_name()).join(&entry.path).display(),
)
.unwrap();
} else {
skipped += 1;
}
}
if output.is_empty() {
output.push_str(&input.path);
output.push_str(" is empty.");
}
if skipped > 0 {
write!(
output,
"\n\nNote: Skipped {skipped} entries. Adjust start and end to see other entries.",
).ok();
}
Task::ready(Ok(output)).into()
} else {
// It's a file, so read its contents
let file_path = input.path.clone();
cx.spawn(async move |cx| {
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
})?
.await?;
if input.start.is_some() || input.end.is_some() {
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let start = input.start.unwrap_or(1);
let lines = text.split('\n').skip(start as usize - 1);
if let Some(end) = input.end {
let count = end.saturating_sub(start).max(1); // Ensure at least 1 line
Itertools::intersperse(lines.take(count as usize), "\n").collect()
} else {
Itertools::intersperse(lines, "\n").collect()
}
})?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
Ok(result)
} else {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
if file_size <= MAX_FILE_SIZE_TO_READ {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
Ok(result)
} else {
// File is too big, so return its outline and a suggestion to
// read again with a line number range specified.
let outline = file_outline(project, file_path, action_log, None, cx).await?;
Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline."))
}
}
}).into()
}
}
}

View File

@@ -1,9 +0,0 @@
Reads the contents of a path on the filesystem.
If the path is a directory, this lists all files and directories within that path.
If the path is a file, this returns the file's contents.
When reading a file, if the file is too big and no line range is specified, an outline of the file's code symbols is listed instead, which can be used to request specific line ranges in a subsequent call.
Similarly, if a directory has too many entries to show at once, a subset of entries will be shown,
and subsequent requests can use starting and ending line numbers to get other subsets.

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -55,7 +55,7 @@ impl Tool for CopyPathTool {
IconName::Clipboard
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<CopyPathToolInput>(format)
}
@@ -77,10 +77,10 @@ impl Tool for CopyPathTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<CopyPathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let copy_task = project.update(cx, |project, cx| {
match project
@@ -117,6 +117,5 @@ impl Tool for CopyPathTool {
)),
}
})
.into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -45,7 +45,7 @@ impl Tool for CreateDirectoryTool {
IconName::Folder
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<CreateDirectoryToolInput>(format)
}
@@ -68,16 +68,14 @@ impl Tool for CreateDirectoryTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let project_path = match project.read(cx).find_project_path(&input.path, cx) {
Some(project_path) => project_path,
None => {
return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
}
None => return Task::ready(Err(anyhow!("Path to create was outside the project"))),
};
let destination_path: Arc<str> = input.path.as_str().into();
@@ -91,6 +89,5 @@ impl Tool for CreateDirectoryTool {
Ok(format!("Created directory {destination_path}"))
})
.into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -52,7 +52,7 @@ impl Tool for CreateFileTool {
IconName::FileCreate
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<CreateFileToolInput>(format)
}
@@ -73,16 +73,14 @@ impl Tool for CreateFileTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<CreateFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let project_path = match project.read(cx).find_project_path(&input.path, cx) {
Some(project_path) => project_path,
None => {
return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
}
None => return Task::ready(Err(anyhow!("Path to create was outside the project"))),
};
let contents: Arc<str> = input.contents.as_str().into();
let destination_path: Arc<str> = input.path.as_str().into();
@@ -108,6 +106,5 @@ impl Tool for CreateFileTool {
Ok(format!("Created file {destination_path}"))
})
.into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -45,7 +45,7 @@ impl Tool for DeletePathTool {
IconName::FileDelete
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<DeletePathToolInput>(format)
}
@@ -63,16 +63,15 @@ impl Tool for DeletePathTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
Ok(input) => input.path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let Some(project_path) = project.read(cx).find_project_path(&path_str, cx) else {
return Task::ready(Err(anyhow!(
"Couldn't delete {path_str} because that path isn't in this project."
)))
.into();
)));
};
let Some(worktree) = project
@@ -81,8 +80,7 @@ impl Tool for DeletePathTool {
else {
return Task::ready(Err(anyhow!(
"Couldn't delete {path_str} because that path isn't in this project."
)))
.into();
)));
};
let worktree_snapshot = worktree.read(cx).snapshot();
@@ -134,6 +132,5 @@ impl Tool for DeletePathTool {
)),
}
})
.into()
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -58,7 +58,7 @@ impl Tool for DiagnosticsTool {
IconName::XCircle
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<DiagnosticsToolInput>(format)
}
@@ -83,15 +83,14 @@ impl Tool for DiagnosticsTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
match serde_json::from_value::<DiagnosticsToolInput>(input)
.ok()
.and_then(|input| input.path)
{
Some(path) if !path.is_empty() => {
let Some(project_path) = project.read(cx).find_project_path(&path, cx) else {
return Task::ready(Err(anyhow!("Could not find path {path} in project",)))
.into();
return Task::ready(Err(anyhow!("Could not find path {path} in project",)));
};
let buffer =
@@ -126,7 +125,6 @@ impl Tool for DiagnosticsTool {
Ok(output)
}
})
.into()
}
_ => {
let project = project.read(cx);
@@ -157,10 +155,9 @@ impl Tool for DiagnosticsTool {
});
if has_diagnostics {
Task::ready(Ok(output)).into()
Task::ready(Ok(output))
} else {
Task::ready(Ok("No errors or warnings found in the project.".to_string()))
.into()
}
}
}

View File

@@ -15,7 +15,6 @@ To get a project-wide diagnostic summary:
{}
</example>
<guidelines>
- If you think you can fix a diagnostic, make 1-2 attempts and then give up.
- Don't remove code you've generated just because you can't fix an error. The user can help you fix it.
</guidelines>
IMPORTANT: When you're done making changes, you **MUST** get the **project** diagnostics (input: `{}`) at the end of your edits so you can fix any problems you might have introduced. **DO NOT** tell the user you're done before doing this!
You may only attempt to fix these up to 3 times. If you have tried 3 times to fix them, and there are still problems remaining, you must not continue trying to fix them, and must instead tell the user that there are problems remaining - and ask if the user would like you to attempt to solve them further.

View File

@@ -1,183 +0,0 @@
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::IconName;
use crate::replace::replace_exact;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct EditFileToolInput {
/// The full path of the file to modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
///
/// The following examples assume we have two root directories in the project:
/// - backend
/// - frontend
///
/// <example>
/// `backend/src/main.rs`
///
/// Notice how the file path starts with root-1. Without that, the path
/// would be ambiguous and the call would fail!
/// </example>
///
/// <example>
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
pub display_description: String,
/// The text to replace.
pub old_string: String,
/// The text to replace it with.
pub new_string: String,
}
pub struct EditFileTool;
impl Tool for EditFileTool {
fn name(&self) -> String {
"edit_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("edit_file_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<EditFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Edit file".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<EditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx: &mut AsyncApp| {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
if input.old_string.is_empty() {
return Err(anyhow!("`old_string` cannot be empty. Use a different tool if you want to create a file."));
}
if input.old_string == input.new_string {
return Err(anyhow!("The `old_string` and `new_string` are identical, so no changes would be made."));
}
let result = cx
.background_spawn(async move {
// Try to match exactly
let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
.await
// If that fails, try being flexible about indentation
.or_else(|| replace_with_flexible_indent(&input.old_string, &input.new_string, &snapshot))?;
if diff.edits.is_empty() {
return None;
}
let old_text = snapshot.text();
Some((old_text, diff))
})
.await;
let Some((old_text, diff)) = result else {
let err = buffer.read_with(cx, |buffer, _cx| {
let file_exists = buffer
.file()
.map_or(false, |file| file.disk_state().exists());
if !file_exists {
anyhow!("{} does not exist", input.path.display())
} else if buffer.is_empty() {
anyhow!(
"{} is empty, so the provided `old_string` wasn't found.",
input.path.display()
)
} else {
anyhow!("Failed to match the provided `old_string`")
}
})?;
return Err(err)
};
let snapshot = cx.update(|cx| {
action_log.update(cx, |log, cx| {
log.buffer_read(buffer.clone(), cx)
});
let snapshot = buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction();
buffer.apply_diff(diff, cx);
buffer.finalize_last_transaction();
buffer.snapshot()
});
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
});
snapshot
})?;
project.update( cx, |project, cx| {
project.save_buffer(buffer, cx)
})?.await?;
let diff_str = cx.background_spawn(async move {
let new_text = snapshot.text();
language::unified_diff(&old_text, &new_text)
}).await;
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
}).into()
}
}

View File

@@ -1,45 +0,0 @@
This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files.
Before using this tool:
1. Use the `read_file` tool to understand the file's contents and context
2. Verify the directory path is correct (only applicable when creating new files):
- Use the `list_directory` tool to verify the parent directory exists and is the correct location
To make a file edit, provide the following:
1. path: The full path to the file you wish to modify in the project. This path must include the root directory in the project.
2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation)
3. new_string: The edited text, which will replace the old_string in the file.
The tool will replace ONE occurrence of old_string with new_string in the specified file.
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means:
- Include AT LEAST 3-5 lines of context BEFORE the change point
- Include AT LEAST 3-5 lines of context AFTER the change point
- Include all whitespace, indentation, and surrounding code exactly as it appears in the file
2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances:
- Make separate calls to this tool for each instance
- Each call must uniquely identify its specific instance using extensive context
3. VERIFICATION: Before using this tool:
- Check how many instances of the target text exist in the file
- If multiple instances exist, gather enough context to uniquely identify each one
- Plan separate tool calls for each instance
WARNING: If you do not follow these requirements:
- The tool will fail if old_string matches multiple locations
- The tool will fail if old_string doesn't match exactly (including whitespace)
- You may change the wrong instance if you don't include enough context
When making edits:
- Ensure the edit results in idiomatic, correct code
- Do not leave the code in a broken state
- Always use fully-qualified project paths (starting with the name of one of the project's root directories)
If you want to create a new file, use the `create_file` tool instead of this tool. Don't pass an empty `old_string`.
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext as _, Entity, Task};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
@@ -128,7 +128,7 @@ impl Tool for FetchTool {
IconName::Globe
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<FetchToolInput>(format)
}
@@ -146,10 +146,10 @@ impl Tool for FetchTool {
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<FetchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let text = cx.background_spawn({
@@ -158,15 +158,13 @@ impl Tool for FetchTool {
async move { Self::build_message(http_client, &url).await }
});
cx.foreground_executor()
.spawn(async move {
let text = text.await?;
if text.trim().is_empty() {
bail!("no textual content found");
}
cx.foreground_executor().spawn(async move {
let text = text.await?;
if text.trim().is_empty() {
bail!("no textual content found");
}
Ok(text)
})
.into()
Ok(text)
})
}
}

View File

@@ -1,6 +1,6 @@
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -151,7 +151,7 @@ impl Tool for FindReplaceFileTool {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<FindReplaceFileToolInput>(format)
}
@@ -169,10 +169,10 @@ impl Tool for FindReplaceFileTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<FindReplaceFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
cx.spawn(async move |cx: &mut AsyncApp| {
@@ -263,6 +263,6 @@ impl Tool for FindReplaceFileTool {
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
}).into()
})
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -12,7 +12,7 @@ use util::markdown::MarkdownString;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ListDirectoryToolInput {
/// The fully-qualified path of the directory to list in the project.
/// The relative path of the directory to list.
///
/// This path should never be absolute, and the first component
/// of the path should always be a root directory in a project.
@@ -56,7 +56,7 @@ impl Tool for ListDirectoryTool {
IconName::Folder
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<ListDirectoryToolInput>(format)
}
@@ -77,10 +77,10 @@ impl Tool for ListDirectoryTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
// Sometimes models will return these even though we tell it to give a path and not a glob.
@@ -101,26 +101,26 @@ impl Tool for ListDirectoryTool {
.collect::<Vec<_>>()
.join("\n");
return Task::ready(Ok(output)).into();
return Task::ready(Ok(output));
}
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", input.path))).into();
return Task::ready(Err(anyhow!("Path {} not found in project", input.path)));
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found"))).into();
return Task::ready(Err(anyhow!("Worktree not found")));
};
let worktree = worktree.read(cx);
let Some(entry) = worktree.entry_for_path(&project_path.path) else {
return Task::ready(Err(anyhow!("Path not found: {}", input.path))).into();
return Task::ready(Err(anyhow!("Path not found: {}", input.path)));
};
if !entry.is_dir() {
return Task::ready(Err(anyhow!("{} is not a directory.", input.path))).into();
return Task::ready(Err(anyhow!("{} is not a directory.", input.path)));
}
let mut output = String::new();
@@ -133,8 +133,8 @@ impl Tool for ListDirectoryTool {
.unwrap();
}
if output.is_empty() {
return Task::ready(Ok(format!("{} is empty.", input.path))).into();
return Task::ready(Ok(format!("{} is empty.", input.path)));
}
Task::ready(Ok(output)).into()
Task::ready(Ok(output))
}
}

View File

@@ -1 +1 @@
Lists files and directories in a given path. Prefer the `regex_search` or `path_search` tools when searching the codebase.
Lists files and directories in a given path.

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -54,7 +54,7 @@ impl Tool for MovePathTool {
IconName::ArrowRightLeft
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<MovePathToolInput>(format)
}
@@ -90,10 +90,10 @@ impl Tool for MovePathTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<MovePathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let rename_task = project.update(cx, |project, cx| {
match project
@@ -128,6 +128,5 @@ impl Tool for MovePathTool {
)),
}
})
.into()
}
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use chrono::{Local, Utc};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -45,7 +45,7 @@ impl Tool for NowTool {
IconName::Info
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<NowToolInput>(format)
}
@@ -60,10 +60,10 @@ impl Tool for NowTool {
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input: NowToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let now = match input.timezone {
@@ -72,6 +72,6 @@ impl Tool for NowTool {
};
let text = format!("The current datetime is {now}.");
Task::ready(Ok(text)).into()
Task::ready(Ok(text))
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -35,7 +35,7 @@ impl Tool for OpenTool {
IconName::ArrowUpRight
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<OpenToolInput>(format)
}
@@ -53,10 +53,10 @@ impl Tool for OpenTool {
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input: OpenToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
cx.background_spawn(async move {
@@ -64,6 +64,5 @@ impl Tool for OpenTool {
Ok(format!("Successfully opened {}", input.path_or_url))
})
.into()
}
}

View File

@@ -1,19 +1,19 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{cmp, fmt::Write as _, path::PathBuf, sync::Arc};
use std::{path::PathBuf, sync::Arc};
use ui::IconName;
use util::paths::PathMatcher;
use worktree::Snapshot;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct PathSearchToolInput {
/// The glob to match against every path in the project.
/// The glob to search all project paths for.
///
/// <example>
/// If the project has the following root directories:
@@ -53,7 +53,7 @@ impl Tool for PathSearchTool {
IconName::SearchCode
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<PathSearchToolInput>(format)
}
@@ -71,130 +71,71 @@ impl Tool for PathSearchTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
Ok(input) => (input.offset, input.glob),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let offset = offset as usize;
let task = search_paths(&glob, project, cx);
let path_matcher = match PathMatcher::new([
// Sometimes models try to search for "". In this case, return all paths in the project.
if glob.is_empty() { "*" } else { &glob },
]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
};
let snapshots: Vec<Snapshot> = project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
cx.background_spawn(async move {
let matches = task.await?;
let paginated_matches = &matches[cmp::min(offset, matches.len())
..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
let mut matches = Vec::new();
for worktree in snapshots {
let root_name = worktree.root_name();
// Don't consider ignored entries.
for entry in worktree.entries(false, 0) {
if path_matcher.is_match(&entry.path) {
matches.push(
PathBuf::from(root_name)
.join(&entry.path)
.to_string_lossy()
.to_string(),
);
}
}
}
if matches.is_empty() {
Ok("No matches found".to_string())
Ok(format!("No paths in the project matched the glob {glob:?}"))
} else {
let mut message = format!("Found {} total matches.", matches.len());
if matches.len() > RESULTS_PER_PAGE {
write!(
&mut message,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
// Sort to group entries in the same directory together.
matches.sort();
let total_matches = matches.len();
let response = if total_matches > RESULTS_PER_PAGE + offset as usize {
let paginated_matches: Vec<_> = matches
.into_iter()
.skip(offset as usize)
.take(RESULTS_PER_PAGE)
.collect();
format!(
"Found {} total matches. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}",
total_matches,
offset + 1,
offset + paginated_matches.len()
offset as usize + paginated_matches.len(),
paginated_matches.join("\n")
)
.unwrap();
}
for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(message)
} else {
matches.join("\n")
};
Ok(response)
}
})
.into()
}
}
fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
let path_matcher = match PathMatcher::new([
// Sometimes models try to search for "". In this case, return all paths in the project.
if glob.is_empty() { "*" } else { glob },
]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
};
let snapshots: Vec<Snapshot> = project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
cx.background_spawn(async move {
Ok(snapshots
.iter()
.flat_map(|snapshot| {
let root_name = PathBuf::from(snapshot.root_name());
snapshot
.entries(false, 0)
.map(move |entry| root_name.join(&entry.path))
.filter(|path| path_matcher.is_match(&path))
})
.collect())
})
}
#[cfg(test)]
mod test {
use super::*;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_path_search_tool(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
serde_json::json!({
"apple": {
"banana": {
"carrot": "1",
},
"bandana": {
"carbonara": "2",
},
"endive": "3"
}
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let matches = cx
.update(|cx| search_paths("root/**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
let matches = cx
.update(|cx| search_paths("**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
}

View File

@@ -1,7 +1,3 @@
Fast file pattern matching tool that works with any codebase size
Returns paths in the project which match the given glob.
- Supports glob patterns like "**/*.js" or "src/**/*.ts"
- Returns matching file paths sorted alphabetically
- Prefer the `regex_search` tool to this tool when searching for symbols unless you have specific information about paths.
- Use this tool when you need to find files by name patterns
- Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages.
Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages.

View File

@@ -1,14 +1,14 @@
use std::sync::Arc;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use indoc::formatdoc;
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ui::IconName;
use util::markdown::MarkdownString;
@@ -63,7 +63,7 @@ impl Tool for ReadFileTool {
IconName::FileSearch
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<ReadFileToolInput>(format)
}
@@ -88,31 +88,18 @@ impl Tool for ReadFileTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path,)));
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
};
let exists = worktree.update(cx, |worktree, cx| {
worktree.file_exists(&project_path.path, cx)
});
let file_path = input.path.clone();
cx.spawn(async move |cx| {
if !exists.await? {
return Err(anyhow!("{} not found", file_path))
}
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
@@ -154,231 +141,11 @@ impl Tool for ReadFileTool {
} else {
// File is too big, so return an error with the outline
// and a suggestion to read again with line numbers.
let outline = file_outline(project, file_path, action_log, None, cx).await?;
Ok(formatdoc! {"
This file was too big to read all at once. Here is an outline of its symbols:
let outline = file_outline(project, file_path, action_log, None, 0, cx).await?;
{outline}
Using the line numbers in this outline, you can call this tool again while specifying
the start_line and end_line fields to see the implementations of symbols in the outline."
})
Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start_line and end_line fields to see the implementations of symbols in the outline."))
}
}
}).into()
}
}
#[cfg(test)]
mod test {
use super::*;
use gpui::{AppContext, TestAppContext};
use language::{Language, LanguageConfig, LanguageMatcher};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let result = cx
.update(|cx| {
let input = json!({
"path": "root/nonexistent_file.txt"
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, cx)
.output
})
.await;
assert_eq!(
result.unwrap_err().to_string(),
"root/nonexistent_file.txt not found"
);
}
#[gpui::test]
async fn test_read_small_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"small_file.txt": "This is a small file content"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let result = cx
.update(|cx| {
let input = json!({
"path": "root/small_file.txt"
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, cx)
.output
})
.await;
assert_eq!(result.unwrap(), "This is a small file content");
}
#[gpui::test]
async fn test_read_large_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n a: u32,\n b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let result = cx
.update(|cx| {
let input = json!({
"path": "root/large_file.rs"
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log.clone(), cx)
.output
})
.await;
let content = result.unwrap();
assert_eq!(
content.lines().skip(2).take(6).collect::<Vec<_>>(),
vec![
"struct Test0 [L1-4]",
" a [L2]",
" b [L3]",
"struct Test1 [L5-8]",
" a [L6]",
" b [L7]",
]
);
let result = cx
.update(|cx| {
let input = json!({
"path": "root/large_file.rs",
"offset": 1
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, cx)
.output
})
.await;
let content = result.unwrap();
let expected_content = (0..1000)
.flat_map(|i| {
vec![
format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
format!(" a [L{}]", i * 4 + 2),
format!(" b [L{}]", i * 4 + 3),
]
})
.collect::<Vec<_>>();
pretty_assertions::assert_eq!(
content
.lines()
.skip(2)
.take(expected_content.len())
.collect::<Vec<_>>(),
expected_content
);
}
#[gpui::test]
async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 2,
"end_line": 4
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, cx)
.output
})
.await;
assert_eq!(result.unwrap(), "Line 2\nLine 3");
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_outline_query(
r#"
(line_comment) @annotation
(struct_item
"struct" @context
name: (_) @name) @item
(enum_item
"enum" @context
name: (_) @name) @item
(enum_variant
name: (_) @name) @item
(field_declaration
name: (_) @name) @item
(impl_item
"impl" @context
trait: (_)? @name
"for"? @context
type: (_) @name
body: (_ "{" (_)* "}")) @item
(function_item
"fn" @context
name: (_) @name) @item
(mod_item
"mod" @context
name: (_) @name) @item
"#,
)
.unwrap()
})
}
}

View File

@@ -1,6 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use futures::StreamExt;
use gpui::{App, Entity, Task};
use language::OffsetRangeExt;
@@ -60,7 +60,7 @@ impl Tool for RegexSearchTool {
IconName::Regex
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<RegexSearchToolInput>(format)
}
@@ -92,13 +92,13 @@ impl Tool for RegexSearchTool {
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
const CONTEXT_LINES: u32 = 2;
let (offset, regex, case_sensitive) =
match serde_json::from_value::<RegexSearchToolInput>(input) {
Ok(input) => (input.offset, input.regex, input.case_sensitive),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let query = match SearchQuery::regex(
@@ -112,7 +112,7 @@ impl Tool for RegexSearchTool {
None,
) {
Ok(query) => query,
Err(error) => return Task::ready(Err(error)).into(),
Err(error) => return Task::ready(Err(error)),
};
let results = project.update(cx, |project, cx| project.search(query, cx));
@@ -201,6 +201,6 @@ impl Tool for RegexSearchTool {
} else {
Ok(format!("Found {matches_found} matches:\n{output}"))
}
}).into()
})
}
}

View File

@@ -1,6 +1,7 @@
Searches the entire project for the given regular expression.
- Prefer this tool when searching for files containing symbols in the project.
- Supports full regex syntax (eg. "log.*Error", "function\\s+\\w+", etc.)
- Use this tool when you need to find files containing specific patterns
- Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages.
Returns a list of paths that matched the query. For each path, it returns some excerpts of the matched text.
Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages.
This tool is not aware of semantics and does not use any information from language servers, so it should only be used when no available semantic tool (e.g. one that uses language servers) could fit a particular use case instead.

View File

@@ -1,16 +1,14 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language::{self, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ui::IconName;
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RenameToolInput {
/// The relative path to the file containing the symbol to rename.
@@ -68,8 +66,12 @@ impl Tool for RenameTool {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<RenameToolInput>(format)
fn input_schema(
&self,
_format: language_model::LanguageModelToolSchemaFormat,
) -> serde_json::Value {
let schema = schemars::schema_for!(RenameToolInput);
serde_json::to_value(&schema).unwrap()
}
fn ui_text(&self, input: &serde_json::Value) -> String {
@@ -88,10 +90,10 @@ impl Tool for RenameTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<RenameToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
cx.spawn(async move |cx| {
@@ -138,7 +140,7 @@ impl Tool for RenameTool {
})?;
Ok(format!("Renamed '{}' to '{}'", input.symbol, input.new_name))
}).into()
})
}
}

View File

@@ -5,20 +5,23 @@ use schemars::{
schema::{RootSchema, Schema, SchemaObject},
};
pub fn json_schema_for<T: JsonSchema>(
format: LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
pub fn json_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> serde_json::Value {
let schema = root_schema_for::<T>(format);
schema_to_json(&schema, format)
schema_to_json(&schema, format).expect("Failed to convert tool calling schema to JSON")
}
fn schema_to_json(
pub fn schema_to_json(
schema: &RootSchema,
format: LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let mut value = serde_json::to_value(schema)?;
assistant_tool::adapt_schema_to_format(&mut value, format)?;
Ok(value)
match format {
LanguageModelToolSchemaFormat::JsonSchema => Ok(value),
LanguageModelToolSchemaFormat::JsonSchemaSubset => {
transform_fields_to_json_schema_subset(&mut value);
Ok(value)
}
}
}
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
@@ -76,3 +79,42 @@ impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
schemars::visit::visit_schema_object(self, schema)
}
}
fn transform_fields_to_json_schema_subset(json: &mut serde_json::Value) {
if let serde_json::Value::Object(obj) = json {
if let Some(default) = obj.get("default") {
let is_null = default.is_null();
//Default is not supported, so we need to remove it.
obj.remove("default");
if is_null {
obj.insert("nullable".to_string(), serde_json::Value::Bool(true));
}
}
// If a type is not specified for an input parameter we need to add it.
if obj.contains_key("description")
&& !obj.contains_key("type")
&& !(obj.contains_key("anyOf")
|| obj.contains_key("oneOf")
|| obj.contains_key("allOf"))
{
obj.insert(
"type".to_string(),
serde_json::Value::String("string".to_string()),
);
}
//Format field is only partially supported (e.g. not uint compatibility)
obj.remove("format");
for (_, value) in obj.iter_mut() {
if let serde_json::Value::Object(_) | serde_json::Value::Array(_) = value {
transform_fields_to_json_schema_subset(value);
}
}
} else if let serde_json::Value::Array(arr) = json {
for item in arr.iter_mut() {
transform_fields_to_json_schema_subset(item);
}
}
}

View File

@@ -1,5 +1,5 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AsyncApp, Entity, Task};
use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -84,7 +84,7 @@ impl Tool for SymbolInfoTool {
IconName::Code
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<SymbolInfoToolInput>(format)
}
@@ -122,10 +122,10 @@ impl Tool for SymbolInfoTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
) -> Task<Result<String>> {
let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
cx.spawn(async move |cx| {
@@ -205,7 +205,7 @@ impl Tool for SymbolInfoTool {
} else {
Ok(output)
}
}).into()
})
}
}

Some files were not shown because too many files have changed in this diff Show More