Compare commits

..

1 Commits

Author SHA1 Message Date
Peter Tripp
d5e93c08d1 vim: Support keyboard navigation of context menus and code actions 2025-06-25 12:42:40 -04:00
528 changed files with 12525 additions and 25924 deletions

View File

@@ -29,8 +29,6 @@ jobs:
outputs:
run_tests: ${{ steps.filter.outputs.run_tests }}
run_license: ${{ steps.filter.outputs.run_license }}
run_docs: ${{ steps.filter.outputs.run_docs }}
run_nix: ${{ steps.filter.outputs.run_nix }}
runs-on:
- ubuntu-latest
steps:
@@ -60,22 +58,11 @@ jobs:
else
echo "run_tests=false" >> $GITHUB_OUTPUT
fi
if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep '^docs/') ]]; then
echo "run_docs=true" >> $GITHUB_OUTPUT
else
echo "run_docs=false" >> $GITHUB_OUTPUT
fi
if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep '^Cargo.lock') ]]; then
echo "run_license=true" >> $GITHUB_OUTPUT
else
echo "run_license=false" >> $GITHUB_OUTPUT
fi
NIX_REGEX='^(nix/|flake\.|Cargo\.|rust-toolchain.toml|\.cargo/config.toml)'
if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep "$NIX_REGEX") ]]; then
echo "run_nix=true" >> $GITHUB_OUTPUT
else
echo "run_nix=false" >> $GITHUB_OUTPUT
fi
migration_checks:
name: Check Postgres and Protobuf migrations, mergability
@@ -211,9 +198,7 @@ jobs:
timeout-minutes: 60
name: Check docs
needs: [job_spec]
if: |
github.repository_owner == 'zed-industries' &&
(needs.job_spec.outputs.run_tests == 'true' || needs.job_spec.outputs.run_docs == 'true')
if: github.repository_owner == 'zed-industries'
runs-on:
- buildjet-8vcpu-ubuntu-2204
steps:
@@ -467,10 +452,8 @@ jobs:
RET_CODE=0
# Always check style
[[ "${{ needs.style.result }}" != 'success' ]] && { RET_CODE=1; echo "style tests failed"; }
[[ "${{ needs.check_docs.result }}" != 'success' ]] && { RET_CODE=1; echo "docs checks failed"; }
if [[ "${{ needs.job_spec.outputs.run_docs }}" == "true" ]]; then
[[ "${{ needs.check_docs.result }}" != 'success' ]] && { RET_CODE=1; echo "docs checks failed"; }
fi
# Only check test jobs if they were supposed to run
if [[ "${{ needs.job_spec.outputs.run_tests }}" == "true" ]]; then
[[ "${{ needs.workspace_hack.result }}" != 'success' ]] && { RET_CODE=1; echo "Workspace Hack failed"; }
@@ -753,10 +736,7 @@ jobs:
nix-build:
name: Build with Nix
uses: ./.github/workflows/nix.yml
needs: [job_spec]
if: github.repository_owner == 'zed-industries' &&
(contains(github.event.pull_request.labels.*.name, 'run-nix') ||
needs.job_spec.outputs.run_nix == 'true')
if: github.repository_owner == 'zed-industries' && contains(github.event.pull_request.labels.*.name, 'run-nix')
secrets: inherit
with:
flake-output: debug

View File

@@ -0,0 +1,34 @@
name: Delete Mediafire Comments
on:
issue_comment:
types: [created]
permissions:
issues: write
jobs:
delete_comment:
if: github.repository_owner == 'zed-industries'
runs-on: ubuntu-latest
steps:
- name: Check for specific strings in comment
id: check_comment
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7
with:
script: |
const comment = context.payload.comment.body;
const triggerStrings = ['www.mediafire.com'];
return triggerStrings.some(triggerString => comment.includes(triggerString));
- name: Delete comment if it contains any of the specific strings
if: steps.check_comment.outputs.result == 'true'
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7
with:
script: |
const commentId = context.payload.comment.id;
await github.rest.issues.deleteComment({
owner: context.repo.owner,
repo: context.repo.repo,
comment_id: commentId
});

View File

@@ -7,7 +7,7 @@ on:
pull_request:
branches:
- "**"
types: [synchronize, reopened, labeled]
types: [opened, synchronize, reopened, labeled]
workflow_dispatch:
@@ -25,6 +25,16 @@ env:
ZED_EVAL_TELEMETRY: 1
jobs:
# This is a no-op job that we run to prevent GitHub from marking the workflow
# as failed for PRs that don't have the `run-eval` label.
noop:
name: No-op
runs-on: ubuntu-latest
if: github.repository_owner == 'zed-industries'
steps:
- name: No-op
run: echo "Nothing to do"
run_eval:
timeout-minutes: 60
name: Run Agent Eval

View File

@@ -40,7 +40,6 @@
},
"file_types": {
"Dockerfile": ["Dockerfile*[!dockerignore]"],
"JSONC": ["assets/**/*.json", "renovate.json"],
"Git Ignore": ["dockerignore"]
},
"hard_tabs": false,

105
Cargo.lock generated
View File

@@ -14,7 +14,6 @@ dependencies = [
"gpui",
"language",
"project",
"proto",
"release_channel",
"smallvec",
"ui",
@@ -78,7 +77,6 @@ dependencies = [
"language",
"language_model",
"log",
"parking_lot",
"paths",
"postage",
"pretty_assertions",
@@ -111,11 +109,18 @@ dependencies = [
name = "agent_settings"
version = "0.1.0"
dependencies = [
"anthropic",
"anyhow",
"collections",
"deepseek",
"fs",
"gpui",
"language_model",
"lmstudio",
"log",
"mistral",
"ollama",
"open_ai",
"paths",
"schemars",
"serde",
@@ -1911,6 +1916,7 @@ dependencies = [
"serde_json",
"strum 0.27.1",
"thiserror 2.0.12",
"tokio",
"workspace-hack",
]
@@ -2076,7 +2082,7 @@ dependencies = [
[[package]]
name = "blade-graphics"
version = "0.6.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
dependencies = [
"ash",
"ash-window",
@@ -2109,7 +2115,7 @@ dependencies = [
[[package]]
name = "blade-macros"
version = "0.3.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
dependencies = [
"proc-macro2",
"quote",
@@ -2119,7 +2125,7 @@ dependencies = [
[[package]]
name = "blade-util"
version = "0.2.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
dependencies = [
"blade-graphics",
"bytemuck",
@@ -4132,7 +4138,7 @@ dependencies = [
[[package]]
name = "dap-types"
version = "0.0.1"
source = "git+https://github.com/zed-industries/dap-types?rev=7f39295b441614ca9dbf44293e53c32f666897f9#7f39295b441614ca9dbf44293e53c32f666897f9"
source = "git+https://github.com/zed-industries/dap-types?rev=b40956a7f4d1939da67429d941389ee306a3a308#b40956a7f4d1939da67429d941389ee306a3a308"
dependencies = [
"schemars",
"serde",
@@ -4147,8 +4153,6 @@ dependencies = [
"async-trait",
"collections",
"dap",
"dotenvy",
"fs",
"futures 0.3.31",
"gpui",
"json_dotpath",
@@ -4157,7 +4161,6 @@ dependencies = [
"paths",
"serde",
"serde_json",
"shlex",
"task",
"util",
"workspace-hack",
@@ -4311,7 +4314,6 @@ version = "0.1.0"
dependencies = [
"alacritty_terminal",
"anyhow",
"bitflags 2.9.0",
"client",
"collections",
"command_palette_hooks",
@@ -4677,6 +4679,12 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "dotenv"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
[[package]]
name = "dotenvy"
version = "0.15.7"
@@ -4809,7 +4817,6 @@ dependencies = [
"pretty_assertions",
"project",
"rand 0.8.5",
"regex",
"release_channel",
"rpc",
"schemars",
@@ -4830,7 +4837,6 @@ dependencies = [
"tree-sitter-python",
"tree-sitter-rust",
"tree-sitter-typescript",
"tree-sitter-yaml",
"ui",
"unicode-script",
"unicode-segmentation",
@@ -5111,7 +5117,7 @@ dependencies = [
"collections",
"debug_adapter_extension",
"dirs 4.0.0",
"dotenvy",
"dotenv",
"env_logger 0.11.8",
"extension",
"fs",
@@ -8844,7 +8850,6 @@ dependencies = [
"http_client",
"imara-diff",
"indoc",
"inventory",
"itertools 0.14.0",
"log",
"lsp",
@@ -8943,10 +8948,8 @@ dependencies = [
"aws-credential-types",
"aws_http_client",
"bedrock",
"chrono",
"client",
"collections",
"component",
"copilot",
"credentials_provider",
"deepseek",
@@ -9017,13 +9020,11 @@ dependencies = [
"collections",
"copilot",
"editor",
"feature_flags",
"futures 0.3.31",
"gpui",
"itertools 0.14.0",
"language",
"lsp",
"picker",
"project",
"release_channel",
"serde_json",
@@ -9241,7 +9242,7 @@ dependencies = [
[[package]]
name = "libwebrtc"
version = "0.3.10"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=80bb8f4c9112789f7c24cc98d8423010977806a6#80bb8f4c9112789f7c24cc98d8423010977806a6"
dependencies = [
"cxx",
"jni",
@@ -9321,7 +9322,7 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856"
[[package]]
name = "livekit"
version = "0.7.8"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=80bb8f4c9112789f7c24cc98d8423010977806a6#80bb8f4c9112789f7c24cc98d8423010977806a6"
dependencies = [
"chrono",
"futures-util",
@@ -9344,7 +9345,7 @@ dependencies = [
[[package]]
name = "livekit-api"
version = "0.4.2"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=80bb8f4c9112789f7c24cc98d8423010977806a6#80bb8f4c9112789f7c24cc98d8423010977806a6"
dependencies = [
"futures-util",
"http 0.2.12",
@@ -9368,7 +9369,7 @@ dependencies = [
[[package]]
name = "livekit-protocol"
version = "0.3.9"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=80bb8f4c9112789f7c24cc98d8423010977806a6#80bb8f4c9112789f7c24cc98d8423010977806a6"
dependencies = [
"futures-util",
"livekit-runtime",
@@ -9385,7 +9386,7 @@ dependencies = [
[[package]]
name = "livekit-runtime"
version = "0.4.0"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=80bb8f4c9112789f7c24cc98d8423010977806a6#80bb8f4c9112789f7c24cc98d8423010977806a6"
dependencies = [
"tokio",
"tokio-stream",
@@ -12259,7 +12260,6 @@ dependencies = [
"language",
"log",
"lsp",
"markdown",
"node_runtime",
"parking_lot",
"pathdiff",
@@ -13168,7 +13168,6 @@ dependencies = [
"thiserror 2.0.12",
"urlencoding",
"util",
"which 6.0.3",
"workspace-hack",
]
@@ -14054,13 +14053,12 @@ dependencies = [
[[package]]
name = "schemars"
version = "1.0.1"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe8c9d1c68d67dd9f97ecbc6f932b60eb289c5dbddd8aa1405484a8fd2fcd984"
checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615"
dependencies = [
"dyn-clone",
"indexmap",
"ref-cast",
"schemars_derive",
"serde",
"serde_json",
@@ -14068,9 +14066,9 @@ dependencies = [
[[package]]
name = "schemars_derive"
version = "1.0.1"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ca9fcb757952f8e8629b9ab066fc62da523c46c2b247b1708a3be06dd82530b"
checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d"
dependencies = [
"proc-macro2",
"quote",
@@ -14557,41 +14555,28 @@ dependencies = [
"serde_json",
"serde_json_lenient",
"smallvec",
"streaming-iterator",
"tree-sitter",
"tree-sitter-json",
"unindent",
"util",
"workspace-hack",
"zlog",
]
[[package]]
name = "settings_ui"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"command_palette",
"command_palette_hooks",
"component",
"db",
"editor",
"feature_flags",
"fs",
"fuzzy",
"gpui",
"language",
"log",
"menu",
"paths",
"project",
"schemars",
"search",
"serde",
"settings",
"theme",
"tree-sitter-json",
"tree-sitter-rust",
"ui",
"util",
"workspace",
@@ -15542,18 +15527,6 @@ version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0193cc4331cfd2f3d2011ef287590868599a2f33c3e69bc22c1a3d3acf9e02fb"
[[package]]
name = "svg_preview"
version = "0.1.0"
dependencies = [
"editor",
"file_icons",
"gpui",
"ui",
"workspace",
"workspace-hack",
]
[[package]]
name = "svgtypes"
version = "0.15.3"
@@ -16025,7 +15998,6 @@ dependencies = [
"futures 0.3.31",
"gpui",
"indexmap",
"inventory",
"log",
"palette",
"parking_lot",
@@ -16066,6 +16038,7 @@ dependencies = [
"indexmap",
"log",
"palette",
"rust-embed",
"serde",
"serde_json",
"serde_json_lenient",
@@ -16896,7 +16869,8 @@ dependencies = [
[[package]]
name = "tree-sitter-python"
version = "0.23.6"
source = "git+https://github.com/zed-industries/tree-sitter-python?rev=218fcbf3fda3d029225f3dec005cb497d111b35e#218fcbf3fda3d029225f3dec005cb497d111b35e"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d065aaa27f3aaceaf60c1f0e0ac09e1cb9eb8ed28e7bcdaa52129cffc7f4b04"
dependencies = [
"cc",
"tree-sitter-language",
@@ -17348,7 +17322,6 @@ dependencies = [
"rand 0.8.5",
"regex",
"rust-embed",
"schemars",
"serde",
"serde_json",
"serde_json_lenient",
@@ -17458,8 +17431,11 @@ name = "vercel"
version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.31",
"http_client",
"schemars",
"serde",
"serde_json",
"strum 0.27.1",
"workspace-hack",
]
@@ -18313,7 +18289,7 @@ dependencies = [
[[package]]
name = "webrtc-sys"
version = "0.3.7"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=80bb8f4c9112789f7c24cc98d8423010977806a6#80bb8f4c9112789f7c24cc98d8423010977806a6"
dependencies = [
"cc",
"cxx",
@@ -18326,7 +18302,7 @@ dependencies = [
[[package]]
name = "webrtc-sys-build"
version = "0.3.6"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4#d2eade7a6b15d6dbdb38ba12a1ff7bf07fcebba4"
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=80bb8f4c9112789f7c24cc98d8423010977806a6#80bb8f4c9112789f7c24cc98d8423010977806a6"
dependencies = [
"fs2",
"regex",
@@ -19946,7 +19922,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.195.0"
version = "0.193.0"
dependencies = [
"activity_indicator",
"agent",
@@ -20051,7 +20027,6 @@ dependencies = [
"snippet_provider",
"snippets_ui",
"supermaven",
"svg_preview",
"sysinfo",
"tab_switcher",
"task",
@@ -20144,9 +20119,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.8.5"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203"
dependencies = [
"anyhow",
"serde",

View File

@@ -95,7 +95,6 @@ members = [
"crates/markdown_preview",
"crates/media",
"crates/menu",
"crates/svg_preview",
"crates/migrator",
"crates/mistral",
"crates/multi_buffer",
@@ -305,7 +304,6 @@ lmstudio = { path = "crates/lmstudio" }
lsp = { path = "crates/lsp" }
markdown = { path = "crates/markdown" }
markdown_preview = { path = "crates/markdown_preview" }
svg_preview = { path = "crates/svg_preview" }
media = { path = "crates/media" }
menu = { path = "crates/menu" }
migrator = { path = "crates/migrator" }
@@ -425,9 +423,9 @@ aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
base64 = "0.22"
bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-util = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
blade-util = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
blake3 = "1.5.3"
bytes = "1.0"
cargo_metadata = "0.19"
@@ -444,12 +442,12 @@ core-video = { version = "0.4.3", features = ["metal"] }
cpal = "0.16"
criterion = { version = "0.5", features = ["html_reports"] }
ctor = "0.4.0"
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "7f39295b441614ca9dbf44293e53c32f666897f9" }
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "b40956a7f4d1939da67429d941389ee306a3a308" }
dashmap = "6.0"
derive_more = "0.99.17"
dirs = "4.0"
documented = "0.9.1"
dotenvy = "0.15.0"
dotenv = "0.15.0"
ec4rs = "1.1"
emojis = "0.6.1"
env_logger = "0.11"
@@ -480,7 +478,7 @@ json_dotpath = "1.1"
jsonschema = "0.30.0"
jsonwebtoken = "9.3"
jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
libc = "0.2"
libsqlite3-sys = { version = "0.30.1", features = ["bundled"] }
linkify = "0.10.0"
@@ -491,7 +489,7 @@ metal = "0.29"
moka = { version = "0.12.10", features = ["sync"] }
naga = { version = "25.0", features = ["wgsl-in"] }
nanoid = "0.4"
nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
nix = "0.29"
num-format = "0.4.4"
objc = "0.2"
@@ -531,7 +529,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77
"stream",
] }
rsa = "0.9.6"
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
"async-dispatcher-runtime",
] }
rust-embed = { version = "8.4", features = ["include-exclude"] }
@@ -540,7 +538,7 @@ rustc-hash = "2.1.0"
rustls = { version = "0.23.26" }
rustls-platform-verifier = "0.5.0"
scap = { git = "https://github.com/zed-industries/scap", rev = "08f0a01417505cc0990b9931a37e5120db92e0d0", default-features = false }
schemars = { version = "1.0", features = ["indexmap2"] }
schemars = { version = "0.8", features = ["impl_json_schema", "indexmap2"] }
semver = "1.0"
serde = { version = "1.0", features = ["derive", "rc"] }
serde_derive = { version = "1.0", features = ["deserialize_in_place"] }
@@ -597,7 +595,7 @@ tree-sitter-html = "0.23"
tree-sitter-jsdoc = "0.23"
tree-sitter-json = "0.24"
tree-sitter-md = { git = "https://github.com/tree-sitter-grammars/tree-sitter-markdown", rev = "9a23c1a96c0513d8fc6520972beedd419a973539" }
tree-sitter-python = { git = "https://github.com/zed-industries/tree-sitter-python", rev = "218fcbf3fda3d029225f3dec005cb497d111b35e" }
tree-sitter-python = "0.23"
tree-sitter-regex = "0.24"
tree-sitter-ruby = "0.23"
tree-sitter-rust = "0.24"
@@ -625,7 +623,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29"
which = "6.0.0"
workspace-hack = "0.1.0"
zed_llm_client = "= 0.8.5"
zed_llm_client = "0.8.4"
zstd = "0.11"
[workspace.dependencies.async-stripe]

View File

@@ -1,6 +1,6 @@
# syntax = docker/dockerfile:1.2
FROM rust:1.88-bookworm as builder
FROM rust:1.87-bookworm as builder
WORKDIR app
COPY . .

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-arrow-down10-icon lucide-arrow-down-1-0"><path d="m3 16 4 4 4-4"/><path d="M7 20V4"/><path d="M17 10V4h-2"/><path d="M15 10h4"/><rect x="15" y="14" width="4" height="6" ry="2"/></svg>

Before

Width:  |  Height:  |  Size: 386 B

View File

@@ -1,3 +0,0 @@
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M6.75776 5.50003H8.49988C8.70769 5.50003 8.89518 5.62971 8.95455 5.82346C9.04049 6.01876 8.9858 6.23906 8.82956 6.37656L4.82971 9.87643C4.65315 10.0295 4.39488 10.042 4.20614 9.90455C4.01724 9.76705 3.94849 9.51706 4.04052 9.30301L5.24219 6.49999H3.48601C3.2918 6.49999 3.10524 6.37031 3.03197 6.17657C2.9587 5.98126 3.014 5.76096 3.1708 5.62346L7.17018 2.12375C7.34674 1.97001 7.60454 1.95829 7.7936 2.09547C7.98265 2.23275 8.0514 2.48218 7.95922 2.69695L6.75776 5.50003Z" fill="black"/>
</svg>

Before

Width:  |  Height:  |  Size: 601 B

View File

@@ -1,12 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M6 3L7 4" stroke="black" stroke-width="1.33333" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M9 4L10 3" stroke="black" stroke-width="1.33333" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M6.002 6V5.51658C5.98992 5.32067 6.03266 5.12502 6.12762 4.94143C6.22259 4.75784 6.36781 4.59012 6.55453 4.44839C6.74125 4.30666 6.9656 4.19386 7.21403 4.1168C7.46246 4.03973 7.72983 4 8 4C8.27017 4 8.53754 4.03973 8.78597 4.1168C9.0344 4.19386 9.25875 4.30666 9.44547 4.44839C9.63219 4.59012 9.77741 4.75784 9.87238 4.94143C9.96734 5.12502 10.0101 5.32067 9.998 5.51658V6" stroke="black" stroke-width="1.33333" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M8 13C6.35 13 5 11.5462 5 9.76923V8.15385C5 7.58261 5.21071 7.03477 5.58579 6.63085C5.96086 6.22692 6.46957 6 7 6H9C9.53043 6 10.0391 6.22692 10.4142 6.63085C10.7893 7.03477 11 7.58261 11 8.15385V9.76923C11 11.5462 9.65 13 8 13Z" stroke="black" stroke-width="1.33333" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M5 6.16663C3.90652 6.06663 3 5.21663 3 4.16663" stroke="black" stroke-width="1.33333" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M5 9H3" stroke="black" stroke-width="1.33333" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M3 13C3 11.95 3.89474 11.05 5 11" stroke="black" stroke-width="1.33333" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M13 4C13 5.05 12.0857 5.9 11 6" stroke="black" stroke-width="1.33333" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M13 9H11" stroke="black" stroke-width="1.33333" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M11 11C12.1053 11.05 13 11.95 13 13" stroke="black" stroke-width="1.33333" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.8 KiB

View File

@@ -1,4 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3.84265 10.7778C4.39206 11.6001 5.17295 12.241 6.08658 12.6194C7.00021 12.9978 8.00555 13.0969 8.97545 12.9039C9.94535 12.711 10.8363 12.2348 11.5355 11.5355C12.2348 10.8363 12.711 9.94535 12.9039 8.97545C13.0969 8.00555 12.9978 7.00021 12.6194 6.08658C12.241 5.17295 11.6001 4.39206 10.7778 3.84265C9.9556 3.29324 8.9889 3 8 3C6.60219 3.00526 5.26054 3.55068 4.25556 4.52222L3 5.77778" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M3 3V6H6" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 685 B

View File

@@ -1,4 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 13C10.7614 13 13 10.7614 13 8C13 5.23858 10.7614 3 8 3C5.23858 3 3 5.23858 3 8C3 10.7614 5.23858 13 8 13Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M5 5L11 11" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 409 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-scroll-text-icon lucide-scroll-text"><path d="M15 12h-5"/><path d="M15 8h-5"/><path d="M19 17V5a2 2 0 0 0-2-2H4"/><path d="M8 21h12a2 2 0 0 0 2-2v-1a1 1 0 0 0-1-1H11a1 1 0 0 0-1 1v1a2 2 0 1 1-4 0V5a2 2 0 1 0-4 0v2a1 1 0 0 0 1 1h3"/></svg>

Before

Width:  |  Height:  |  Size: 441 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-split-icon lucide-split"><path d="M16 3h5v5"/><path d="M8 3H3v5"/><path d="M12 22v-8.3a4 4 0 0 0-1.172-2.872L3 3"/><path d="m15 9 6-6"/></svg>

Before

Width:  |  Height:  |  Size: 345 B

View File

@@ -34,15 +34,14 @@
"ctrl-q": "zed::Quit",
"f4": "debugger::Start",
"shift-f5": "debugger::Stop",
"ctrl-shift-f5": "debugger::RerunSession",
"ctrl-shift-f5": "debugger::Restart",
"f6": "debugger::Pause",
"f7": "debugger::StepOver",
"ctrl-f11": "debugger::StepInto",
"shift-f11": "debugger::StepOut",
"f11": "zed::ToggleFullScreen",
"ctrl-alt-z": "edit_prediction::RateCompletions",
"ctrl-shift-i": "edit_prediction::ToggleMenu",
"ctrl-alt-l": "lsp_tool::ToggleMenu"
"ctrl-shift-i": "edit_prediction::ToggleMenu"
}
},
{
@@ -244,8 +243,8 @@
"ctrl-alt-e": "agent::RemoveAllContext",
"ctrl-shift-e": "project_panel::ToggleFocus",
"ctrl-shift-enter": "agent::ContinueThread",
"super-ctrl-b": "agent::ToggleBurnMode",
"alt-enter": "agent::ContinueWithBurnMode"
"alt-enter": "agent::ContinueWithBurnMode",
"ctrl-alt-b": "agent::ToggleBurnMode"
}
},
{
@@ -491,27 +490,13 @@
"ctrl-k r": "editor::RevealInFileManager",
"ctrl-k p": "editor::CopyPath",
"ctrl-\\": "pane::SplitRight",
"ctrl-k v": "markdown::OpenPreviewToTheSide",
"ctrl-shift-v": "markdown::OpenPreview",
"ctrl-alt-shift-c": "editor::DisplayCursorNames",
"alt-.": "editor::GoToHunk",
"alt-,": "editor::GoToPreviousHunk"
}
},
{
"context": "Editor && extension == md",
"use_key_equivalents": true,
"bindings": {
"ctrl-k v": "markdown::OpenPreviewToTheSide",
"ctrl-shift-v": "markdown::OpenPreview"
}
},
{
"context": "Editor && extension == svg",
"use_key_equivalents": true,
"bindings": {
"ctrl-k v": "svg::OpenPreviewToTheSide",
"ctrl-shift-v": "svg::OpenPreview"
}
},
{
"context": "Editor && mode == full",
"bindings": {
@@ -557,13 +542,6 @@
"ctrl-b": "workspace::ToggleLeftDock",
"ctrl-j": "workspace::ToggleBottomDock",
"ctrl-alt-y": "workspace::CloseAllDocks",
"ctrl-alt-0": "workspace::ResetActiveDockSize",
// For 0px parameter, uses UI font size value.
"ctrl-alt--": ["workspace::DecreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-=": ["workspace::IncreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-)": "workspace::ResetOpenDocksSize",
"ctrl-alt-_": ["workspace::DecreaseOpenDocksSize", { "px": 0 }],
"ctrl-alt-+": ["workspace::IncreaseOpenDocksSize", { "px": 0 }],
"shift-find": "pane::DeploySearch",
"ctrl-shift-f": "pane::DeploySearch",
"ctrl-shift-h": ["pane::DeploySearch", { "replace_enabled": true }],
@@ -605,9 +583,7 @@
// "foo-bar": ["task::Spawn", { "task_name": "MyTask", "reveal_target": "dock" }]
// or by tag:
// "foo-bar": ["task::Spawn", { "task_tag": "MyTag" }],
"f5": "debugger::Rerun",
"ctrl-f4": "workspace::CloseActiveDock",
"ctrl-w": "workspace::CloseActiveDock"
"f5": "debugger::RerunLastSession"
}
},
{
@@ -710,13 +686,6 @@
"pagedown": "editor::ContextMenuLast"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"up": "editor::SignatureHelpPrevious",
"down": "editor::SignatureHelpNext"
}
},
// Custom bindings
{
"bindings": {
@@ -935,9 +904,7 @@
"context": "BreakpointList",
"bindings": {
"space": "debugger::ToggleEnableBreakpoint",
"backspace": "debugger::UnsetBreakpoint",
"left": "debugger::PreviousBreakpointProperty",
"right": "debugger::NextBreakpointProperty"
"backspace": "debugger::UnsetBreakpoint"
}
},
{
@@ -1083,19 +1050,5 @@
"ctrl-tab": "pane::ActivateNextItem",
"ctrl-shift-tab": "pane::ActivatePreviousItem"
}
},
{
"context": "MarkdownPreview",
"bindings": {
"pageup": "markdown::MovePageUp",
"pagedown": "markdown::MovePageDown"
}
},
{
"context": "KeymapEditor",
"use_key_equivalents": true,
"bindings": {
"ctrl-f": "search::FocusSearch"
}
}
]

View File

@@ -5,10 +5,10 @@
"bindings": {
"f4": "debugger::Start",
"shift-f5": "debugger::Stop",
"shift-cmd-f5": "debugger::RerunSession",
"shift-cmd-f5": "debugger::Restart",
"f6": "debugger::Pause",
"f7": "debugger::StepOver",
"ctrl-f11": "debugger::StepInto",
"f11": "debugger::StepInto",
"shift-f11": "debugger::StepOut",
"home": "menu::SelectFirst",
"shift-pageup": "menu::SelectFirst",
@@ -47,8 +47,7 @@
"fn-f": "zed::ToggleFullScreen",
"ctrl-cmd-f": "zed::ToggleFullScreen",
"ctrl-cmd-z": "edit_prediction::RateCompletions",
"ctrl-cmd-i": "edit_prediction::ToggleMenu",
"ctrl-cmd-l": "lsp_tool::ToggleMenu"
"ctrl-cmd-i": "edit_prediction::ToggleMenu"
}
},
{
@@ -283,9 +282,9 @@
"cmd->": "assistant::QuoteSelection",
"cmd-alt-e": "agent::RemoveAllContext",
"cmd-shift-e": "project_panel::ToggleFocus",
"cmd-ctrl-b": "agent::ToggleBurnMode",
"cmd-shift-enter": "agent::ContinueThread",
"alt-enter": "agent::ContinueWithBurnMode"
"alt-enter": "agent::ContinueWithBurnMode",
"cmd-alt-b": "agent::ToggleBurnMode"
}
},
{
@@ -545,23 +544,9 @@
"cmd-k r": "editor::RevealInFileManager",
"cmd-k p": "editor::CopyPath",
"cmd-\\": "pane::SplitRight",
"ctrl-cmd-c": "editor::DisplayCursorNames"
}
},
{
"context": "Editor && extension == md",
"use_key_equivalents": true,
"bindings": {
"cmd-k v": "markdown::OpenPreviewToTheSide",
"cmd-shift-v": "markdown::OpenPreview"
}
},
{
"context": "Editor && extension == svg",
"use_key_equivalents": true,
"bindings": {
"cmd-k v": "svg::OpenPreviewToTheSide",
"cmd-shift-v": "svg::OpenPreview"
"cmd-shift-v": "markdown::OpenPreview",
"ctrl-cmd-c": "editor::DisplayCursorNames"
}
},
{
@@ -602,7 +587,7 @@
"alt-cmd-o": ["projects::OpenRecent", { "create_new_window": false }],
"ctrl-cmd-o": ["projects::OpenRemote", { "from_existing_connection": false, "create_new_window": false }],
"ctrl-cmd-shift-o": ["projects::OpenRemote", { "from_existing_connection": true, "create_new_window": false }],
"cmd-ctrl-b": "branches::OpenRecent",
"alt-cmd-b": "branches::OpenRecent",
"ctrl-~": "workspace::NewTerminal",
"cmd-s": "workspace::Save",
"cmd-k s": "workspace::SaveWithoutFormat",
@@ -620,17 +605,9 @@
"cmd-8": ["workspace::ActivatePane", 7],
"cmd-9": ["workspace::ActivatePane", 8],
"cmd-b": "workspace::ToggleLeftDock",
"cmd-alt-b": "workspace::ToggleRightDock",
"cmd-r": "workspace::ToggleRightDock",
"cmd-j": "workspace::ToggleBottomDock",
"alt-cmd-y": "workspace::CloseAllDocks",
// For 0px parameter, uses UI font size value.
"ctrl-alt-0": "workspace::ResetActiveDockSize",
"ctrl-alt--": ["workspace::DecreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-=": ["workspace::IncreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-)": "workspace::ResetOpenDocksSize",
"ctrl-alt-_": ["workspace::DecreaseOpenDocksSize", { "px": 0 }],
"ctrl-alt-+": ["workspace::IncreaseOpenDocksSize", { "px": 0 }],
"cmd-shift-f": "pane::DeploySearch",
"cmd-shift-h": ["pane::DeploySearch", { "replace_enabled": true }],
"cmd-shift-t": "pane::ReopenClosedItem",
@@ -659,8 +636,7 @@
"cmd-k shift-up": "workspace::SwapPaneUp",
"cmd-k shift-down": "workspace::SwapPaneDown",
"cmd-shift-x": "zed::Extensions",
"f5": "debugger::Rerun",
"cmd-w": "workspace::CloseActiveDock"
"f5": "debugger::RerunLastSession"
}
},
{
@@ -774,13 +750,6 @@
"pagedown": "editor::ContextMenuLast"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"up": "editor::SignatureHelpPrevious",
"down": "editor::SignatureHelpNext"
}
},
// Custom bindings
{
"use_key_equivalents": true,
@@ -995,9 +964,7 @@
"context": "BreakpointList",
"bindings": {
"space": "debugger::ToggleEnableBreakpoint",
"backspace": "debugger::UnsetBreakpoint",
"left": "debugger::PreviousBreakpointProperty",
"right": "debugger::NextBreakpointProperty"
"backspace": "debugger::UnsetBreakpoint"
}
},
{
@@ -1182,19 +1149,5 @@
"ctrl-tab": "pane::ActivateNextItem",
"ctrl-shift-tab": "pane::ActivatePreviousItem"
}
},
{
"context": "MarkdownPreview",
"bindings": {
"pageup": "markdown::MovePageUp",
"pagedown": "markdown::MovePageDown"
}
},
{
"context": "KeymapEditor",
"use_key_equivalents": true,
"bindings": {
"cmd-f": "search::FocusSearch"
}
}
]

View File

@@ -8,6 +8,7 @@
"ctrl-shift-i": "agent::ToggleFocus",
"ctrl-l": "agent::ToggleFocus",
"ctrl-shift-l": "agent::ToggleFocus",
"ctrl-alt-b": "agent::ToggleFocus",
"ctrl-shift-j": "agent::OpenConfiguration"
}
},
@@ -41,6 +42,7 @@
"ctrl-shift-i": "workspace::ToggleRightDock",
"ctrl-l": "workspace::ToggleRightDock",
"ctrl-shift-l": "workspace::ToggleRightDock",
"ctrl-alt-b": "workspace::ToggleRightDock",
"ctrl-w": "workspace::ToggleRightDock", // technically should close chat
"ctrl-.": "agent::ToggleProfileSelector",
"ctrl-/": "agent::ToggleModelSelector",

View File

@@ -59,8 +59,7 @@
"alt->": "editor::MoveToEnd", // end-of-buffer
"ctrl-l": "editor::ScrollCursorCenterTopBottom", // recenter-top-bottom
"ctrl-s": "buffer_search::Deploy", // isearch-forward
"alt-^": "editor::JoinLines", // join-line
"alt-q": "editor::Rewrap" // fill-paragraph
"alt-^": "editor::JoinLines" // join-line
}
},
{
@@ -98,13 +97,6 @@
"ctrl-n": "editor::ContextMenuNext"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"ctrl-p": "editor::SignatureHelpPrevious",
"ctrl-n": "editor::SignatureHelpNext"
}
},
{
"context": "Workspace",
"bindings": {

View File

@@ -8,6 +8,7 @@
"cmd-shift-i": "agent::ToggleFocus",
"cmd-l": "agent::ToggleFocus",
"cmd-shift-l": "agent::ToggleFocus",
"cmd-alt-b": "agent::ToggleFocus",
"cmd-shift-j": "agent::OpenConfiguration"
}
},
@@ -42,6 +43,7 @@
"cmd-shift-i": "workspace::ToggleRightDock",
"cmd-l": "workspace::ToggleRightDock",
"cmd-shift-l": "workspace::ToggleRightDock",
"cmd-alt-b": "workspace::ToggleRightDock",
"cmd-w": "workspace::ToggleRightDock", // technically should close chat
"cmd-.": "agent::ToggleProfileSelector",
"cmd-/": "agent::ToggleModelSelector",

View File

@@ -59,8 +59,7 @@
"alt->": "editor::MoveToEnd", // end-of-buffer
"ctrl-l": "editor::ScrollCursorCenterTopBottom", // recenter-top-bottom
"ctrl-s": "buffer_search::Deploy", // isearch-forward
"alt-^": "editor::JoinLines", // join-line
"alt-q": "editor::Rewrap" // fill-paragraph
"alt-^": "editor::JoinLines" // join-line
}
},
{
@@ -98,13 +97,6 @@
"ctrl-n": "editor::ContextMenuNext"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"ctrl-p": "editor::SignatureHelpPrevious",
"ctrl-n": "editor::SignatureHelpNext"
}
},
{
"context": "Workspace",
"bindings": {

View File

@@ -210,8 +210,7 @@
"ctrl-w space": "editor::OpenExcerptsSplit",
"ctrl-w g space": "editor::OpenExcerptsSplit",
"ctrl-6": "pane::AlternateFile",
"ctrl-^": "pane::AlternateFile",
".": "vim::Repeat"
"ctrl-^": "pane::AlternateFile"
}
},
{
@@ -220,6 +219,7 @@
"ctrl-[": "editor::Cancel",
"escape": "editor::Cancel",
":": "command_palette::Toggle",
".": "vim::Repeat",
"c": "vim::PushChange",
"shift-c": "vim::ChangeToEndOfLine",
"d": "vim::PushDelete",
@@ -477,13 +477,6 @@
"ctrl-n": "editor::ShowWordCompletions"
}
},
{
"context": "vim_mode == insert && showing_signature_help && !showing_completions",
"bindings": {
"ctrl-p": "editor::SignatureHelpPrevious",
"ctrl-n": "editor::SignatureHelpNext"
}
},
{
"context": "vim_mode == replace",
"bindings": {
@@ -841,6 +834,24 @@
"g g": "menu::SelectFirst"
}
},
{
"context": "menu",
"bindings": {
"j": "menu::SelectNext",
"k": "menu::SelectPrevious",
"shift-g": "menu::SelectLast",
"g g": "menu::SelectFirst"
}
},
{
"context": "Editor && showing_code_actions",
"bindings": {
"j": "editor::ContextMenuNext",
"k": "editor::ContextMenuPrevious",
"shift-g": "editor::ContextMenuLast",
"g g": "editor::ContextMenuFirst"
}
},
{
"context": "GitPanel && ChangesList",
"use_key_equivalents": true,
@@ -856,25 +867,6 @@
"shift-u": "git::UnstageAll"
}
},
{
"context": "Editor && mode == auto_height && VimControl",
"bindings": {
// TODO: Implement search
"/": null,
"?": null,
"#": null,
"*": null,
"n": null,
"shift-n": null
}
},
{
"context": "GitCommit > Editor && VimControl && vim_mode == normal",
"bindings": {
"ctrl-c": "menu::Cancel",
"escape": "menu::Cancel"
}
},
{
"context": "Editor && edit_prediction",
"bindings": {
@@ -886,7 +878,14 @@
{
"context": "MessageEditor > Editor && VimControl",
"bindings": {
"enter": "agent::Chat"
"enter": "agent::Chat",
// TODO: Implement search
"/": null,
"?": null,
"#": null,
"*": null,
"n": null,
"shift-n": null
}
},
{

View File

@@ -746,6 +746,8 @@
"default_width": 380
},
"agent": {
// Version of this setting.
"version": "2",
// Whether the agent is enabled.
"enabled": true,
/// What completion mode to start new threads in, if available. Can be 'normal' or 'burn'.
@@ -1290,8 +1292,6 @@
// Whether or not selecting text in the terminal will automatically
// copy to the system clipboard.
"copy_on_select": false,
// Whether to keep the text selection after copying it to the clipboard
"keep_selection_on_copy": false,
// Whether to show the terminal button in the status bar
"button": true,
// Any key-value pairs added to this list will be added to the terminal's
@@ -1656,6 +1656,7 @@
// Different settings for specific language models.
"language_models": {
"anthropic": {
"version": "1",
"api_url": "https://api.anthropic.com"
},
"google": {
@@ -1665,6 +1666,7 @@
"api_url": "http://localhost:11434"
},
"openai": {
"version": "1",
"api_url": "https://api.openai.com/v1"
},
"open_router": {
@@ -1718,11 +1720,6 @@
// }
// }
},
// Common language server settings.
"global_lsp_settings": {
// Whether to show the LSP servers button in the status bar.
"button": true
},
// Jupyter settings
"jupyter": {
"enabled": true
@@ -1782,8 +1779,7 @@
// `socks5h`. `http` will be used when no scheme is specified.
//
// By default no proxy will be used, or Zed will try get proxy settings from
// environment variables. If certain hosts should not be proxied,
// set the `no_proxy` environment variable and provide a comma-separated list.
// environment variables.
//
// Examples:
// - "proxy": "socks5h://localhost:10808"

View File

@@ -21,7 +21,6 @@ futures.workspace = true
gpui.workspace = true
language.workspace = true
project.workspace = true
proto.workspace = true
smallvec.workspace = true
ui.workspace = true
util.workspace = true

View File

@@ -31,13 +31,7 @@ use workspace::{StatusItemView, Workspace, item::ItemHandle};
const GIT_OPERATION_DELAY: Duration = Duration::from_millis(0);
actions!(
activity_indicator,
[
/// Displays error messages from language servers in the status bar.
ShowErrorMessage
]
);
actions!(activity_indicator, [ShowErrorMessage]);
pub enum Event {
ShowStatus {
@@ -86,13 +80,10 @@ impl ActivityIndicator {
let this = cx.new(|cx| {
let mut status_events = languages.language_server_binary_statuses();
cx.spawn(async move |this, cx| {
while let Some((name, binary_status)) = status_events.next().await {
while let Some((name, status)) = status_events.next().await {
this.update(cx, |this: &mut ActivityIndicator, cx| {
this.statuses.retain(|s| s.name != name);
this.statuses.push(ServerStatus {
name,
status: LanguageServerStatusUpdate::Binary(binary_status),
});
this.statuses.push(ServerStatus { name, status });
cx.notify();
})?;
}
@@ -121,76 +112,8 @@ impl ActivityIndicator {
cx.subscribe(
&project.read(cx).lsp_store(),
|activity_indicator, _, event, cx| match event {
LspStoreEvent::LanguageServerUpdate { name, message, .. } => {
if let proto::update_language_server::Variant::StatusUpdate(status_update) =
message
{
let Some(name) = name.clone() else {
return;
};
let status = match &status_update.status {
Some(proto::status_update::Status::Binary(binary_status)) => {
if let Some(binary_status) =
proto::ServerBinaryStatus::from_i32(*binary_status)
{
let binary_status = match binary_status {
proto::ServerBinaryStatus::None => BinaryStatus::None,
proto::ServerBinaryStatus::CheckingForUpdate => {
BinaryStatus::CheckingForUpdate
}
proto::ServerBinaryStatus::Downloading => {
BinaryStatus::Downloading
}
proto::ServerBinaryStatus::Starting => {
BinaryStatus::Starting
}
proto::ServerBinaryStatus::Stopping => {
BinaryStatus::Stopping
}
proto::ServerBinaryStatus::Stopped => {
BinaryStatus::Stopped
}
proto::ServerBinaryStatus::Failed => {
let Some(error) = status_update.message.clone()
else {
return;
};
BinaryStatus::Failed { error }
}
};
LanguageServerStatusUpdate::Binary(binary_status)
} else {
return;
}
}
Some(proto::status_update::Status::Health(health_status)) => {
if let Some(health) =
proto::ServerHealth::from_i32(*health_status)
{
let health = match health {
proto::ServerHealth::Ok => ServerHealth::Ok,
proto::ServerHealth::Warning => ServerHealth::Warning,
proto::ServerHealth::Error => ServerHealth::Error,
};
LanguageServerStatusUpdate::Health(
health,
status_update.message.clone().map(SharedString::from),
)
} else {
return;
}
}
None => return,
};
activity_indicator.statuses.retain(|s| s.name != name);
activity_indicator
.statuses
.push(ServerStatus { name, status });
}
cx.notify()
}
|_, _, event, cx| match event {
LspStoreEvent::LanguageServerUpdate { .. } => cx.notify(),
_ => {}
},
)
@@ -305,23 +228,9 @@ impl ActivityIndicator {
_: &mut Window,
cx: &mut Context<Self>,
) {
let error_dismissed = if let Some(updater) = &self.auto_updater {
updater.update(cx, |updater, cx| updater.dismiss_error(cx))
} else {
false
};
if error_dismissed {
return;
if let Some(updater) = &self.auto_updater {
updater.update(cx, |updater, cx| updater.dismiss_error(cx));
}
self.project.update(cx, |project, cx| {
if project.last_formatting_failure(cx).is_some() {
project.reset_last_formatting_failure(cx);
true
} else {
false
}
});
}
fn pending_language_server_work<'a>(
@@ -490,12 +399,6 @@ impl ActivityIndicator {
let mut servers_to_clear_statuses = HashSet::<LanguageServerName>::default();
for status in &self.statuses {
match &status.status {
LanguageServerStatusUpdate::Binary(
BinaryStatus::Starting | BinaryStatus::Stopping,
) => {}
LanguageServerStatusUpdate::Binary(BinaryStatus::Stopped) => {
servers_to_clear_statuses.insert(status.name.clone());
}
LanguageServerStatusUpdate::Binary(BinaryStatus::CheckingForUpdate) => {
checking_for_update.push(status.name.clone());
}

View File

@@ -72,7 +72,6 @@ gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true
language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] }
parking_lot.workspace = true
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
use assistant_tool::{AnyTool, ToolSource, ToolWorkingSet, UniqueToolName};
use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
use collections::IndexMap;
use convert_case::{Case, Casing};
use fs::Fs;
@@ -72,7 +72,7 @@ impl AgentProfile {
&self.id
}
pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
return Vec::new();
};
@@ -81,26 +81,23 @@ impl AgentProfile {
.read(cx)
.tools(cx)
.into_iter()
.filter(|(_, tool)| Self::is_enabled(settings, tool.source(), tool.name()))
.filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
.collect()
}
pub fn is_tool_enabled(&self, source: ToolSource, tool_name: String, cx: &App) -> bool {
let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
return false;
};
return Self::is_enabled(settings, source, tool_name);
}
fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
match source {
ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
ToolSource::ContextServer { id } => settings
.context_servers
.get(id.as_ref())
.and_then(|preset| preset.tools.get(name.as_str()).copied())
.unwrap_or(settings.enable_all_context_servers),
ToolSource::ContextServer { id } => {
if settings.enable_all_context_servers {
return true;
}
let Some(preset) = settings.context_servers.get(id.as_ref()) else {
return false;
};
*preset.tools.get(name.as_str()).unwrap_or(&false)
}
}
}
}
@@ -108,7 +105,7 @@ impl AgentProfile {
#[cfg(test)]
mod tests {
use agent_settings::ContextServerPreset;
use assistant_tool::{Tool, ToolRegistry};
use assistant_tool::ToolRegistry;
use collections::IndexMap;
use gpui::SharedString;
use gpui::{AppContext, TestAppContext};
@@ -137,7 +134,7 @@ mod tests {
let mut enabled_tools = cx
.read(|cx| profile.enabled_tools(cx))
.into_iter()
.map(|(_, tool)| tool.name())
.map(|tool| tool.name())
.collect::<Vec<_>>();
enabled_tools.sort();
@@ -174,7 +171,7 @@ mod tests {
let mut enabled_tools = cx
.read(|cx| profile.enabled_tools(cx))
.into_iter()
.map(|(_, tool)| tool.name())
.map(|tool| tool.name())
.collect::<Vec<_>>();
enabled_tools.sort();
@@ -207,7 +204,7 @@ mod tests {
let mut enabled_tools = cx
.read(|cx| profile.enabled_tools(cx))
.into_iter()
.map(|(_, tool)| tool.name())
.map(|tool| tool.name())
.collect::<Vec<_>>();
enabled_tools.sort();
@@ -267,16 +264,10 @@ mod tests {
}
fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
cx.new(|cx| {
cx.new(|_| {
let mut tool_set = ToolWorkingSet::default();
tool_set.insert(
Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")).into(),
cx,
);
tool_set.insert(
Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")).into(),
cx,
);
tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
tool_set
})
}
@@ -296,8 +287,6 @@ mod tests {
}
impl Tool for FakeTool {
type Input = ();
fn name(&self) -> String {
self.name.clone()
}
@@ -316,17 +305,17 @@ mod tests {
unimplemented!()
}
fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
unimplemented!()
}
fn ui_text(&self, _input: &Self::Input) -> String {
fn ui_text(&self, _input: &serde_json::Value) -> String {
unimplemented!()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
_input: serde_json::Value,
_request: Arc<language_model::LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<assistant_tool::ActionLog>,

View File

@@ -29,8 +29,6 @@ impl ContextServerTool {
}
impl Tool for ContextServerTool {
type Input = serde_json::Value;
fn name(&self) -> String {
self.tool.name.clone()
}
@@ -49,7 +47,7 @@ impl Tool for ContextServerTool {
}
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
}
@@ -71,13 +69,13 @@ impl Tool for ContextServerTool {
})
}
fn ui_text(&self, _input: &Self::Input) -> String {
fn ui_text(&self, _input: &serde_json::Value) -> String {
format!("Run MCP tool `{}`", self.tool.name)
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,

File diff suppressed because it is too large Load Diff

View File

@@ -71,7 +71,7 @@ impl Column for DataType {
}
}
const RULES_FILE_NAMES: [&'static str; 9] = [
const RULES_FILE_NAMES: [&'static str; 8] = [
".rules",
".cursorrules",
".windsurfrules",
@@ -80,7 +80,6 @@ const RULES_FILE_NAMES: [&'static str; 9] = [
"CLAUDE.md",
"AGENT.md",
"AGENTS.md",
"GEMINI.md",
];
pub fn init(cx: &mut App) {
@@ -537,8 +536,8 @@ impl ThreadStore {
}
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, cx| {
tool_working_set.remove(&tool_ids, cx);
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
}
}
@@ -569,18 +568,19 @@ impl ThreadStore {
.log_err()
{
let tool_ids = tool_working_set
.update(cx, |tool_working_set, cx| {
tool_working_set.extend(
response.tools.into_iter().map(|tool| {
Arc::new(ContextServerTool::new(
.update(cx, |tool_working_set, _| {
response
.tools
.into_iter()
.map(|tool| {
log::info!("registering context server tool: {:?}", tool.name);
tool_working_set.insert(Arc::new(ContextServerTool::new(
context_server_store.clone(),
server.id(),
tool,
))
.into()
}),
cx,
)
)))
})
.collect::<Vec<_>>()
})
.log_err();

View File

@@ -4,7 +4,7 @@ use crate::{
};
use anyhow::Result;
use assistant_tool::{
AnyTool, AnyToolCard, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
};
use collections::HashMap;
use futures::{FutureExt as _, future::Shared};
@@ -378,7 +378,7 @@ impl ToolUseState {
ui_text: impl Into<Arc<str>>,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
tool: AnyTool,
tool: Arc<dyn Tool>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
let ui_text = ui_text.into();
@@ -533,7 +533,7 @@ pub struct Confirmation {
pub input: serde_json::Value,
pub ui_text: Arc<str>,
pub request: Arc<LanguageModelRequest>,
pub tool: AnyTool,
pub tool: Arc<dyn Tool>,
}
#[derive(Debug, Clone)]

View File

@@ -12,10 +12,17 @@ workspace = true
path = "src/agent_settings.rs"
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
collections.workspace = true
gpui.workspace = true
language_model.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
deepseek = { workspace = true, features = ["schemars"] }
mistral = { workspace = true, features = ["schemars"] }
schemars.workspace = true
serde.workspace = true
settings.workspace = true

View File

@@ -2,14 +2,19 @@ mod agent_profile;
use std::sync::Arc;
use ::open_ai::Model as OpenAiModel;
use anthropic::Model as AnthropicModel;
use anyhow::{Result, bail};
use collections::IndexMap;
use deepseek::Model as DeepseekModel;
use gpui::{App, Pixels, SharedString};
use language_model::LanguageModel;
use schemars::{JsonSchema, json_schema};
use lmstudio::Model as LmStudioModel;
use mistral::Model as MistralModel;
use ollama::Model as OllamaModel;
use schemars::{JsonSchema, schema::Schema};
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use std::borrow::Cow;
pub use crate::agent_profile::*;
@@ -43,6 +48,45 @@ pub enum NotifyWhenAgentWaiting {
Never,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(tag = "name", rename_all = "snake_case")]
#[schemars(deny_unknown_fields)]
pub enum AgentProviderContentV1 {
#[serde(rename = "zed.dev")]
ZedDotDev { default_model: Option<String> },
#[serde(rename = "openai")]
OpenAi {
default_model: Option<OpenAiModel>,
api_url: Option<String>,
available_models: Option<Vec<OpenAiModel>>,
},
#[serde(rename = "anthropic")]
Anthropic {
default_model: Option<AnthropicModel>,
api_url: Option<String>,
},
#[serde(rename = "ollama")]
Ollama {
default_model: Option<OllamaModel>,
api_url: Option<String>,
},
#[serde(rename = "lmstudio")]
LmStudio {
default_model: Option<LmStudioModel>,
api_url: Option<String>,
},
#[serde(rename = "deepseek")]
DeepSeek {
default_model: Option<DeepseekModel>,
api_url: Option<String>,
},
#[serde(rename = "mistral")]
Mistral {
default_model: Option<MistralModel>,
api_url: Option<String>,
},
}
#[derive(Default, Clone, Debug)]
pub struct AgentSettings {
pub enabled: bool,
@@ -50,7 +94,7 @@ pub struct AgentSettings {
pub dock: AgentDockPosition,
pub default_width: Pixels,
pub default_height: Pixels,
pub default_model: Option<LanguageModelSelection>,
pub default_model: LanguageModelSelection,
pub inline_assistant_model: Option<LanguageModelSelection>,
pub commit_message_model: Option<LanguageModelSelection>,
pub thread_summary_model: Option<LanguageModelSelection>,
@@ -124,56 +168,366 @@ impl LanguageModelParameters {
}
}
/// Agent panel settings
#[derive(Clone, Serialize, Deserialize, Debug, Default)]
pub struct AgentSettingsContent {
#[serde(flatten)]
pub inner: Option<AgentSettingsContentInner>,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum AgentSettingsContentInner {
Versioned(Box<VersionedAgentSettingsContent>),
Legacy(LegacyAgentSettingsContent),
}
impl AgentSettingsContentInner {
fn for_v2(content: AgentSettingsContentV2) -> Self {
AgentSettingsContentInner::Versioned(Box::new(VersionedAgentSettingsContent::V2(content)))
}
}
impl JsonSchema for AgentSettingsContent {
fn schema_name() -> String {
VersionedAgentSettingsContent::schema_name()
}
fn json_schema(r#gen: &mut schemars::r#gen::SchemaGenerator) -> Schema {
VersionedAgentSettingsContent::json_schema(r#gen)
}
fn is_referenceable() -> bool {
VersionedAgentSettingsContent::is_referenceable()
}
}
impl AgentSettingsContent {
pub fn is_version_outdated(&self) -> bool {
match &self.inner {
Some(AgentSettingsContentInner::Versioned(settings)) => match **settings {
VersionedAgentSettingsContent::V1(_) => true,
VersionedAgentSettingsContent::V2(_) => false,
},
Some(AgentSettingsContentInner::Legacy(_)) => true,
None => false,
}
}
fn upgrade(&self) -> AgentSettingsContentV2 {
match &self.inner {
Some(AgentSettingsContentInner::Versioned(settings)) => match **settings {
VersionedAgentSettingsContent::V1(ref settings) => AgentSettingsContentV2 {
enabled: settings.enabled,
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_width,
default_model: settings
.provider
.clone()
.and_then(|provider| match provider {
AgentProviderContentV1::ZedDotDev { default_model } => default_model
.map(|model| LanguageModelSelection {
provider: "zed.dev".into(),
model,
}),
AgentProviderContentV1::OpenAi { default_model, .. } => default_model
.map(|model| LanguageModelSelection {
provider: "openai".into(),
model: model.id().to_string(),
}),
AgentProviderContentV1::Anthropic { default_model, .. } => {
default_model.map(|model| LanguageModelSelection {
provider: "anthropic".into(),
model: model.id().to_string(),
})
}
AgentProviderContentV1::Ollama { default_model, .. } => default_model
.map(|model| LanguageModelSelection {
provider: "ollama".into(),
model: model.id().to_string(),
}),
AgentProviderContentV1::LmStudio { default_model, .. } => default_model
.map(|model| LanguageModelSelection {
provider: "lmstudio".into(),
model: model.id().to_string(),
}),
AgentProviderContentV1::DeepSeek { default_model, .. } => default_model
.map(|model| LanguageModelSelection {
provider: "deepseek".into(),
model: model.id().to_string(),
}),
AgentProviderContentV1::Mistral { default_model, .. } => default_model
.map(|model| LanguageModelSelection {
provider: "mistral".into(),
model: model.id().to_string(),
}),
}),
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
default_profile: None,
default_view: None,
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
single_file_review: None,
model_parameters: Vec::new(),
preferred_completion_mode: None,
enable_feedback: None,
play_sound_when_agent_done: None,
},
VersionedAgentSettingsContent::V2(ref settings) => settings.clone(),
},
Some(AgentSettingsContentInner::Legacy(settings)) => AgentSettingsContentV2 {
enabled: None,
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_height,
default_model: Some(LanguageModelSelection {
provider: "openai".into(),
model: settings
.default_open_ai_model
.clone()
.unwrap_or_default()
.id()
.to_string(),
}),
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
default_profile: None,
default_view: None,
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
single_file_review: None,
model_parameters: Vec::new(),
preferred_completion_mode: None,
enable_feedback: None,
play_sound_when_agent_done: None,
},
None => AgentSettingsContentV2::default(),
}
}
pub fn set_dock(&mut self, dock: AgentDockPosition) {
self.dock = Some(dock);
match &mut self.inner {
Some(AgentSettingsContentInner::Versioned(settings)) => match **settings {
VersionedAgentSettingsContent::V1(ref mut settings) => {
settings.dock = Some(dock);
}
VersionedAgentSettingsContent::V2(ref mut settings) => {
settings.dock = Some(dock);
}
},
Some(AgentSettingsContentInner::Legacy(settings)) => {
settings.dock = Some(dock);
}
None => {
self.inner = Some(AgentSettingsContentInner::for_v2(AgentSettingsContentV2 {
dock: Some(dock),
..Default::default()
}))
}
}
}
pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
let model = language_model.id().0.to_string();
let provider = language_model.provider_id().0.to_string();
self.default_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
match &mut self.inner {
Some(AgentSettingsContentInner::Versioned(settings)) => match **settings {
VersionedAgentSettingsContent::V1(ref mut settings) => match provider.as_ref() {
"zed.dev" => {
log::warn!("attempted to set zed.dev model on outdated settings");
}
"anthropic" => {
let api_url = match &settings.provider {
Some(AgentProviderContentV1::Anthropic { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AgentProviderContentV1::Anthropic {
default_model: AnthropicModel::from_id(&model).ok(),
api_url,
});
}
"ollama" => {
let api_url = match &settings.provider {
Some(AgentProviderContentV1::Ollama { api_url, .. }) => api_url.clone(),
_ => None,
};
settings.provider = Some(AgentProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(
&model,
None,
None,
Some(language_model.supports_tools()),
Some(language_model.supports_images()),
None,
)),
api_url,
});
}
"lmstudio" => {
let api_url = match &settings.provider {
Some(AgentProviderContentV1::LmStudio { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AgentProviderContentV1::LmStudio {
default_model: Some(lmstudio::Model::new(
&model, None, None, false, false,
)),
api_url,
});
}
"openai" => {
let (api_url, available_models) = match &settings.provider {
Some(AgentProviderContentV1::OpenAi {
api_url,
available_models,
..
}) => (api_url.clone(), available_models.clone()),
_ => (None, None),
};
settings.provider = Some(AgentProviderContentV1::OpenAi {
default_model: OpenAiModel::from_id(&model).ok(),
api_url,
available_models,
});
}
"deepseek" => {
let api_url = match &settings.provider {
Some(AgentProviderContentV1::DeepSeek { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AgentProviderContentV1::DeepSeek {
default_model: DeepseekModel::from_id(&model).ok(),
api_url,
});
}
_ => {}
},
VersionedAgentSettingsContent::V2(ref mut settings) => {
settings.default_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
}
},
Some(AgentSettingsContentInner::Legacy(settings)) => {
if let Ok(model) = OpenAiModel::from_id(&language_model.id().0) {
settings.default_open_ai_model = Some(model);
}
}
None => {
self.inner = Some(AgentSettingsContentInner::for_v2(AgentSettingsContentV2 {
default_model: Some(LanguageModelSelection {
provider: provider.into(),
model,
}),
..Default::default()
}));
}
}
}
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
self.inline_assistant_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
self.v2_setting(|setting| {
setting.inline_assistant_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
Ok(())
})
.ok();
}
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
self.commit_message_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
self.v2_setting(|setting| {
setting.commit_message_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
Ok(())
})
.ok();
}
pub fn v2_setting(
&mut self,
f: impl FnOnce(&mut AgentSettingsContentV2) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
match self.inner.get_or_insert_with(|| {
AgentSettingsContentInner::for_v2(AgentSettingsContentV2 {
..Default::default()
})
}) {
AgentSettingsContentInner::Versioned(boxed) => {
if let VersionedAgentSettingsContent::V2(ref mut settings) = **boxed {
f(settings)
} else {
Ok(())
}
}
_ => Ok(()),
}
}
pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
self.thread_summary_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
self.v2_setting(|setting| {
setting.thread_summary_model = Some(LanguageModelSelection {
provider: provider.into(),
model,
});
Ok(())
})
.ok();
}
pub fn set_always_allow_tool_actions(&mut self, allow: bool) {
self.always_allow_tool_actions = Some(allow);
self.v2_setting(|setting| {
setting.always_allow_tool_actions = Some(allow);
Ok(())
})
.ok();
}
pub fn set_play_sound_when_agent_done(&mut self, allow: bool) {
self.play_sound_when_agent_done = Some(allow);
self.v2_setting(|setting| {
setting.play_sound_when_agent_done = Some(allow);
Ok(())
})
.ok();
}
pub fn set_single_file_review(&mut self, allow: bool) {
self.single_file_review = Some(allow);
self.v2_setting(|setting| {
setting.single_file_review = Some(allow);
Ok(())
})
.ok();
}
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
self.default_profile = Some(profile_id);
self.v2_setting(|setting| {
setting.default_profile = Some(profile_id);
Ok(())
})
.ok();
}
pub fn create_profile(
@@ -181,38 +535,79 @@ impl AgentSettingsContent {
profile_id: AgentProfileId,
profile_settings: AgentProfileSettings,
) -> Result<()> {
let profiles = self.profiles.get_or_insert_default();
if profiles.contains_key(&profile_id) {
bail!("profile with ID '{profile_id}' already exists");
}
self.v2_setting(|settings| {
let profiles = settings.profiles.get_or_insert_default();
if profiles.contains_key(&profile_id) {
bail!("profile with ID '{profile_id}' already exists");
}
profiles.insert(
profile_id,
AgentProfileContent {
name: profile_settings.name.into(),
tools: profile_settings.tools,
enable_all_context_servers: Some(profile_settings.enable_all_context_servers),
context_servers: profile_settings
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
},
);
profiles.insert(
profile_id,
AgentProfileContent {
name: profile_settings.name.into(),
tools: profile_settings.tools,
enable_all_context_servers: Some(profile_settings.enable_all_context_servers),
context_servers: profile_settings
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
},
);
Ok(())
Ok(())
})
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
#[serde(tag = "version")]
#[schemars(deny_unknown_fields)]
pub enum VersionedAgentSettingsContent {
#[serde(rename = "1")]
V1(AgentSettingsContentV1),
#[serde(rename = "2")]
V2(AgentSettingsContentV2),
}
impl Default for VersionedAgentSettingsContent {
fn default() -> Self {
Self::V2(AgentSettingsContentV2 {
enabled: None,
button: None,
dock: None,
default_width: None,
default_height: None,
default_model: None,
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
default_profile: None,
default_view: None,
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
single_file_review: None,
model_parameters: Vec::new(),
preferred_completion_mode: None,
enable_feedback: None,
play_sound_when_agent_done: None,
})
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug, Default)]
pub struct AgentSettingsContent {
#[schemars(deny_unknown_fields)]
pub struct AgentSettingsContentV2 {
/// Whether the Agent is enabled.
///
/// Default: true
@@ -321,27 +716,28 @@ pub struct LanguageModelSelection {
pub struct LanguageModelProviderSetting(pub String);
impl JsonSchema for LanguageModelProviderSetting {
fn schema_name() -> Cow<'static, str> {
fn schema_name() -> String {
"LanguageModelProviderSetting".into()
}
fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema {
json_schema!({
"enum": [
"anthropic",
"amazon-bedrock",
"google",
"lmstudio",
"ollama",
"openai",
"zed.dev",
"copilot_chat",
"deepseek",
"openrouter",
"mistral",
"vercel"
]
})
fn json_schema(_: &mut schemars::r#gen::SchemaGenerator) -> Schema {
schemars::schema::SchemaObject {
enum_values: Some(vec![
"anthropic".into(),
"amazon-bedrock".into(),
"google".into(),
"lmstudio".into(),
"ollama".into(),
"openai".into(),
"zed.dev".into(),
"copilot_chat".into(),
"deepseek".into(),
"openrouter".into(),
"mistral".into(),
]),
..Default::default()
}
.into()
}
}
@@ -357,6 +753,15 @@ impl From<&str> for LanguageModelProviderSetting {
}
}
impl Default for LanguageModelSelection {
fn default() -> Self {
Self {
provider: LanguageModelProviderSetting("openai".to_string()),
model: "gpt-4".to_string(),
}
}
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AgentProfileContent {
pub name: Arc<str>,
@@ -373,6 +778,65 @@ pub struct ContextServerPresetContent {
pub tools: IndexMap<Arc<str>, bool>,
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct AgentSettingsContentV1 {
/// Whether the Agent is enabled.
///
/// Default: true
enabled: Option<bool>,
/// Whether to show the Agent panel button in the status bar.
///
/// Default: true
button: Option<bool>,
/// Where to dock the Agent.
///
/// Default: right
dock: Option<AgentDockPosition>,
/// Default width in pixels when the Agent is docked to the left or right.
///
/// Default: 640
default_width: Option<f32>,
/// Default height in pixels when the Agent is docked to the bottom.
///
/// Default: 320
default_height: Option<f32>,
/// The provider of the Agent service.
///
/// This can be "openai", "anthropic", "ollama", "lmstudio", "deepseek", "zed.dev"
/// each with their respective default models and configurations.
provider: Option<AgentProviderContentV1>,
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct LegacyAgentSettingsContent {
/// Whether to show the Agent panel button in the status bar.
///
/// Default: true
pub button: Option<bool>,
/// Where to dock the Agent.
///
/// Default: right
pub dock: Option<AgentDockPosition>,
/// Default width in pixels when the Agent is docked to the left or right.
///
/// Default: 640
pub default_width: Option<f32>,
/// Default height in pixels when the Agent is docked to the bottom.
///
/// Default: 320
pub default_height: Option<f32>,
/// The default OpenAI model to use when creating new chats.
///
/// Default: gpt-4-1106-preview
pub default_open_ai_model: Option<OpenAiModel>,
/// OpenAI API base URL to use when creating new chats.
///
/// Default: <https://api.openai.com/v1>
pub openai_api_url: Option<String>,
}
impl Settings for AgentSettings {
const KEY: Option<&'static str> = Some("agent");
@@ -389,6 +853,11 @@ impl Settings for AgentSettings {
let mut settings = AgentSettings::default();
for value in sources.defaults_and_customizations() {
if value.is_version_outdated() {
settings.using_outdated_settings_version = true;
}
let value = value.upgrade();
merge(&mut settings.enabled, value.enabled);
merge(&mut settings.button, value.button);
merge(&mut settings.dock, value.dock);
@@ -400,26 +869,17 @@ impl Settings for AgentSettings {
&mut settings.default_height,
value.default_height.map(Into::into),
);
settings.default_model = value
.default_model
.clone()
.or(settings.default_model.take());
merge(&mut settings.default_model, value.default_model);
settings.inline_assistant_model = value
.inline_assistant_model
.clone()
.or(settings.inline_assistant_model.take());
settings.commit_message_model = value
.clone()
.commit_message_model
.or(settings.commit_message_model.take());
settings.thread_summary_model = value
.clone()
.thread_summary_model
.or(settings.thread_summary_model.take());
merge(
&mut settings.inline_alternatives,
value.inline_alternatives.clone(),
);
merge(&mut settings.inline_alternatives, value.inline_alternatives);
merge(
&mut settings.always_allow_tool_actions,
value.always_allow_tool_actions,
@@ -434,7 +894,7 @@ impl Settings for AgentSettings {
);
merge(&mut settings.stream_edits, value.stream_edits);
merge(&mut settings.single_file_review, value.single_file_review);
merge(&mut settings.default_profile, value.default_profile.clone());
merge(&mut settings.default_profile, value.default_profile);
merge(&mut settings.default_view, value.default_view);
merge(
&mut settings.preferred_completion_mode,
@@ -446,24 +906,24 @@ impl Settings for AgentSettings {
.model_parameters
.extend_from_slice(&value.model_parameters);
if let Some(profiles) = value.profiles.as_ref() {
if let Some(profiles) = value.profiles {
settings
.profiles
.extend(profiles.into_iter().map(|(id, profile)| {
(
id.clone(),
id,
AgentProfileSettings {
name: profile.name.clone().into(),
tools: profile.tools.clone(),
name: profile.name.into(),
tools: profile.tools,
enable_all_context_servers: profile
.enable_all_context_servers
.unwrap_or_default(),
context_servers: profile
.context_servers
.iter()
.into_iter()
.map(|(context_server_id, preset)| {
(
context_server_id.clone(),
context_server_id,
ContextServerPreset {
tools: preset.tools.clone(),
},
@@ -484,8 +944,28 @@ impl Settings for AgentSettings {
.read_value("chat.agent.enabled")
.and_then(|b| b.as_bool())
{
current.enabled = Some(b);
current.button = Some(b);
match &mut current.inner {
Some(AgentSettingsContentInner::Versioned(versioned)) => match versioned.as_mut() {
VersionedAgentSettingsContent::V1(setting) => {
setting.enabled = Some(b);
setting.button = Some(b);
}
VersionedAgentSettingsContent::V2(setting) => {
setting.enabled = Some(b);
setting.button = Some(b);
}
},
Some(AgentSettingsContentInner::Legacy(setting)) => setting.button = Some(b),
None => {
current.inner =
Some(AgentSettingsContentInner::for_v2(AgentSettingsContentV2 {
enabled: Some(b),
button: Some(b),
..Default::default()
}));
}
}
}
}
}
@@ -495,3 +975,149 @@ fn merge<T>(target: &mut T, value: Option<T>) {
*target = value;
}
}
#[cfg(test)]
mod tests {
use fs::Fs;
use gpui::{ReadGlobal, TestAppContext};
use settings::SettingsStore;
use super::*;
#[gpui::test]
async fn test_deserialize_agent_settings_with_version(cx: &mut TestAppContext) {
let fs = fs::FakeFs::new(cx.executor().clone());
fs.create_dir(paths::settings_file().parent().unwrap())
.await
.unwrap();
cx.update(|cx| {
let test_settings = settings::SettingsStore::test(cx);
cx.set_global(test_settings);
AgentSettings::register(cx);
});
cx.update(|cx| {
assert!(!AgentSettings::get_global(cx).using_outdated_settings_version);
assert_eq!(
AgentSettings::get_global(cx).default_model,
LanguageModelSelection {
provider: "zed.dev".into(),
model: "claude-sonnet-4".into(),
}
);
});
cx.update(|cx| {
settings::SettingsStore::global(cx).update_settings_file::<AgentSettings>(
fs.clone(),
|settings, _| {
*settings = AgentSettingsContent {
inner: Some(AgentSettingsContentInner::for_v2(AgentSettingsContentV2 {
default_model: Some(LanguageModelSelection {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enabled: None,
button: None,
dock: None,
default_width: None,
default_height: None,
default_profile: None,
default_view: None,
profiles: None,
always_allow_tool_actions: None,
play_sound_when_agent_done: None,
notify_when_agent_waiting: None,
stream_edits: None,
single_file_review: None,
enable_feedback: None,
model_parameters: Vec::new(),
preferred_completion_mode: None,
})),
}
},
);
});
cx.run_until_parked();
let raw_settings_value = fs.load(paths::settings_file()).await.unwrap();
assert!(raw_settings_value.contains(r#""version": "2""#));
#[derive(Debug, Deserialize)]
struct AgentSettingsTest {
agent: AgentSettingsContent,
}
let agent_settings: AgentSettingsTest =
serde_json_lenient::from_str(&raw_settings_value).unwrap();
assert!(!agent_settings.agent.is_version_outdated());
}
#[gpui::test]
async fn test_load_settings_from_old_key(cx: &mut TestAppContext) {
let fs = fs::FakeFs::new(cx.executor().clone());
fs.create_dir(paths::settings_file().parent().unwrap())
.await
.unwrap();
cx.update(|cx| {
let mut test_settings = settings::SettingsStore::test(cx);
let user_settings_content = r#"{
"assistant": {
"enabled": true,
"version": "2",
"default_model": {
"provider": "zed.dev",
"model": "gpt-99"
},
}}"#;
test_settings
.set_user_settings(user_settings_content, cx)
.unwrap();
cx.set_global(test_settings);
AgentSettings::register(cx);
});
cx.run_until_parked();
let agent_settings = cx.update(|cx| AgentSettings::get_global(cx).clone());
assert!(agent_settings.enabled);
assert!(!agent_settings.using_outdated_settings_version);
assert_eq!(agent_settings.default_model.model, "gpt-99");
cx.update_global::<SettingsStore, _>(|settings_store, cx| {
settings_store.update_user_settings::<AgentSettings>(cx, |settings| {
*settings = AgentSettingsContent {
inner: Some(AgentSettingsContentInner::for_v2(AgentSettingsContentV2 {
enabled: Some(false),
default_model: Some(LanguageModelSelection {
provider: "xai".to_owned().into(),
model: "grok".to_owned(),
}),
..Default::default()
})),
};
});
});
cx.run_until_parked();
let settings = cx.update(|cx| SettingsStore::global(cx).raw_user_settings().clone());
#[derive(Debug, Deserialize)]
struct AgentSettingsTest {
assistant: AgentSettingsContent,
agent: Option<serde_json_lenient::Value>,
}
let agent_settings: AgentSettingsTest = serde_json::from_value(settings).unwrap();
assert!(agent_settings.agent.is_none());
}
}

View File

@@ -1,7 +1,9 @@
use crate::context_picker::{ContextPicker, MentionLink};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::message_editor::{extract_message_creases, insert_message_creases};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::ui::{
AddedContext, AgentNotification, AgentNotificationEvent, AnimatedLabel, ContextPill,
};
use crate::{AgentPanel, ModelUsageContext};
use agent::{
ContextStore, LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, TextThreadStore,
@@ -17,7 +19,7 @@ use audio::{Audio, Sound};
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
use editor::scroll::Autoscroll;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer, SelectionEffects};
use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer};
use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardEntry,
ClipboardItem, DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla,
@@ -45,8 +47,8 @@ use std::time::Duration;
use text::ToPoint;
use theme::ThemeSettings;
use ui::{
Banner, Disclosure, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize,
Tooltip, prelude::*,
Disclosure, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize, Tooltip,
prelude::*,
};
use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
@@ -56,7 +58,6 @@ use zed_llm_client::CompletionIntent;
const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container";
const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1;
const RESPONSE_PADDING_X: Pixels = px(19.);
pub struct ActiveThread {
context_store: Entity<ContextStore>,
@@ -203,7 +204,7 @@ pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle
MarkdownStyle {
base_text_style: text_style.clone(),
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().colors().element_selection_background,
selection_background_color: cx.theme().players().local().selection,
code_block_overflow_x_scroll: true,
table_overflow_x_scroll: true,
heading_level_styles: Some(HeadingLevelStyles {
@@ -300,7 +301,7 @@ fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
MarkdownStyle {
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().colors().element_selection_background,
selection_background_color: cx.theme().players().local().selection,
code_block_overflow_x_scroll: false,
code_block: StyleRefinement {
margin: EdgesRefinement::default(),
@@ -688,12 +689,9 @@ fn open_markdown_link(
})
.context("Could not find matching symbol")?;
editor.change_selections(
SelectionEffects::scroll(Autoscroll::center()),
window,
cx,
|s| s.select_anchor_ranges([symbol_range.start..symbol_range.start]),
);
editor.change_selections(Some(Autoscroll::center()), window, cx, |s| {
s.select_anchor_ranges([symbol_range.start..symbol_range.start])
});
anyhow::Ok(())
})
})
@@ -710,15 +708,10 @@ fn open_markdown_link(
.downcast::<Editor>()
.context("Item is not an editor")?;
active_editor.update_in(cx, |editor, window, cx| {
editor.change_selections(
SelectionEffects::scroll(Autoscroll::center()),
window,
cx,
|s| {
s.select_ranges([Point::new(line_range.start as u32, 0)
..Point::new(line_range.start as u32, 0)])
},
);
editor.change_selections(Some(Autoscroll::center()), window, cx, |s| {
s.select_ranges([Point::new(line_range.start as u32, 0)
..Point::new(line_range.start as u32, 0)])
});
anyhow::Ok(())
})
})
@@ -816,12 +809,7 @@ impl ActiveThread {
};
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
let rendered_message = RenderedMessage::from_segments(
&message.segments,
this.language_registry.clone(),
cx,
);
this.push_rendered_message(message.id, rendered_message);
this.push_message(&message.id, &message.segments, window, cx);
for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
this.render_tool_use_markdown(
@@ -887,11 +875,36 @@ impl ActiveThread {
&self.text_thread_store
}
fn push_rendered_message(&mut self, id: MessageId, rendered_message: RenderedMessage) {
fn push_message(
&mut self,
id: &MessageId,
segments: &[MessageSegment],
_window: &mut Window,
cx: &mut Context<Self>,
) {
let old_len = self.messages.len();
self.messages.push(id);
self.messages.push(*id);
self.list_state.splice(old_len..old_len, 1);
self.rendered_messages_by_id.insert(id, rendered_message);
let rendered_message =
RenderedMessage::from_segments(segments, self.language_registry.clone(), cx);
self.rendered_messages_by_id.insert(*id, rendered_message);
}
fn edited_message(
&mut self,
id: &MessageId,
segments: &[MessageSegment],
_window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
return;
};
self.list_state.splice(index..index + 1, 1);
let rendered_message =
RenderedMessage::from_segments(segments, self.language_registry.clone(), cx);
self.rendered_messages_by_id.insert(*id, rendered_message);
}
fn deleted_message(&mut self, id: &MessageId) {
@@ -1024,45 +1037,31 @@ impl ActiveThread {
}
}
ThreadEvent::MessageAdded(message_id) => {
self.clear_last_error();
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
thread.message(*message_id).map(|message| {
RenderedMessage::from_segments(
&message.segments,
self.language_registry.clone(),
cx,
)
})
}) {
self.push_rendered_message(*message_id, rendered_message);
if let Some(message_segments) = self
.thread
.read(cx)
.message(*message_id)
.map(|message| message.segments.clone())
{
self.push_message(message_id, &message_segments, window, cx);
}
self.save_thread(cx);
cx.notify();
}
ThreadEvent::MessageEdited(message_id) => {
self.clear_last_error();
if let Some(index) = self.messages.iter().position(|id| id == message_id) {
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
thread.message(*message_id).map(|message| {
let mut rendered_message = RenderedMessage {
language_registry: self.language_registry.clone(),
segments: Vec::with_capacity(message.segments.len()),
};
for segment in &message.segments {
rendered_message.push_segment(segment, cx);
}
rendered_message
})
}) {
self.list_state.splice(index..index + 1, 1);
self.rendered_messages_by_id
.insert(*message_id, rendered_message);
self.scroll_to_bottom(cx);
self.save_thread(cx);
cx.notify();
}
if let Some(message_segments) = self
.thread
.read(cx)
.message(*message_id)
.map(|message| message.segments.clone())
{
self.edited_message(message_id, &message_segments, window, cx);
}
self.scroll_to_bottom(cx);
self.save_thread(cx);
cx.notify();
}
ThreadEvent::MessageDeleted(message_id) => {
self.deleted_message(message_id);
@@ -1149,9 +1148,6 @@ impl ActiveThread {
self.save_thread(cx);
cx.notify();
}
ThreadEvent::RetriesFailed { message } => {
self.show_notification(message, ui::IconName::Warning, window, cx);
}
}
}
@@ -1315,11 +1311,17 @@ impl ActiveThread {
fn start_editing_message(
&mut self,
message_id: MessageId,
message_text: impl Into<Arc<str>>,
message_segments: &[MessageSegment],
message_creases: &[MessageCrease],
window: &mut Window,
cx: &mut Context<Self>,
) {
// User message should always consist of a single text segment,
// therefore we can skip returning early if it's not a text segment.
let Some(MessageSegment::Text(message_text)) = message_segments.first() else {
return;
};
let editor = crate::message_editor::create_editor(
self.workspace.clone(),
self.context_store.downgrade(),
@@ -1331,7 +1333,7 @@ impl ActiveThread {
cx,
);
editor.update(cx, |editor, cx| {
editor.set_text(message_text, window, cx);
editor.set_text(message_text.clone(), window, cx);
insert_message_creases(editor, message_creases, &self.context_store, window, cx);
editor.focus_handle(cx).focus(window);
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
@@ -1818,7 +1820,7 @@ impl ActiveThread {
.my_3()
.mx_5()
.when(is_generating_stale || message.is_hidden, |this| {
this.child(LoadingLabel::new("").size(LabelSize::Small))
this.child(AnimatedLabel::new("").size(LabelSize::Small))
})
});
@@ -1826,6 +1828,8 @@ impl ActiveThread {
return div().children(loading_dots).into_any();
}
let message_creases = message.creases.clone();
let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
return Empty.into_any();
};
@@ -1847,10 +1851,9 @@ impl ActiveThread {
.filter(|(id, _)| *id == message_id)
.map(|(_, state)| state);
let (editor_bg_color, panel_bg) = {
let colors = cx.theme().colors();
(colors.editor_background, colors.panel_background)
};
let colors = cx.theme().colors();
let editor_bg_color = colors.editor_background;
let panel_bg = colors.panel_background;
let open_as_markdown = IconButton::new(("open-as-markdown", ix), IconName::DocumentText)
.icon_size(IconSize::XSmall)
@@ -1875,6 +1878,9 @@ impl ActiveThread {
this.scroll_to_top(cx);
}));
// For all items that should be aligned with the LLM's response.
const RESPONSE_PADDING_X: Pixels = px(19.);
let show_feedback = thread.is_turn_end(ix);
let feedback_container = h_flex()
.group("feedback_container")
@@ -2035,162 +2041,137 @@ impl ActiveThread {
}
});
let styled_message = if message.ui_only {
self.render_ui_notification(message_content, ix, cx)
} else {
match message.role {
Role::User => {
let colors = cx.theme().colors();
let styled_message = match message.role {
Role::User => v_flex()
.id(("message-container", ix))
.pt_2()
.pl_2()
.pr_2p5()
.pb_4()
.child(
v_flex()
.id(("message-container", ix))
.pt_2()
.pl_2()
.pr_2p5()
.pb_4()
.id(("user-message", ix))
.bg(editor_bg_color)
.rounded_lg()
.shadow_md()
.border_1()
.border_color(colors.border)
.hover(|hover| hover.border_color(colors.text_accent.opacity(0.5)))
.child(
v_flex()
.id(("user-message", ix))
.bg(editor_bg_color)
.rounded_lg()
.shadow_md()
.border_1()
.border_color(colors.border)
.hover(|hover| hover.border_color(colors.text_accent.opacity(0.5)))
.child(
v_flex()
.p_2p5()
.gap_1()
.children(message_content)
.when_some(editing_message_state, |this, state| {
let focus_handle = state.editor.focus_handle(cx).clone();
.p_2p5()
.gap_1()
.children(message_content)
.when_some(editing_message_state, |this, state| {
let focus_handle = state.editor.focus_handle(cx).clone();
this.child(
this.child(
h_flex()
.w_full()
.gap_1()
.justify_between()
.flex_wrap()
.child(
h_flex()
.w_full()
.gap_1()
.justify_between()
.flex_wrap()
.gap_1p5()
.child(
h_flex()
.gap_1p5()
div()
.opacity(0.8)
.child(
div()
.opacity(0.8)
.child(
Icon::new(IconName::Warning)
.size(IconSize::Indicator)
.color(Color::Warning)
),
)
.child(
Label::new("Editing will restart the thread from this point.")
.color(Color::Muted)
.size(LabelSize::XSmall),
Icon::new(IconName::Warning)
.size(IconSize::Indicator)
.color(Color::Warning)
),
)
.child(
h_flex()
.gap_0p5()
.child(
IconButton::new(
"cancel-edit-message",
IconName::Close,
)
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Error)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Cancel Edit",
&menu::Cancel,
&focus_handle,
window,
cx,
)
}
})
.on_click(cx.listener(Self::handle_cancel_click)),
)
.child(
IconButton::new(
"confirm-edit-message",
IconName::Return,
)
.disabled(state.editor.read(cx).is_empty(cx))
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Muted)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Regenerate",
&menu::Confirm,
&focus_handle,
window,
cx,
)
}
})
.on_click(
cx.listener(Self::handle_regenerate_click),
),
),
)
Label::new("Editing will restart the thread from this point.")
.color(Color::Muted)
.size(LabelSize::XSmall),
),
)
}),
)
.on_click(cx.listener({
let message_creases = message.creases.clone();
move |this, _, window, cx| {
if let Some(message_text) =
this.thread.read(cx).message(message_id).and_then(|message| {
message.segments.first().and_then(|segment| {
match segment {
MessageSegment::Text(message_text) => {
Some(Into::<Arc<str>>::into(message_text.as_str()))
}
_ => {
None
}
}
})
})
{
this.start_editing_message(
message_id,
message_text,
&message_creases,
window,
cx,
);
}
}
})),
.child(
h_flex()
.gap_0p5()
.child(
IconButton::new(
"cancel-edit-message",
IconName::Close,
)
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Error)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Cancel Edit",
&menu::Cancel,
&focus_handle,
window,
cx,
)
}
})
.on_click(cx.listener(Self::handle_cancel_click)),
)
.child(
IconButton::new(
"confirm-edit-message",
IconName::Return,
)
.disabled(state.editor.read(cx).is_empty(cx))
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Muted)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Regenerate",
&menu::Confirm,
&focus_handle,
window,
cx,
)
}
})
.on_click(
cx.listener(Self::handle_regenerate_click),
),
),
)
)
}),
)
}
Role::Assistant => v_flex()
.id(("message-container", ix))
.px(RESPONSE_PADDING_X)
.gap_2()
.children(message_content)
.when(has_tool_uses, |parent| {
parent.children(tool_uses.into_iter().map(|tool_use| {
self.render_tool_use(tool_use, window, workspace.clone(), cx)
}))
}),
Role::System => {
let colors = cx.theme().colors();
div().id(("message-container", ix)).py_1().px_2().child(
v_flex()
.bg(colors.editor_background)
.rounded_sm()
.child(div().p_4().children(message_content)),
)
}
}
.on_click(cx.listener({
let message_segments = message.segments.clone();
move |this, _, window, cx| {
this.start_editing_message(
message_id,
&message_segments,
&message_creases,
window,
cx,
);
}
})),
),
Role::Assistant => v_flex()
.id(("message-container", ix))
.px(RESPONSE_PADDING_X)
.gap_2()
.children(message_content)
.when(has_tool_uses, |parent| {
parent.children(tool_uses.into_iter().map(|tool_use| {
self.render_tool_use(tool_use, window, workspace.clone(), cx)
}))
}),
Role::System => div().id(("message-container", ix)).py_1().px_2().child(
v_flex()
.bg(colors.editor_background)
.rounded_sm()
.child(div().p_4().children(message_content)),
),
};
let after_editing_message = self
@@ -2529,26 +2510,6 @@ impl ActiveThread {
.blend(cx.theme().colors().editor_foreground.opacity(0.025))
}
fn render_ui_notification(
&self,
message_content: impl IntoIterator<Item = impl IntoElement>,
ix: usize,
cx: &mut Context<Self>,
) -> Stateful<Div> {
let message = div()
.flex_1()
.min_w_0()
.text_size(TextSize::XSmall.rems(cx))
.text_color(cx.theme().colors().text_muted)
.children(message_content);
div()
.id(("message-container", ix))
.py_1()
.px_2p5()
.child(Banner::new().severity(ui::Severity::Warning).child(message))
}
fn render_message_thinking_segment(
&self,
message_id: MessageId,
@@ -2584,7 +2545,7 @@ impl ActiveThread {
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(LoadingLabel::new("Thinking").size(LabelSize::Small)),
.child(AnimatedLabel::new("Thinking").size(LabelSize::Small)),
)
.child(
h_flex()
@@ -3153,7 +3114,7 @@ impl ActiveThread {
.border_color(self.tool_card_border_color(cx))
.rounded_b_lg()
.child(
LoadingLabel::new("Waiting for Confirmation").size(LabelSize::Small)
AnimatedLabel::new("Waiting for Confirmation").size(LabelSize::Small)
)
.child(
h_flex()
@@ -3803,9 +3764,9 @@ mod tests {
// Stream response to user message
thread.update(cx, |thread, cx| {
let intent = CompletionIntent::UserPrompt;
let request = thread.to_completion_request(model.clone(), intent, cx);
thread.stream_completion(request, model, intent, cx.active_window(), cx)
let request =
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx);
thread.stream_completion(request, model, cx.active_window(), cx)
});
// Follow the agent
cx.update(|window, cx| {
@@ -3865,15 +3826,13 @@ mod tests {
});
active_thread.update_in(cx, |active_thread, window, cx| {
if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
active_thread.start_editing_message(
message.id,
message_text,
message.creases.as_slice(),
window,
cx,
);
}
active_thread.start_editing_message(
message.id,
message.segments.as_slice(),
message.creases.as_slice(),
window,
cx,
);
let editor = active_thread
.editing_message
.as_ref()
@@ -3888,15 +3847,13 @@ mod tests {
let message = thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap());
active_thread.update_in(cx, |active_thread, window, cx| {
if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
active_thread.start_editing_message(
message.id,
message_text,
message.creases.as_slice(),
window,
cx,
);
}
active_thread.start_editing_message(
message.id,
message.segments.as_slice(),
message.creases.as_slice(),
window,
cx,
);
let editor = active_thread
.editing_message
.as_ref()
@@ -3978,15 +3935,13 @@ mod tests {
// Edit the message while the completion is still running
active_thread.update_in(cx, |active_thread, window, cx| {
if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
active_thread.start_editing_message(
message.id,
message_text,
message.creases.as_slice(),
window,
cx,
);
}
active_thread.start_editing_message(
message.id,
message.segments.as_slice(),
message.creases.as_slice(),
window,
cx,
);
let editor = active_thread
.editing_message
.as_ref()

View File

@@ -16,9 +16,7 @@ use gpui::{
Focusable, ScrollHandle, Subscription, Task, Transformation, WeakEntity, percentage,
};
use language::LanguageRegistry;
use language_model::{
LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
use notifications::status_toast::{StatusToast, ToastIcon};
use project::{
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
@@ -26,8 +24,8 @@ use project::{
};
use settings::{Settings, update_settings_file};
use ui::{
ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
Scrollbar, ScrollbarState, Switch, SwitchColor, Tooltip, prelude::*,
ContextMenu, Disclosure, ElevationIndex, Indicator, PopoverMenu, Scrollbar, ScrollbarState,
Switch, SwitchColor, Tooltip, prelude::*,
};
use util::ResultExt as _;
use workspace::Workspace;
@@ -88,14 +86,6 @@ impl AgentConfiguration {
let scroll_handle = ScrollHandle::new();
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
let mut expanded_provider_configurations = HashMap::default();
if LanguageModelRegistry::read_global(cx)
.provider(&ZED_CLOUD_PROVIDER_ID)
.map_or(false, |cloud_provider| cloud_provider.must_accept_terms(cx))
{
expanded_provider_configurations.insert(ZED_CLOUD_PROVIDER_ID, true);
}
let mut this = Self {
fs,
language_registry,
@@ -104,7 +94,7 @@ impl AgentConfiguration {
configuration_views_by_provider: HashMap::default(),
context_server_store,
expanded_context_server_tools: HashMap::default(),
expanded_provider_configurations,
expanded_provider_configurations: HashMap::default(),
tools,
_registry_subscription: registry_subscription,
scroll_handle,
@@ -172,29 +162,19 @@ impl AgentConfiguration {
.unwrap_or(false);
v_flex()
.when(is_expanded, |this| this.mb_2())
.child(
div()
.opacity(0.6)
.px_2()
.child(Divider::horizontal().color(DividerColor::Border)),
)
.py_2()
.gap_1p5()
.border_t_1()
.border_color(cx.theme().colors().border.opacity(0.6))
.child(
h_flex()
.map(|this| {
if is_expanded {
this.mt_2().mb_1()
} else {
this.my_2()
}
})
.w_full()
.gap_1()
.justify_between()
.child(
h_flex()
.id(provider_id_string.clone())
.cursor_pointer()
.px_2()
.py_0p5()
.w_full()
.justify_between()
@@ -257,16 +237,12 @@ impl AgentConfiguration {
)
}),
)
.child(
div()
.px_2()
.when(is_expanded, |parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
None => parent.child(Label::new(format!(
"No configuration view for {provider_name}",
))),
}),
)
.when(is_expanded, |parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
None => parent.child(Label::new(format!(
"No configuration view for {provider_name}",
))),
})
}
fn render_provider_configuration_section(
@@ -276,11 +252,12 @@ impl AgentConfiguration {
let providers = LanguageModelRegistry::read_global(cx).providers();
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.border_b_1()
.border_color(cx.theme().colors().border)
.child(
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.pb_0()
.mb_2p5()
.gap_0p5()
.child(Headline::new("LLM Providers"))
@@ -289,15 +266,10 @@ impl AgentConfiguration {
.color(Color::Muted),
),
)
.child(
div()
.pl(DynamicSpacing::Base08.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.children(
providers.into_iter().map(|provider| {
self.render_provider_configuration_block(&provider, cx)
}),
),
.children(
providers
.into_iter()
.map(|provider| self.render_provider_configuration_block(&provider, cx)),
)
}

View File

@@ -180,7 +180,7 @@ impl ConfigurationSource {
}
fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand)>) -> String {
let (name, command, args, env) = match existing {
let (name, path, args, env) = match existing {
Some((id, cmd)) => {
let args = serde_json::to_string(&cmd.args).unwrap();
let env = serde_json::to_string(&cmd.env.unwrap_or_default()).unwrap();
@@ -198,12 +198,14 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand)
r#"{{
/// The name of your MCP server
"{name}": {{
/// The command which runs the MCP server
"command": "{command}",
/// The arguments to pass to the MCP server
"args": {args},
/// The environment variables to set
"env": {env}
"command": {{
/// The path to the executable
"path": "{path}",
/// The arguments to pass to the executable
"args": {args},
/// The environment variables to set for the executable
"env": {env}
}}
}}
}}"#
)
@@ -293,7 +295,10 @@ impl ConfigureContextServerModal {
ContextServerDescriptorRegistry::default_global(cx)
.read(cx)
.context_server_descriptor(&server_id.0)
.map(|_| ContextServerSettings::default_extension())
.map(|_| ContextServerSettings::Extension {
enabled: true,
settings: serde_json::json!({}),
})
})
else {
return Task::ready(Err(anyhow::anyhow!("Context server not found")));
@@ -379,14 +384,6 @@ impl ConfigureContextServerModal {
};
self.state = State::Waiting;
let existing_server = self.context_server_store.read(cx).get_running_server(&id);
if existing_server.is_some() {
self.context_server_store.update(cx, |store, cx| {
store.stop_server(&id, cx).log_err();
});
}
let wait_for_context_server_task =
wait_for_context_server(&self.context_server_store, id.clone(), cx);
cx.spawn({
@@ -407,21 +404,13 @@ impl ConfigureContextServerModal {
})
.detach();
let settings_changed =
ProjectSettings::get_global(cx).context_servers.get(&id.0) != Some(&settings);
if settings_changed {
// When we write the settings to the file, the context server will be restarted.
workspace.update(cx, |workspace, cx| {
let fs = workspace.app_state().fs.clone();
update_settings_file::<ProjectSettings>(fs.clone(), cx, |project_settings, _| {
project_settings.context_servers.insert(id.0, settings);
});
// When we write the settings to the file, the context server will be restarted.
workspace.update(cx, |workspace, cx| {
let fs = workspace.app_state().fs.clone();
update_settings_file::<ProjectSettings>(fs.clone(), cx, |project_settings, _| {
project_settings.context_servers.insert(id.0, settings);
});
} else if let Some(existing_server) = existing_server {
self.context_server_store
.update(cx, |store, cx| store.start_server(existing_server, cx));
}
});
}
fn cancel(&mut self, _: &menu::Cancel, cx: &mut Context<Self>) {
@@ -453,7 +442,8 @@ fn parse_input(text: &str) -> Result<(ContextServerId, ContextServerCommand)> {
let object = value.as_object().context("Expected object")?;
anyhow::ensure!(object.len() == 1, "Expected exactly one key-value pair");
let (context_server_name, value) = object.into_iter().next().unwrap();
let command: ContextServerCommand = serde_json::from_value(value.clone())?;
let command = value.get("command").context("Expected command")?;
let command: ContextServerCommand = serde_json::from_value(command.clone())?;
Ok((ContextServerId(context_server_name.clone().into()), command))
}
@@ -761,7 +751,7 @@ pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle
MarkdownStyle {
base_text_style: text_style.clone(),
selection_background_color: colors.element_selection_background,
selection_background_color: cx.theme().players().local().selection,
link: TextStyleRefinement {
background_color: Some(colors.editor_foreground.opacity(0.025)),
underline: Some(UnderlineStyle {

View File

@@ -272,35 +272,42 @@ impl PickerDelegate for ToolPickerDelegate {
let server_id = server_id.clone();
let tool_name = tool_name.clone();
move |settings: &mut AgentSettingsContent, _cx| {
let profiles = settings.profiles.get_or_insert_default();
let profile = profiles
.entry(profile_id)
.or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(),
tools: default_profile.tools,
enable_all_context_servers: Some(
default_profile.enable_all_context_servers,
),
context_servers: default_profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
});
settings
.v2_setting(|v2_settings| {
let profiles = v2_settings.profiles.get_or_insert_default();
let profile =
profiles
.entry(profile_id)
.or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(),
tools: default_profile.tools,
enable_all_context_servers: Some(
default_profile.enable_all_context_servers,
),
context_servers: default_profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
});
if let Some(server_id) = server_id {
let preset = profile.context_servers.entry(server_id).or_default();
*preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
} else {
*profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
}
if let Some(server_id) = server_id {
let preset = profile.context_servers.entry(server_id).or_default();
*preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
} else {
*profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
}
Ok(())
})
.ok();
}
});
}

View File

@@ -5,8 +5,7 @@ use anyhow::Result;
use buffer_diff::DiffHunkStatus;
use collections::{HashMap, HashSet};
use editor::{
Direction, Editor, EditorEvent, EditorSettings, MultiBuffer, MultiBufferSnapshot,
SelectionEffects, ToPoint,
Direction, Editor, EditorEvent, EditorSettings, MultiBuffer, MultiBufferSnapshot, ToPoint,
actions::{GoToHunk, GoToPreviousHunk},
scroll::Autoscroll,
};
@@ -172,9 +171,15 @@ impl AgentDiffPane {
if let Some(first_hunk) = first_hunk {
let first_hunk_start = first_hunk.multi_buffer_range().start;
editor.change_selections(Default::default(), window, cx, |selections| {
selections.select_anchor_ranges([first_hunk_start..first_hunk_start]);
})
editor.change_selections(
Some(Autoscroll::fit()),
window,
cx,
|selections| {
selections
.select_anchor_ranges([first_hunk_start..first_hunk_start]);
},
)
}
}
@@ -237,7 +242,7 @@ impl AgentDiffPane {
if let Some(first_hunk) = first_hunk {
let first_hunk_start = first_hunk.multi_buffer_range().start;
editor.change_selections(Default::default(), window, cx, |selections| {
editor.change_selections(Some(Autoscroll::fit()), window, cx, |selections| {
selections.select_anchor_ranges([first_hunk_start..first_hunk_start]);
})
}
@@ -411,7 +416,7 @@ fn update_editor_selection(
};
if let Some(target_hunk) = target_hunk {
editor.change_selections(Default::default(), window, cx, |selections| {
editor.change_selections(Some(Autoscroll::fit()), window, cx, |selections| {
let next_hunk_start = target_hunk.multi_buffer_range().start;
selections.select_anchor_ranges([next_hunk_start..next_hunk_start]);
})
@@ -1375,7 +1380,6 @@ impl AgentDiff {
| ThreadEvent::ToolConfirmationNeeded
| ThreadEvent::ToolUseLimitReached
| ThreadEvent::CancelEditing
| ThreadEvent::RetriesFailed { .. }
| ThreadEvent::ProfileChanged => {}
}
}
@@ -1539,7 +1543,7 @@ impl AgentDiff {
let first_hunk_start = first_hunk.multi_buffer_range().start;
editor.change_selections(
SelectionEffects::scroll(Autoscroll::center()),
Some(Autoscroll::center()),
window,
cx,
|selections| {
@@ -1863,7 +1867,7 @@ mod tests {
// Rejecting a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end.
editor.update_in(cx, |editor, window, cx| {
editor.change_selections(SelectionEffects::no_scroll(), window, cx, |selections| {
editor.change_selections(None, window, cx, |selections| {
selections.select_ranges([Point::new(10, 0)..Point::new(10, 0)])
});
});
@@ -2119,7 +2123,7 @@ mod tests {
// Rejecting a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end.
editor1.update_in(cx, |editor, window, cx| {
editor.change_selections(SelectionEffects::no_scroll(), window, cx, |selections| {
editor.change_selections(None, window, cx, |selections| {
selections.select_ranges([Point::new(10, 0)..Point::new(10, 0)])
});
});

View File

@@ -11,7 +11,7 @@ use language_model::{ConfiguredModel, LanguageModelRegistry};
use picker::popover_menu::PickerPopoverMenu;
use settings::update_settings_file;
use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
use ui::{PopoverMenuHandle, Tooltip, prelude::*};
pub struct AgentModelSelector {
selector: Entity<LanguageModelSelector>,
@@ -94,35 +94,20 @@ impl Render for AgentModelSelector {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let model = self.selector.read(cx).delegate.active_model(cx);
let model_name = model
.as_ref()
.map(|model| model.model.name().0)
.unwrap_or_else(|| SharedString::from("No model selected"));
let provider_icon = model
.as_ref()
.map(|model| model.provider.icon())
.unwrap_or_else(|| IconName::Ai);
let focus_handle = self.focus_handle.clone();
PickerPopoverMenu::new(
self.selector.clone(),
ButtonLike::new("active-model")
.child(
Icon::new(provider_icon)
.color(Color::Muted)
.size(IconSize::XSmall),
)
.child(
Label::new(model_name)
.color(Color::Muted)
.size(LabelSize::Small)
.ml_0p5(),
)
.child(
Icon::new(IconName::ChevronDown)
.color(Color::Muted)
.size(IconSize::XSmall),
),
Button::new("active-model", model_name)
.label_size(LabelSize::Small)
.color(Color::Muted)
.icon(IconName::ChevronDown)
.icon_size(IconSize::XSmall)
.icon_position(IconPosition::End)
.icon_color(Color::Muted),
move |window, cx| {
Tooltip::for_action_in(
"Change Model",

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,6 @@ mod agent_diff;
mod agent_model_selector;
mod agent_panel;
mod buffer_codegen;
mod burn_mode_tooltip;
mod context_picker;
mod context_server_configuration;
mod context_strip;
@@ -12,6 +11,7 @@ mod debug;
mod inline_assistant;
mod inline_prompt_editor;
mod language_model_selector;
mod max_mode_tooltip;
mod message_editor;
mod profile_selector;
mod slash_command;
@@ -48,94 +48,57 @@ pub use crate::agent_panel::{AgentPanel, ConcreteAssistantPanelDelegate};
pub use crate::inline_assistant::InlineAssistant;
use crate::slash_command_settings::SlashCommandSettings;
pub use agent_diff::{AgentDiffPane, AgentDiffToolbar};
pub use text_thread_editor::{AgentPanelDelegate, TextThreadEditor};
pub use text_thread_editor::AgentPanelDelegate;
pub use ui::preview::{all_agent_previews, get_agent_preview};
actions!(
agent,
[
/// Creates a new text-based conversation thread.
NewTextThread,
/// Toggles the context picker interface for adding files, symbols, or other context.
ToggleContextPicker,
/// Toggles the navigation menu for switching between threads and views.
ToggleNavigationMenu,
/// Toggles the options menu for agent settings and preferences.
ToggleOptionsMenu,
/// Deletes the recently opened thread from history.
DeleteRecentlyOpenThread,
/// Toggles the profile selector for switching between agent profiles.
ToggleProfileSelector,
/// Removes all added context from the current conversation.
RemoveAllContext,
/// Expands the message editor to full size.
ExpandMessageEditor,
/// Opens the conversation history view.
OpenHistory,
/// Adds a context server to the configuration.
AddContextServer,
/// Removes the currently selected thread.
RemoveSelectedThread,
/// Starts a chat conversation with the agent.
Chat,
/// Starts a chat conversation with follow-up enabled.
ChatWithFollow,
/// Cycles to the next inline assist suggestion.
CycleNextInlineAssist,
/// Cycles to the previous inline assist suggestion.
CyclePreviousInlineAssist,
/// Moves focus up in the interface.
FocusUp,
/// Moves focus down in the interface.
FocusDown,
/// Moves focus left in the interface.
FocusLeft,
/// Moves focus right in the interface.
FocusRight,
/// Removes the currently focused context item.
RemoveFocusedContext,
/// Accepts the suggested context item.
AcceptSuggestedContext,
/// Opens the active thread as a markdown file.
OpenActiveThreadAsMarkdown,
/// Opens the agent diff view to review changes.
OpenAgentDiff,
/// Keeps the current suggestion or change.
Keep,
/// Rejects the current suggestion or change.
Reject,
/// Rejects all suggestions or changes.
RejectAll,
/// Keeps all suggestions or changes.
KeepAll,
/// Follows the agent's suggestions.
Follow,
/// Resets the trial upsell notification.
ResetTrialUpsell,
/// Resets the trial end upsell notification.
ResetTrialEndUpsell,
/// Continues the current thread.
ContinueThread,
/// Continues the thread with burn mode enabled.
ContinueWithBurnMode,
/// Toggles burn mode for faster responses.
ToggleBurnMode,
]
);
/// Creates a new conversation thread, optionally based on an existing thread.
#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)]
#[action(namespace = agent)]
#[serde(deny_unknown_fields)]
pub struct NewThread {
#[serde(default)]
from_thread_id: Option<ThreadId>,
}
/// Opens the profile management interface for configuring agent tools and settings.
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)]
#[action(namespace = agent)]
#[serde(deny_unknown_fields)]
pub struct ManageProfiles {
#[serde(default)]
pub customize_tools: Option<AgentProfileId>,
@@ -194,7 +157,6 @@ pub fn init(
agent::init(cx);
agent_panel::init(cx);
context_server_configuration::init(language_registry.clone(), fs.clone(), cx);
TextThreadEditor::init(cx);
register_slash_commands(cx);
inline_assistant::init(
@@ -246,7 +208,7 @@ fn update_active_language_model_from_settings(cx: &mut App) {
}
}
let default = settings.default_model.as_ref().map(to_selected_model);
let default = to_selected_model(&settings.default_model);
let inline_assistant = settings
.inline_assistant_model
.as_ref()
@@ -266,7 +228,7 @@ fn update_active_language_model_from_settings(cx: &mut App) {
.collect::<Vec<_>>();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.select_default_model(default.as_ref(), cx);
registry.select_default_model(Some(&default), cx);
registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
registry.select_commit_message_model(commit_message.as_ref(), cx);
registry.select_thread_summary_model(thread_summary.as_ref(), cx);

View File

@@ -1094,9 +1094,15 @@ mod tests {
};
use language_model::{LanguageModelRegistry, TokenUsage};
use rand::prelude::*;
use serde::Serialize;
use settings::SettingsStore;
use std::{future, sync::Arc};
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
init_test(cx);

View File

@@ -661,7 +661,7 @@ fn recent_context_picker_entries(
let active_thread_id = workspace
.panel::<AgentPanel>(cx)
.and_then(|panel| Some(panel.read(cx).active_thread(cx)?.read(cx).id()));
.and_then(|panel| Some(panel.read(cx).active_thread()?.read(cx).id()));
if let Some((thread_store, text_thread_store)) = thread_store
.and_then(|store| store.upgrade())
@@ -930,8 +930,8 @@ impl MentionLink {
format!(
"[@{} ({}-{})]({}:{}:{}-{})",
file_name,
line_range.start + 1,
line_range.end + 1,
line_range.start,
line_range.end,
Self::SELECTION,
full_path,
line_range.start,

View File

@@ -161,7 +161,7 @@ impl ContextStrip {
let workspace = self.workspace.upgrade()?;
let panel = workspace.read(cx).panel::<AgentPanel>(cx)?.read(cx);
if let Some(active_thread) = panel.active_thread(cx) {
if let Some(active_thread) = panel.active_thread() {
let weak_active_thread = active_thread.downgrade();
let active_thread = active_thread.read(cx);

View File

@@ -18,7 +18,6 @@ use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
use collections::{HashMap, HashSet, VecDeque, hash_map};
use editor::SelectionEffects;
use editor::{
Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorEvent, ExcerptId, ExcerptRange,
MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint,
@@ -1160,7 +1159,7 @@ impl InlineAssistant {
let position = assist.range.start;
editor.update(cx, |editor, cx| {
editor.change_selections(SelectionEffects::no_scroll(), window, cx, |selections| {
editor.change_selections(None, window, cx, |selections| {
selections.select_anchor_ranges([position..position])
});

View File

@@ -18,7 +18,6 @@ use ui::{ListItem, ListItemSpacing, prelude::*};
actions!(
agent,
[
/// Toggles the language model selector dropdown.
#[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])]
ToggleModelSelector
]
@@ -400,7 +399,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let all_models = self.all_models.clone();
let active_model = (self.get_active_model)(cx);
let current_index = self.selected_index;
let bg_executor = cx.background_executor();
let language_model_registry = LanguageModelRegistry::global(cx);
@@ -442,9 +441,12 @@ impl PickerDelegate for LanguageModelPickerDelegate {
cx.spawn_in(window, async move |this, cx| {
this.update_in(cx, |this, window, cx| {
this.delegate.filtered_entries = filtered_models.entries();
// Finds the currently selected model in the list
let new_index =
Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
// Preserve selection focus
let new_index = if current_index >= this.delegate.filtered_entries.len() {
0
} else {
current_index
};
this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
cx.notify();
})

View File

@@ -1,11 +1,11 @@
use gpui::{Context, FontWeight, IntoElement, Render, Window};
use ui::{prelude::*, tooltip_container};
pub struct BurnModeTooltip {
pub struct MaxModeTooltip {
selected: bool,
}
impl BurnModeTooltip {
impl MaxModeTooltip {
pub fn new() -> Self {
Self { selected: false }
}
@@ -16,7 +16,7 @@ impl BurnModeTooltip {
}
}
impl Render for BurnModeTooltip {
impl Render for MaxModeTooltip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let (icon, color) = if self.selected {
(IconName::ZedBurnModeOn, Color::Error)

View File

@@ -575,7 +575,7 @@ impl MessageEditor {
fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let thread = self.thread.read(cx);
let model = thread.configured_model();
if !model?.model.supports_burn_mode() {
if !model?.model.supports_max_mode() {
return None;
}
@@ -1250,7 +1250,9 @@ impl MessageEditor {
self.thread
.read(cx)
.configured_model()
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
.map_or(false, |model| {
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
})
}
fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {

View File

@@ -1,8 +1,8 @@
use crate::{
burn_mode_tooltip::BurnModeTooltip,
language_model_selector::{
LanguageModelSelector, ToggleModelSelector, language_model_selector,
},
max_mode_tooltip::MaxModeTooltip,
};
use agent_settings::{AgentSettings, CompletionMode};
use anyhow::Result;
@@ -21,6 +21,7 @@ use editor::{
BlockPlacement, BlockProperties, BlockStyle, Crease, CreaseMetadata, CustomBlockId, FoldId,
RenderBlock, ToDisplayPoint,
},
scroll::Autoscroll,
};
use editor::{FoldPlaceholder, display_map::CreaseId};
use fs::Fs;
@@ -68,7 +69,7 @@ use workspace::{
searchable::{Direction, SearchableItemHandle},
};
use workspace::{
Save, Toast, Workspace,
Save, Toast, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
item::{self, FollowableItem, Item, ItemHandle},
notifications::NotificationId,
pane,
@@ -85,24 +86,16 @@ use assistant_context::{
actions!(
assistant,
[
/// Sends the current message to the assistant.
Assist,
/// Confirms and executes the entered slash command.
ConfirmCommand,
/// Copies code from the assistant's response to the clipboard.
CopyCode,
/// Cycles between user and assistant message roles.
CycleMessageRole,
/// Inserts the selected text into the active editor.
InsertIntoEditor,
/// Quotes the current selection in the assistant conversation.
QuoteSelection,
/// Splits the conversation at the current cursor position.
Split,
]
);
/// Inserts files that were dragged and dropped into the assistant conversation.
#[derive(PartialEq, Clone, Action)]
#[action(namespace = assistant, no_json, no_register)]
pub enum InsertDraggedFiles {
@@ -396,7 +389,7 @@ impl TextThreadEditor {
cursor..cursor
};
self.editor.update(cx, |editor, cx| {
editor.change_selections(Default::default(), window, cx, |selections| {
editor.change_selections(Some(Autoscroll::fit()), window, cx, |selections| {
selections.select_ranges([new_selection])
});
});
@@ -456,7 +449,8 @@ impl TextThreadEditor {
if let Some(command) = self.slash_commands.command(name, cx) {
self.editor.update(cx, |editor, cx| {
editor.transact(window, cx, |editor, window, cx| {
editor.change_selections(Default::default(), window, cx, |s| s.try_cancel());
editor
.change_selections(Some(Autoscroll::fit()), window, cx, |s| s.try_cancel());
let snapshot = editor.buffer().read(cx).snapshot(cx);
let newest_cursor = editor.selections.newest::<Point>(cx).head();
if newest_cursor.column > 0
@@ -1589,7 +1583,7 @@ impl TextThreadEditor {
self.editor.update(cx, |editor, cx| {
editor.transact(window, cx, |this, window, cx| {
this.change_selections(Default::default(), window, cx, |s| {
this.change_selections(Some(Autoscroll::fit()), window, cx, |s| {
s.select(selections);
});
this.insert("", window, cx);
@@ -2081,12 +2075,12 @@ impl TextThreadEditor {
)
}
fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
fn render_max_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let context = self.context().read(cx);
let active_model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model)?;
if !active_model.supports_burn_mode() {
if !active_model.supports_max_mode() {
return None;
}
@@ -2113,7 +2107,7 @@ impl TextThreadEditor {
});
}))
.tooltip(move |_window, cx| {
cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled))
cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
.into()
})
.into_any_element(),
@@ -2128,21 +2122,12 @@ impl TextThreadEditor {
let active_model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model);
let focus_handle = self.editor().focus_handle(cx).clone();
let model_name = match active_model {
Some(model) => model.name().0,
None => SharedString::from("No model selected"),
};
let active_provider = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.provider);
let provider_icon = match active_provider {
Some(provider) => provider.icon(),
None => IconName::Ai,
};
let focus_handle = self.editor().focus_handle(cx).clone();
PickerPopoverMenu::new(
self.language_model_selector.clone(),
ButtonLike::new("active-model")
@@ -2150,16 +2135,10 @@ impl TextThreadEditor {
.child(
h_flex()
.gap_0p5()
.child(
Icon::new(provider_icon)
.color(Color::Muted)
.size(IconSize::XSmall),
)
.child(
Label::new(model_name)
.color(Color::Muted)
.size(LabelSize::Small)
.ml_0p5(),
.color(Color::Muted),
)
.child(
Icon::new(IconName::ChevronDown)
@@ -2596,7 +2575,7 @@ impl Render for TextThreadEditor {
};
let language_model_selector = self.language_model_selector_menu_handle.clone();
let burn_mode_toggle = self.render_burn_mode_toggle(cx);
let max_mode_toggle = self.render_max_mode_toggle(cx);
v_flex()
.key_context("ContextEditor")
@@ -2651,7 +2630,7 @@ impl Render for TextThreadEditor {
h_flex()
.gap_0p5()
.child(self.render_inject_context_menu(cx))
.when_some(burn_mode_toggle, |this, element| this.child(element)),
.when_some(max_mode_toggle, |this, element| this.child(element)),
)
.child(
h_flex()
@@ -2945,6 +2924,13 @@ impl FollowableItem for TextThreadEditor {
}
}
pub struct ContextEditorToolbarItem {
active_context_editor: Option<WeakEntity<TextThreadEditor>>,
model_summary_editor: Entity<Editor>,
}
impl ContextEditorToolbarItem {}
pub fn render_remaining_tokens(
context_editor: &Entity<TextThreadEditor>,
cx: &App,
@@ -2997,6 +2983,98 @@ pub fn render_remaining_tokens(
)
}
impl Render for ContextEditorToolbarItem {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let left_side = h_flex()
.group("chat-title-group")
.gap_1()
.items_center()
.flex_grow()
.child(
div()
.w_full()
.when(self.active_context_editor.is_some(), |left_side| {
left_side.child(self.model_summary_editor.clone())
}),
)
.child(
div().visible_on_hover("chat-title-group").child(
IconButton::new("regenerate-context", IconName::RefreshTitle)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Regenerate Title"))
.on_click(cx.listener(move |_, _, _window, cx| {
cx.emit(ContextEditorToolbarItemEvent::RegenerateSummary)
})),
),
);
let right_side = h_flex()
.gap_2()
// TODO display this in a nicer way, once we have a design for it.
// .children({
// let project = self
// .workspace
// .upgrade()
// .map(|workspace| workspace.read(cx).project().downgrade());
//
// let scan_items_remaining = cx.update_global(|db: &mut SemanticDb, cx| {
// project.and_then(|project| db.remaining_summaries(&project, cx))
// });
// scan_items_remaining
// .map(|remaining_items| format!("Files to scan: {}", remaining_items))
// })
.children(
self.active_context_editor
.as_ref()
.and_then(|editor| editor.upgrade())
.and_then(|editor| render_remaining_tokens(&editor, cx)),
);
h_flex()
.px_0p5()
.size_full()
.gap_2()
.justify_between()
.child(left_side)
.child(right_side)
}
}
impl ToolbarItemView for ContextEditorToolbarItem {
fn set_active_pane_item(
&mut self,
active_pane_item: Option<&dyn ItemHandle>,
_window: &mut Window,
cx: &mut Context<Self>,
) -> ToolbarItemLocation {
self.active_context_editor = active_pane_item
.and_then(|item| item.act_as::<TextThreadEditor>(cx))
.map(|editor| editor.downgrade());
cx.notify();
if self.active_context_editor.is_none() {
ToolbarItemLocation::Hidden
} else {
ToolbarItemLocation::PrimaryRight
}
}
fn pane_focus_update(
&mut self,
_pane_focused: bool,
_window: &mut Window,
cx: &mut Context<Self>,
) {
cx.notify();
}
}
impl EventEmitter<ToolbarItemEvent> for ContextEditorToolbarItem {}
pub enum ContextEditorToolbarItemEvent {
RegenerateSummary,
}
impl EventEmitter<ContextEditorToolbarItemEvent> for ContextEditorToolbarItem {}
enum PendingSlashCommand {}
fn invoked_slash_command_fold_placeholder(
@@ -3162,7 +3240,6 @@ pub fn make_lsp_adapter_delegate(
#[cfg(test)]
mod tests {
use super::*;
use editor::SelectionEffects;
use fs::FakeFs;
use gpui::{App, TestAppContext, VisualTestContext};
use indoc::indoc;
@@ -3388,9 +3465,7 @@ mod tests {
) {
context_editor.update_in(cx, |context_editor, window, cx| {
context_editor.editor.update(cx, |editor, cx| {
editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
s.select_ranges([range])
});
editor.change_selections(None, window, cx, |s| s.select_ranges([range]));
});
context_editor.copy(&Default::default(), window, cx);

View File

@@ -1,5 +1,5 @@
use agent::{Thread, ThreadEvent};
use assistant_tool::{AnyTool, ToolSource};
use assistant_tool::{Tool, ToolSource};
use collections::HashMap;
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
use language_model::{LanguageModel, LanguageModelToolSchemaFormat};
@@ -7,7 +7,7 @@ use std::sync::Arc;
use ui::prelude::*;
pub struct IncompatibleToolsState {
cache: HashMap<LanguageModelToolSchemaFormat, Vec<AnyTool>>,
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
thread: Entity<Thread>,
_thread_subscription: Subscription,
}
@@ -29,7 +29,11 @@ impl IncompatibleToolsState {
}
}
pub fn incompatible_tools(&mut self, model: &Arc<dyn LanguageModel>, cx: &App) -> &[AnyTool] {
pub fn incompatible_tools(
&mut self,
model: &Arc<dyn LanguageModel>,
cx: &App,
) -> &[Arc<dyn Tool>] {
self.cache
.entry(model.tool_input_format())
.or_insert_with(|| {
@@ -38,15 +42,15 @@ impl IncompatibleToolsState {
.profile()
.enabled_tools(cx)
.iter()
.filter(|(_, tool)| tool.input_schema(model.tool_input_format()).is_err())
.map(|(_, tool)| tool.clone())
.filter(|tool| tool.input_schema(model.tool_input_format()).is_err())
.cloned()
.collect()
})
}
}
pub struct IncompatibleToolsTooltip {
pub incompatible_tools: Vec<AnyTool>,
pub incompatible_tools: Vec<Arc<dyn Tool>>,
}
impl Render for IncompatibleToolsTooltip {

View File

@@ -1,11 +1,13 @@
mod agent_notification;
mod burn_mode_tooltip;
mod animated_label;
mod context_pill;
mod max_mode_tooltip;
mod onboarding_modal;
pub mod preview;
mod upsell;
pub use agent_notification::*;
pub use burn_mode_tooltip::*;
pub use animated_label::*;
pub use context_pill::*;
pub use max_mode_tooltip::*;
pub use onboarding_modal::*;

View File

@@ -1,24 +1,24 @@
use crate::prelude::*;
use gpui::{Animation, AnimationExt, FontWeight, pulsating_between};
use std::time::Duration;
use ui::prelude::*;
#[derive(IntoElement)]
pub struct LoadingLabel {
pub struct AnimatedLabel {
base: Label,
text: SharedString,
}
impl LoadingLabel {
impl AnimatedLabel {
pub fn new(text: impl Into<SharedString>) -> Self {
let text = text.into();
LoadingLabel {
AnimatedLabel {
base: Label::new(text.clone()),
text,
}
}
}
impl LabelCommon for LoadingLabel {
impl LabelCommon for AnimatedLabel {
fn size(mut self, size: LabelSize) -> Self {
self.base = self.base.size(size);
self
@@ -80,14 +80,14 @@ impl LabelCommon for LoadingLabel {
}
}
impl RenderOnce for LoadingLabel {
impl RenderOnce for AnimatedLabel {
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
let text = self.text.clone();
self.base
.color(Color::Muted)
.with_animations(
"loading_label",
"animated-label",
vec![
Animation::new(Duration::from_secs(1)),
Animation::new(Duration::from_secs(1)).repeat(),

View File

@@ -6,7 +6,7 @@ use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::http::{self, HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
@@ -356,7 +356,7 @@ pub async fn complete(
.send(request)
.await
.map_err(AnthropicError::HttpSend)?;
let status_code = response.status();
let status = response.status();
let mut body = String::new();
response
.body_mut()
@@ -364,12 +364,12 @@ pub async fn complete(
.await
.map_err(AnthropicError::ReadResponse)?;
if status_code.is_success() {
if status.is_success() {
Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
} else {
Err(AnthropicError::HttpResponseError {
status_code,
message: body,
status: status.as_u16(),
body,
})
}
}
@@ -444,7 +444,11 @@ impl RateLimitInfo {
}
Self {
retry_after: parse_retry_after(headers),
retry_after: headers
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs),
requests: RateLimit::from_headers("requests", headers).ok(),
tokens: RateLimit::from_headers("tokens", headers).ok(),
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
@@ -453,17 +457,6 @@ impl RateLimitInfo {
}
}
/// Parses the Retry-After header value as an integer number of seconds (anthropic always uses
/// seconds). Note that other services might specify an HTTP date or some other format for this
/// header. Returns `None` if the header is not present or cannot be parsed.
pub fn parse_retry_after(headers: &HeaderMap<HeaderValue>) -> Option<Duration> {
headers
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs)
}
fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> {
Ok(headers
.get(key)
@@ -527,10 +520,6 @@ pub async fn stream_completion_with_rate_limit_info(
})
.boxed();
Ok((stream, Some(rate_limits)))
} else if response.status().as_u16() == 529 {
Err(AnthropicError::ServerOverloaded {
retry_after: rate_limits.retry_after,
})
} else if let Some(retry_after) = rate_limits.retry_after {
Err(AnthropicError::RateLimit { retry_after })
} else {
@@ -543,9 +532,10 @@ pub async fn stream_completion_with_rate_limit_info(
match serde_json::from_str::<Event>(&body) {
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
status_code: response.status(),
message: body,
Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)),
Err(_) => Err(AnthropicError::HttpResponseError {
status: response.status().as_u16(),
body: body,
}),
}
}
@@ -811,19 +801,16 @@ pub enum AnthropicError {
ReadResponse(io::Error),
/// HTTP error response from the API
HttpResponseError {
status_code: StatusCode,
message: String,
},
HttpResponseError { status: u16, body: String },
/// Rate limit exceeded
RateLimit { retry_after: Duration },
/// Server overloaded
ServerOverloaded { retry_after: Option<Duration> },
/// API returned an error response
ApiError(ApiError),
/// Unexpected response format
UnexpectedResponseFormat(String),
}
#[derive(Debug, Serialize, Deserialize, Error)]

View File

@@ -2140,8 +2140,7 @@ impl AssistantContext {
);
}
LanguageModelCompletionEvent::ToolUse(_) |
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
LanguageModelCompletionEvent::UsageUpdate(_) => {}
LanguageModelCompletionEvent::UsageUpdate(_) => {}
}
});
@@ -2347,13 +2346,13 @@ impl AssistantContext {
completion_request.messages.push(request_message);
}
}
let supports_burn_mode = if let Some(model) = model {
model.supports_burn_mode()
let supports_max_mode = if let Some(model) = model {
model.supports_max_mode()
} else {
false
};
if supports_burn_mode {
if supports_max_mode {
completion_request.mode = Some(self.completion_mode.into());
}
completion_request
@@ -2524,12 +2523,6 @@ impl AssistantContext {
}
let message = start_message;
let at_end = range.end >= message.offset_range.end.saturating_sub(1);
let role_after = if range.start == range.end || at_end {
Role::User
} else {
message.role
};
let role = message.role;
let mut edited_buffer = false;
@@ -2564,7 +2557,7 @@ impl AssistantContext {
};
let suffix_metadata = MessageMetadata {
role: role_after,
role,
status: MessageStatus::Done,
timestamp: suffix.id.0,
cache: None,

View File

@@ -74,7 +74,7 @@ impl SlashCommand for DeltaSlashCommand {
.slice(section.range.to_offset(&context_buffer)),
);
file_command_new_outputs.push(Arc::new(FileSlashCommand).run(
std::slice::from_ref(&metadata.path),
&[metadata.path.clone()],
context_slash_command_output_sections,
context_buffer.clone(),
workspace.clone(),

View File

@@ -22,7 +22,6 @@ gpui.workspace = true
icons.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true

View File

@@ -4,19 +4,25 @@ mod tool_registry;
mod tool_schema;
mod tool_working_set;
use std::{fmt, fmt::Debug, fmt::Formatter, ops::Deref, sync::Arc};
use std::fmt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc;
use anyhow::Result;
use gpui::{
AnyElement, AnyWindowHandle, App, Context, Entity, IntoElement, SharedString, Task, WeakEntity,
Window,
};
use gpui::AnyElement;
use gpui::AnyWindowHandle;
use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task, WeakEntity};
use icons::IconName;
use language_model::{
LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
};
use language_model::LanguageModel;
use language_model::LanguageModelImage;
use language_model::LanguageModelRequest;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
use serde::de::DeserializeOwned;
use workspace::Workspace;
pub use crate::action_log::*;
@@ -193,10 +199,7 @@ pub enum ToolSource {
}
/// A tool that can be used by a language model.
pub trait Tool: Send + Sync + 'static {
/// The input type that is accepted by the tool.
type Input: DeserializeOwned;
pub trait Tool: 'static + Send + Sync {
/// Returns the name of the tool.
fn name(&self) -> String;
@@ -213,7 +216,7 @@ pub trait Tool: Send + Sync + 'static {
/// Returns true if the tool needs the users's confirmation
/// before having permission to run.
fn needs_confirmation(&self, input: &Self::Input, cx: &App) -> bool;
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
/// Returns true if the tool may perform edits.
fn may_perform_edits(&self) -> bool;
@@ -224,18 +227,18 @@ pub trait Tool: Send + Sync + 'static {
}
/// Returns markdown to be displayed in the UI for this tool.
fn ui_text(&self, input: &Self::Input) -> String;
fn ui_text(&self, input: &serde_json::Value) -> String;
/// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
/// (so information may be missing).
fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
self.ui_text(input)
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -255,199 +258,7 @@ pub trait Tool: Send + Sync + 'static {
}
}
#[derive(Clone)]
pub struct AnyTool {
inner: Arc<dyn ErasedTool>,
}
/// Copy of `Tool` where the Input type is erased.
trait ErasedTool: Send + Sync {
fn name(&self) -> String;
fn description(&self) -> String;
fn icon(&self) -> IconName;
fn source(&self) -> ToolSource;
fn may_perform_edits(&self) -> bool;
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn ui_text(&self, input: &serde_json::Value) -> String;
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String;
fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult;
fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard>;
}
struct ErasedToolWrapper<T: Tool> {
tool: Arc<T>,
}
impl<T: Tool> ErasedTool for ErasedToolWrapper<T> {
fn name(&self) -> String {
self.tool.name()
}
fn description(&self) -> String {
self.tool.description()
}
fn icon(&self) -> IconName {
self.tool.icon()
}
fn source(&self) -> ToolSource {
self.tool.source()
}
fn may_perform_edits(&self) -> bool {
self.tool.may_perform_edits()
}
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.needs_confirmation(&parsed_input, cx),
Err(_) => true,
}
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
self.tool.input_schema(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.ui_text(&parsed_input),
Err(_) => "Invalid input".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<T::Input>(input.clone()) {
Ok(parsed_input) => self.tool.still_streaming_ui_text(&parsed_input),
Err(_) => "Invalid input".to_string(),
}
}
fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
match serde_json::from_value::<T::Input>(input) {
Ok(parsed_input) => self.tool.clone().run(
parsed_input,
request,
project,
action_log,
model,
window,
cx,
),
Err(err) => ToolResult::from(Task::ready(Err(err.into()))),
}
}
fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
self.tool
.clone()
.deserialize_card(output, project, window, cx)
}
}
impl<T: Tool> From<Arc<T>> for AnyTool {
fn from(tool: Arc<T>) -> Self {
Self {
inner: Arc::new(ErasedToolWrapper { tool }),
}
}
}
impl AnyTool {
pub fn name(&self) -> String {
self.inner.name()
}
pub fn description(&self) -> String {
self.inner.description()
}
pub fn icon(&self) -> IconName {
self.inner.icon()
}
pub fn source(&self) -> ToolSource {
self.inner.source()
}
pub fn may_perform_edits(&self) -> bool {
self.inner.may_perform_edits()
}
pub fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
self.inner.needs_confirmation(input, cx)
}
pub fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
self.inner.input_schema(format)
}
pub fn ui_text(&self, input: &serde_json::Value) -> String {
self.inner.ui_text(input)
}
pub fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
self.inner.still_streaming_ui_text(input)
}
pub fn run(
&self,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
self.inner
.run(input, request, project, action_log, model, window, cx)
}
pub fn deserialize_card(
&self,
output: serde_json::Value,
project: Entity<Project>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyToolCard> {
self.inner.deserialize_card(output, project, window, cx)
}
}
impl Debug for AnyTool {
impl Debug for dyn Tool {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Tool").field("name", &self.name()).finish()
}

View File

@@ -6,7 +6,7 @@ use gpui::Global;
use gpui::{App, ReadGlobal};
use parking_lot::RwLock;
use crate::{AnyTool, Tool};
use crate::Tool;
#[derive(Default, Deref, DerefMut)]
struct GlobalToolRegistry(Arc<ToolRegistry>);
@@ -15,7 +15,7 @@ impl Global for GlobalToolRegistry {}
#[derive(Default)]
struct ToolRegistryState {
tools: HashMap<Arc<str>, AnyTool>,
tools: HashMap<Arc<str>, Arc<dyn Tool>>,
}
#[derive(Default)]
@@ -48,7 +48,7 @@ impl ToolRegistry {
pub fn register_tool(&self, tool: impl Tool) {
let mut state = self.state.write();
let tool_name: Arc<str> = tool.name().into();
state.tools.insert(tool_name, Arc::new(tool).into());
state.tools.insert(tool_name, Arc::new(tool));
}
/// Unregisters the provided [`Tool`].
@@ -63,12 +63,12 @@ impl ToolRegistry {
}
/// Returns the list of tools in the registry.
pub fn tools(&self) -> Vec<AnyTool> {
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
self.state.read().tools.values().cloned().collect()
}
/// Returns the [`Tool`] with the given name.
pub fn tool(&self, name: &str) -> Option<AnyTool> {
pub fn tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.state.read().tools.get(name).cloned()
}
}

View File

@@ -1,77 +1,39 @@
use std::borrow::Borrow;
use std::sync::Arc;
use crate::{AnyTool, ToolRegistry, ToolSource};
use collections::{HashMap, HashSet, IndexMap};
use gpui::{App, SharedString};
use util::debug_panic;
use collections::{HashMap, IndexMap};
use gpui::App;
use crate::{Tool, ToolRegistry, ToolSource};
#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
pub struct ToolId(usize);
/// A unique identifier for a tool within a working set.
#[derive(Clone, PartialEq, Eq, Hash, Default)]
pub struct UniqueToolName(SharedString);
impl Borrow<str> for UniqueToolName {
fn borrow(&self) -> &str {
&self.0
}
}
impl From<String> for UniqueToolName {
fn from(value: String) -> Self {
UniqueToolName(SharedString::new(value))
}
}
impl Into<String> for UniqueToolName {
fn into(self) -> String {
self.0.into()
}
}
impl std::fmt::Debug for UniqueToolName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Display for UniqueToolName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_ref())
}
}
/// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)]
pub struct ToolWorkingSet {
context_server_tools_by_id: HashMap<ToolId, AnyTool>,
context_server_tools_by_name: HashMap<UniqueToolName, AnyTool>,
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
next_tool_id: ToolId,
}
impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &App) -> Option<AnyTool> {
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
self.context_server_tools_by_name
.get(name)
.cloned()
.or_else(|| ToolRegistry::global(cx).tool(name))
}
pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
let mut tools = ToolRegistry::global(cx)
.tools()
.into_iter()
.map(|tool| (UniqueToolName(tool.name().into()), tool))
.collect::<Vec<_>>();
tools.extend(self.context_server_tools_by_name.clone());
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools();
tools.extend(self.context_server_tools_by_id.values().cloned());
tools
}
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<AnyTool>> {
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
let mut tools_by_source = IndexMap::default();
for (_, tool) in self.tools(cx) {
for tool in self.tools(cx) {
tools_by_source
.entry(tool.source())
.or_insert_with(Vec::new)
@@ -87,330 +49,27 @@ impl ToolWorkingSet {
tools_by_source
}
pub fn insert(&mut self, tool: AnyTool, cx: &App) -> ToolId {
let tool_id = self.register_tool(tool);
self.tools_changed(cx);
tool_id
}
pub fn extend(&mut self, tools: impl Iterator<Item = AnyTool>, cx: &App) -> Vec<ToolId> {
let ids = tools.map(|tool| self.register_tool(tool)).collect();
self.tools_changed(cx);
ids
}
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId], cx: &App) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed(cx);
}
fn register_tool(&mut self, tool: AnyTool) -> ToolId {
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1;
self.context_server_tools_by_id
.insert(tool_id, tool.clone());
self.tools_changed();
tool_id
}
fn tools_changed(&mut self, cx: &App) {
self.context_server_tools_by_name = resolve_context_server_tool_name_conflicts(
&self
.context_server_tools_by_id
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed();
}
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.cloned()
.collect::<Vec<_>>(),
&ToolRegistry::global(cx).tools(),
.map(|tool| (tool.name(), tool.clone())),
);
}
}
fn resolve_context_server_tool_name_conflicts(
context_server_tools: &[AnyTool],
native_tools: &[AnyTool],
) -> HashMap<UniqueToolName, AnyTool> {
fn resolve_tool_name(tool: &AnyTool) -> String {
let mut tool_name = tool.name();
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
tool_name
}
const MAX_TOOL_NAME_LENGTH: usize = 64;
let mut duplicated_tool_names = HashSet::default();
let mut seen_tool_names = HashSet::default();
seen_tool_names.extend(native_tools.iter().map(|tool| tool.name()));
for tool in context_server_tools {
let tool_name = resolve_tool_name(tool);
if seen_tool_names.contains(&tool_name) {
debug_assert!(
tool.source() != ToolSource::Native,
"Expected MCP tool but got a native tool: {}",
tool_name
);
duplicated_tool_names.insert(tool_name);
} else {
seen_tool_names.insert(tool_name);
}
}
if duplicated_tool_names.is_empty() {
return context_server_tools
.into_iter()
.map(|tool| (resolve_tool_name(tool).into(), tool.clone()))
.collect();
}
context_server_tools
.into_iter()
.filter_map(|tool| {
let mut tool_name = resolve_tool_name(tool);
if !duplicated_tool_names.contains(&tool_name) {
return Some((tool_name.into(), tool.clone()));
}
match tool.source() {
ToolSource::Native => {
debug_panic!("Expected MCP tool but got a native tool: {}", tool_name);
// Built-in tools always keep their original name
Some((tool_name.into(), tool.clone()))
}
ToolSource::ContextServer { id } => {
// Context server tools are prefixed with the context server ID, and truncated if necessary
tool_name.insert(0, '_');
if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
let mut id = id.to_string();
id.truncate(len);
tool_name.insert_str(0, &id);
} else {
tool_name.insert_str(0, &id);
}
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
if seen_tool_names.contains(&tool_name) {
log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
None
} else {
Some((tool_name.into(), tool.clone()))
}
}
}
})
.collect()
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use gpui::{AnyWindowHandle, Entity, Task, TestAppContext};
use language_model::{LanguageModel, LanguageModelRequest};
use project::Project;
use crate::{ActionLog, Tool, ToolResult};
use super::*;
#[gpui::test]
fn test_unique_tool_names(cx: &mut TestAppContext) {
fn assert_tool(
tool_working_set: &ToolWorkingSet,
unique_name: &str,
expected_name: &str,
expected_source: ToolSource,
cx: &App,
) {
let tool = tool_working_set.tool(unique_name, cx).unwrap();
assert_eq!(tool.name(), expected_name);
assert_eq!(tool.source(), expected_source);
}
let tool_registry = cx.update(ToolRegistry::default_global);
tool_registry.register_tool(TestTool::new("tool1", ToolSource::Native));
tool_registry.register_tool(TestTool::new("tool2", ToolSource::Native));
let mut tool_working_set = ToolWorkingSet::default();
cx.update(|cx| {
tool_working_set.extend(
vec![
Arc::new(TestTool::new(
"tool2",
ToolSource::ContextServer { id: "mcp-1".into() },
))
.into(),
Arc::new(TestTool::new(
"tool2",
ToolSource::ContextServer { id: "mcp-2".into() },
))
.into(),
]
.into_iter(),
cx,
);
});
cx.update(|cx| {
assert_tool(&tool_working_set, "tool1", "tool1", ToolSource::Native, cx);
assert_tool(&tool_working_set, "tool2", "tool2", ToolSource::Native, cx);
assert_tool(
&tool_working_set,
"mcp-1_tool2",
"tool2",
ToolSource::ContextServer { id: "mcp-1".into() },
cx,
);
assert_tool(
&tool_working_set,
"mcp-2_tool2",
"tool2",
ToolSource::ContextServer { id: "mcp-2".into() },
cx,
);
})
}
#[gpui::test]
fn test_resolve_context_server_tool_name_conflicts() {
assert_resolve_context_server_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
],
vec![TestTool::new(
"tool3",
ToolSource::ContextServer { id: "mcp-1".into() },
)],
vec!["tool3"],
);
assert_resolve_context_server_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
],
vec![
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
],
vec!["mcp-1_tool3", "mcp-2_tool3"],
);
assert_resolve_context_server_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
TestTool::new("tool3", ToolSource::Native),
],
vec![
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
],
vec!["mcp-1_tool3", "mcp-2_tool3"],
);
// Test deduplication of tools with very long names, in this case the mcp server name should be truncated
assert_resolve_context_server_tool_name_conflicts(
vec![TestTool::new(
"tool-with-very-very-very-long-name",
ToolSource::Native,
)],
vec![TestTool::new(
"tool-with-very-very-very-long-name",
ToolSource::ContextServer {
id: "mcp-with-very-very-very-long-name".into(),
},
)],
vec!["mcp-with-very-very-very-long-_tool-with-very-very-very-long-name"],
);
fn assert_resolve_context_server_tool_name_conflicts(
builtin_tools: Vec<TestTool>,
context_server_tools: Vec<TestTool>,
expected: Vec<&'static str>,
) {
let context_server_tools: Vec<AnyTool> = context_server_tools
.into_iter()
.map(|t| Arc::new(t).into())
.collect();
let builtin_tools: Vec<AnyTool> = builtin_tools
.into_iter()
.map(|t| Arc::new(t).into())
.collect();
let tools =
resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools);
assert_eq!(tools.len(), expected.len());
for (i, (name, _)) in tools.into_iter().enumerate() {
assert_eq!(
name.0.as_ref(),
expected[i],
"Expected '{}' got '{}' at index {}",
expected[i],
name,
i
);
}
}
}
struct TestTool {
name: String,
source: ToolSource,
}
impl TestTool {
fn new(name: impl Into<String>, source: ToolSource) -> Self {
Self {
name: name.into(),
source,
}
}
}
impl Tool for TestTool {
type Input = ();
fn name(&self) -> String {
self.name.clone()
}
fn icon(&self) -> icons::IconName {
icons::IconName::Ai
}
fn may_perform_edits(&self) -> bool {
false
}
fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
true
}
fn source(&self) -> ToolSource {
self.source.clone()
}
fn description(&self) -> String {
"Test tool".to_string()
}
fn ui_text(&self, _input: &Self::Input) -> String {
"Test tool".to_string()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>,
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
ToolResult {
output: Task::ready(Err(anyhow::anyhow!("No content"))),
card: None,
}
}
}
}

View File

@@ -40,13 +40,11 @@ pub struct CopyPathToolInput {
pub struct CopyPathTool;
impl Tool for CopyPathTool {
type Input = CopyPathToolInput;
fn name(&self) -> String {
"copy_path".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -66,15 +64,20 @@ impl Tool for CopyPathTool {
json_schema_for::<CopyPathToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
format!("Copy {src} to {dest}")
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<CopyPathToolInput>(input.clone()) {
Ok(input) => {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
format!("Copy {src} to {dest}")
}
Err(_) => "Copy path".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -82,6 +85,10 @@ impl Tool for CopyPathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CopyPathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let copy_task = project.update(cx, |project, cx| {
match project
.find_project_path(&input.source_path, cx)

View File

@@ -29,8 +29,6 @@ pub struct CreateDirectoryToolInput {
pub struct CreateDirectoryTool;
impl Tool for CreateDirectoryTool {
type Input = CreateDirectoryToolInput;
fn name(&self) -> String {
"create_directory".into()
}
@@ -39,7 +37,7 @@ impl Tool for CreateDirectoryTool {
include_str!("./create_directory_tool/description.md").into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -55,13 +53,18 @@ impl Tool for CreateDirectoryTool {
json_schema_for::<CreateDirectoryToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Create directory {}", MarkdownInlineCode(&input.path))
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<CreateDirectoryToolInput>(input.clone()) {
Ok(input) => {
format!("Create directory {}", MarkdownInlineCode(&input.path))
}
Err(_) => "Create directory".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -69,6 +72,10 @@ impl Tool for CreateDirectoryTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project_path = match project.read(cx).find_project_path(&input.path, cx) {
Some(project_path) => project_path,
None => {

View File

@@ -29,13 +29,11 @@ pub struct DeletePathToolInput {
pub struct DeletePathTool;
impl Tool for DeletePathTool {
type Input = DeletePathToolInput;
fn name(&self) -> String {
"delete_path".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -55,13 +53,16 @@ impl Tool for DeletePathTool {
json_schema_for::<DeletePathToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Delete “`{}`”", input.path)
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<DeletePathToolInput>(input.clone()) {
Ok(input) => format!("Delete “`{}`”", input.path),
Err(_) => "Delete path".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -69,7 +70,10 @@ impl Tool for DeletePathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let path_str = input.path;
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
Ok(input) => input.path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&path_str, cx) else {
return Task::ready(Err(anyhow!(
"Couldn't delete {path_str} because that path isn't in this project."

View File

@@ -42,13 +42,11 @@ where
pub struct DiagnosticsTool;
impl Tool for DiagnosticsTool {
type Input = DiagnosticsToolInput;
fn name(&self) -> String {
"diagnostics".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -68,9 +66,15 @@ impl Tool for DiagnosticsTool {
json_schema_for::<DiagnosticsToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
if let Some(path) = input.path.as_ref().filter(|p| !p.is_empty()) {
format!("Check diagnostics for {}", MarkdownInlineCode(path))
fn ui_text(&self, input: &serde_json::Value) -> String {
if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input.clone())
.ok()
.and_then(|input| match input.path {
Some(path) if !path.is_empty() => Some(path),
_ => None,
})
{
format!("Check diagnostics for {}", MarkdownInlineCode(&path))
} else {
"Check project diagnostics".to_string()
}
@@ -78,7 +82,7 @@ impl Tool for DiagnosticsTool {
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -86,7 +90,10 @@ impl Tool for DiagnosticsTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
match input.path {
match serde_json::from_value::<DiagnosticsToolInput>(input)
.ok()
.and_then(|input| input.path)
{
Some(path) if !path.is_empty() => {
let Some(project_path) = project.read(cx).find_project_path(&path, cx) else {
return Task::ready(Err(anyhow!("Could not find path {path} in project",)))

View File

@@ -29,7 +29,6 @@ use std::{
path::Path,
str::FromStr,
sync::mpsc,
time::Duration,
};
use util::path;
@@ -1471,7 +1470,7 @@ impl EditAgentTest {
Project::init_settings(cx);
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
crate::init(client.http_client(), cx);
});
@@ -1659,14 +1658,12 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
match request().await {
Ok(result) => return Ok(result),
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
Ok(err) => match &err {
LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
| LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
Ok(err) => match err {
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
eprintln!(
"Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
"Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}"
);
Timer::after(retry_after + jitter).await;
continue;

View File

@@ -10,7 +10,7 @@ use assistant_tool::{
ToolUseStatus,
};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey};
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey, scroll::Autoscroll};
use futures::StreamExt;
use gpui::{
Animation, AnimationExt, AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task,
@@ -121,13 +121,11 @@ struct PartialInput {
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for EditFileTool {
type Input = EditFileToolInput;
fn name(&self) -> String {
"edit_file".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -147,20 +145,24 @@ impl Tool for EditFileTool {
json_schema_for::<EditFileToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
input.display_description.clone()
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &Self::Input) -> String {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.to_string_lossy();
let path = path.trim();
if !path.is_empty() {
return path.to_string();
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
@@ -168,7 +170,7 @@ impl Tool for EditFileTool {
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -176,6 +178,11 @@ impl Tool for EditFileTool {
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<EditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
@@ -816,7 +823,7 @@ impl ToolCard for EditFileToolCard {
let first_hunk_start =
first_hunk.multi_buffer_range().start;
editor.change_selections(
Default::default(),
Some(Autoscroll::fit()),
window,
cx,
|selections| {
@@ -1058,7 +1065,7 @@ fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
MarkdownStyle {
base_text_style: text_style.clone(),
selection_background_color: cx.theme().colors().element_selection_background,
selection_background_color: cx.theme().players().local().selection,
..Default::default()
}
}
@@ -1162,11 +1169,12 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = EditFileToolInput {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Some edit".into(),
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
@@ -1280,22 +1288,24 @@ mod tests {
#[test]
fn still_streaming_ui_text_with_path() {
let input = EditFileToolInput {
path: "src/main.rs".into(),
display_description: "".into(),
mode: EditFileMode::Edit,
};
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
}
#[test]
fn still_streaming_ui_text_with_description() {
let input = EditFileToolInput {
path: "".into(),
display_description: "Fix error handling".into(),
mode: EditFileMode::Edit,
};
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@@ -1305,11 +1315,12 @@ mod tests {
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let input = EditFileToolInput {
path: "src/main.rs".into(),
display_description: "Fix error handling".into(),
mode: EditFileMode::Edit,
};
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@@ -1319,11 +1330,12 @@ mod tests {
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let input = EditFileToolInput {
path: "".into(),
display_description: "".into(),
mode: EditFileMode::Edit,
};
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@@ -1333,11 +1345,7 @@ mod tests {
#[test]
fn still_streaming_ui_text_with_null() {
let input = EditFileToolInput {
path: "".into(),
display_description: "".into(),
mode: EditFileMode::Edit,
};
let input = serde_json::Value::Null;
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
@@ -1449,11 +1457,12 @@ mod tests {
// Have the model stream unformatted content
let edit_result = {
let edit_task = cx.update(|cx| {
let input = EditFileToolInput {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
@@ -1512,11 +1521,12 @@ mod tests {
// Stream unformatted edits again
let edit_result = {
let edit_task = cx.update(|cx| {
let input = EditFileToolInput {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
@@ -1590,11 +1600,12 @@ mod tests {
// Have the model stream content that contains trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = EditFileToolInput {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
@@ -1646,11 +1657,12 @@ mod tests {
// Stream edits again with trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = EditFileToolInput {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,

View File

@@ -3,10 +3,10 @@ use std::sync::Arc;
use std::{borrow::Cow, cell::RefCell};
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, bail};
use anyhow::{Context as _, Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::AsyncReadExt as _;
use gpui::{AnyWindowHandle, App, AppContext as _, Entity};
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
use http_client::{AsyncBody, HttpClientWithUrl};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
@@ -113,13 +113,11 @@ impl FetchTool {
}
impl Tool for FetchTool {
type Input = FetchToolInput;
fn name(&self) -> String {
"fetch".to_string()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -139,13 +137,16 @@ impl Tool for FetchTool {
json_schema_for::<FetchToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Fetch {}", MarkdownEscaped(&input.url))
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FetchToolInput>(input.clone()) {
Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)),
Err(_) => "Fetch URL".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -153,6 +154,11 @@ impl Tool for FetchTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<FetchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let text = cx.background_spawn({
let http_client = self.http_client.clone();
async move { Self::build_message(http_client, &input.url).await }

View File

@@ -51,13 +51,11 @@ const RESULTS_PER_PAGE: usize = 50;
pub struct FindPathTool;
impl Tool for FindPathTool {
type Input = FindPathToolInput;
fn name(&self) -> String {
"find_path".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -77,13 +75,16 @@ impl Tool for FindPathTool {
json_schema_for::<FindPathToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Find paths matching \"`{}`\"", input.glob)
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FindPathToolInput>(input.clone()) {
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
Err(_) => "Search paths".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -91,7 +92,10 @@ impl Tool for FindPathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let (offset, glob) = (input.offset, input.glob);
let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
Ok(input) => (input.offset, input.glob),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let (sender, receiver) = oneshot::channel();

View File

@@ -53,13 +53,11 @@ const RESULTS_PER_PAGE: u32 = 20;
pub struct GrepTool;
impl Tool for GrepTool {
type Input = GrepToolInput;
fn name(&self) -> String {
"grep".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -79,25 +77,30 @@ impl Tool for GrepTool {
json_schema_for::<GrepToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let page = input.page();
let regex_str = MarkdownInlineCode(&input.regex);
let case_info = if input.case_sensitive {
" (case-sensitive)"
} else {
""
};
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<GrepToolInput>(input.clone()) {
Ok(input) => {
let page = input.page();
let regex_str = MarkdownInlineCode(&input.regex);
let case_info = if input.case_sensitive {
" (case-sensitive)"
} else {
""
};
if page > 1 {
format!("Get page {page} of search results for regex {regex_str}{case_info}")
} else {
format!("Search files for regex {regex_str}{case_info}")
if page > 1 {
format!("Get page {page} of search results for regex {regex_str}{case_info}")
} else {
format!("Search files for regex {regex_str}{case_info}")
}
}
Err(_) => "Search with regex".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -108,6 +111,13 @@ impl Tool for GrepTool {
const CONTEXT_LINES: u32 = 2;
const MAX_ANCESTOR_LINES: u32 = 10;
let input = match serde_json::from_value::<GrepToolInput>(input) {
Ok(input) => input,
Err(error) => {
return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into();
}
};
let include_matcher = match PathMatcher::new(
input
.include_pattern
@@ -338,12 +348,13 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Test with include pattern for Rust files inside the root of the project
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "println".to_string(),
include_pattern: Some("root/**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(result.contains("main.rs"), "Should find matches in main.rs");
@@ -357,12 +368,13 @@ mod tests {
);
// Test with include pattern for src directory only
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "fn".to_string(),
include_pattern: Some("root/**/src/**".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@@ -379,12 +391,13 @@ mod tests {
);
// Test with empty include pattern (should default to all files)
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "fn".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(result.contains("main.rs"), "Should find matches in main.rs");
@@ -415,12 +428,13 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Test case-insensitive search (default)
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "uppercase".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@@ -429,12 +443,13 @@ mod tests {
);
// Test case-sensitive search
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "uppercase".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: true,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@@ -443,12 +458,13 @@ mod tests {
);
// Test case-sensitive search
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "LOWERCASE".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: true,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
@@ -458,12 +474,13 @@ mod tests {
);
// Test case-sensitive search for lowercase pattern
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "lowercase".to_string(),
include_pattern: Some("**/*.txt".to_string()),
offset: 0,
case_sensitive: true,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
assert!(
@@ -559,12 +576,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line at the top level of the file
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "This is at the top level".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -588,12 +606,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line inside a function body
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "Function in nested module".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -619,12 +638,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line with a function argument
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "second_arg".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -654,12 +674,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line inside an if block
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "Inside if block".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -684,12 +705,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line in the middle of a long function - should show message about remaining lines
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "Line 5".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -724,12 +746,13 @@ mod tests {
let project = setup_syntax_test(cx).await;
// Test: Line in the long function
let input = GrepToolInput {
let input = serde_json::to_value(GrepToolInput {
regex: "Line 12".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
@@ -751,7 +774,7 @@ mod tests {
}
async fn run_grep_tool(
input: GrepToolInput,
input: serde_json::Value,
project: Entity<Project>,
cx: &mut TestAppContext,
) -> String {
@@ -853,12 +876,9 @@ mod tests {
// Searching for files outside the project worktree should return no results
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "outside_function".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "outside_function"
});
Arc::new(GrepTool)
.run(
input,
@@ -882,12 +902,9 @@ mod tests {
// Searching within the project should succeed
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "main".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "main"
});
Arc::new(GrepTool)
.run(
input,
@@ -911,12 +928,9 @@ mod tests {
// Searching files that match file_scan_exclusions should return no results
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "special_configuration".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "special_configuration"
});
Arc::new(GrepTool)
.run(
input,
@@ -939,12 +953,9 @@ mod tests {
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "custom_metadata".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "custom_metadata"
});
Arc::new(GrepTool)
.run(
input,
@@ -968,12 +979,9 @@ mod tests {
// Searching private files should return no results
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "SECRET_KEY".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "SECRET_KEY"
});
Arc::new(GrepTool)
.run(
input,
@@ -996,12 +1004,9 @@ mod tests {
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "private_key_content".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "private_key_content"
});
Arc::new(GrepTool)
.run(
input,
@@ -1024,12 +1029,9 @@ mod tests {
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "sensitive_data".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "sensitive_data"
});
Arc::new(GrepTool)
.run(
input,
@@ -1053,12 +1055,9 @@ mod tests {
// Searching a normal file should still work, even with private_files configured
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "normal_file_content".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "normal_file_content"
});
Arc::new(GrepTool)
.run(
input,
@@ -1082,12 +1081,10 @@ mod tests {
// Path traversal attempts with .. in include_pattern should not escape project
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "outside_function".to_string(),
include_pattern: Some("../outside_project/**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "outside_function",
"include_pattern": "../outside_project/**/*.rs"
});
Arc::new(GrepTool)
.run(
input,
@@ -1188,12 +1185,10 @@ mod tests {
// Search for "secret" - should exclude files based on worktree-specific settings
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "secret".to_string(),
include_pattern: None,
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "secret",
"case_sensitive": false
});
Arc::new(GrepTool)
.run(
input,
@@ -1255,12 +1250,10 @@ mod tests {
// Test with `include_pattern` specific to one worktree
let result = cx
.update(|cx| {
let input = GrepToolInput {
regex: "secret".to_string(),
include_pattern: Some("worktree1/**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
};
let input = json!({
"regex": "secret",
"include_pattern": "worktree1/**/*.rs"
});
Arc::new(GrepTool)
.run(
input,

View File

@@ -41,13 +41,11 @@ pub struct ListDirectoryToolInput {
pub struct ListDirectoryTool;
impl Tool for ListDirectoryTool {
type Input = ListDirectoryToolInput;
fn name(&self) -> String {
"list_directory".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -67,14 +65,19 @@ impl Tool for ListDirectoryTool {
json_schema_for::<ListDirectoryToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let path = MarkdownInlineCode(&input.path);
format!("List the {path} directory's contents")
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ListDirectoryToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownInlineCode(&input.path);
format!("List the {path} directory's contents")
}
Err(_) => "List directory".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -82,6 +85,11 @@ impl Tool for ListDirectoryTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// Sometimes models will return these even though we tell it to give a path and not a glob.
// When this happens, just list the root worktree directories.
if matches!(input.path.as_str(), "." | "" | "./" | "*") {
@@ -277,9 +285,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Test listing root directory
let input = ListDirectoryToolInput {
path: "project".to_string(),
};
let input = json!({
"path": "project"
});
let result = cx
.update(|cx| {
@@ -312,9 +320,9 @@ mod tests {
);
// Test listing src directory
let input = ListDirectoryToolInput {
path: "project/src".to_string(),
};
let input = json!({
"path": "project/src"
});
let result = cx
.update(|cx| {
@@ -347,9 +355,9 @@ mod tests {
);
// Test listing directory with only files
let input = ListDirectoryToolInput {
path: "project/tests".to_string(),
};
let input = json!({
"path": "project/tests"
});
let result = cx
.update(|cx| {
@@ -391,9 +399,9 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let tool = Arc::new(ListDirectoryTool);
let input = ListDirectoryToolInput {
path: "project/empty_dir".to_string(),
};
let input = json!({
"path": "project/empty_dir"
});
let result = cx
.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
@@ -424,9 +432,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Test non-existent path
let input = ListDirectoryToolInput {
path: "project/nonexistent".to_string(),
};
let input = json!({
"path": "project/nonexistent"
});
let result = cx
.update(|cx| {
@@ -447,9 +455,9 @@ mod tests {
assert!(result.unwrap_err().to_string().contains("Path not found"));
// Test trying to list a file instead of directory
let input = ListDirectoryToolInput {
path: "project/file.txt".to_string(),
};
let input = json!({
"path": "project/file.txt"
});
let result = cx
.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
@@ -519,9 +527,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Listing root directory should exclude private and excluded files
let input = ListDirectoryToolInput {
path: "project".to_string(),
};
let input = json!({
"path": "project"
});
let result = cx
.update(|cx| {
@@ -560,9 +568,9 @@ mod tests {
);
// Trying to list an excluded directory should fail
let input = ListDirectoryToolInput {
path: "project/.secretdir".to_string(),
};
let input = json!({
"path": "project/.secretdir"
});
let result = cx
.update(|cx| {
@@ -592,9 +600,9 @@ mod tests {
);
// Listing a directory should exclude private files within it
let input = ListDirectoryToolInput {
path: "project/visible_dir".to_string(),
};
let input = json!({
"path": "project/visible_dir"
});
let result = cx
.update(|cx| {
@@ -712,9 +720,9 @@ mod tests {
let tool = Arc::new(ListDirectoryTool);
// Test listing worktree1/src - should exclude secret.rs and config.toml based on local settings
let input = ListDirectoryToolInput {
path: "worktree1/src".to_string(),
};
let input = json!({
"path": "worktree1/src"
});
let result = cx
.update(|cx| {
@@ -744,9 +752,9 @@ mod tests {
);
// Test listing worktree1/tests - should exclude fixture.sql based on local settings
let input = ListDirectoryToolInput {
path: "worktree1/tests".to_string(),
};
let input = json!({
"path": "worktree1/tests"
});
let result = cx
.update(|cx| {
@@ -772,9 +780,9 @@ mod tests {
);
// Test listing worktree2/lib - should exclude private.js and data.json based on local settings
let input = ListDirectoryToolInput {
path: "worktree2/lib".to_string(),
};
let input = json!({
"path": "worktree2/lib"
});
let result = cx
.update(|cx| {
@@ -804,9 +812,9 @@ mod tests {
);
// Test listing worktree2/docs - should exclude internal.md based on local settings
let input = ListDirectoryToolInput {
path: "worktree2/docs".to_string(),
};
let input = json!({
"path": "worktree2/docs"
});
let result = cx
.update(|cx| {
@@ -832,9 +840,9 @@ mod tests {
);
// Test trying to list an excluded directory directly
let input = ListDirectoryToolInput {
path: "worktree1/src/secret.rs".to_string(),
};
let input = json!({
"path": "worktree1/src/secret.rs"
});
let result = cx
.update(|cx| {

View File

@@ -38,13 +38,11 @@ pub struct MovePathToolInput {
pub struct MovePathTool;
impl Tool for MovePathTool {
type Input = MovePathToolInput;
fn name(&self) -> String {
"move_path".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -64,29 +62,34 @@ impl Tool for MovePathTool {
json_schema_for::<MovePathToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
let src_path = Path::new(&input.source_path);
let dest_path = Path::new(&input.destination_path);
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<MovePathToolInput>(input.clone()) {
Ok(input) => {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);
let src_path = Path::new(&input.source_path);
let dest_path = Path::new(&input.destination_path);
match dest_path
.file_name()
.and_then(|os_str| os_str.to_os_string().into_string().ok())
{
Some(filename) if src_path.parent() == dest_path.parent() => {
let filename = MarkdownInlineCode(&filename);
format!("Rename {src} to {filename}")
}
_ => {
format!("Move {src} to {dest}")
match dest_path
.file_name()
.and_then(|os_str| os_str.to_os_string().into_string().ok())
{
Some(filename) if src_path.parent() == dest_path.parent() => {
let filename = MarkdownInlineCode(&filename);
format!("Rename {src} to {filename}")
}
_ => {
format!("Move {src} to {dest}")
}
}
}
Err(_) => "Move path".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -94,6 +97,10 @@ impl Tool for MovePathTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<MovePathToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let rename_task = project.update(cx, |project, cx| {
match project
.find_project_path(&input.source_path, cx)

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::Result;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use chrono::{Local, Utc};
use gpui::{AnyWindowHandle, App, Entity, Task};
@@ -29,13 +29,11 @@ pub struct NowToolInput {
pub struct NowTool;
impl Tool for NowTool {
type Input = NowToolInput;
fn name(&self) -> String {
"now".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -55,13 +53,13 @@ impl Tool for NowTool {
json_schema_for::<NowToolInput>(format)
}
fn ui_text(&self, _input: &Self::Input) -> String {
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Get current time".to_string()
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -69,6 +67,11 @@ impl Tool for NowTool {
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
let input: NowToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let now = match input.timezone {
Timezone::Utc => Utc::now().to_rfc3339(),
Timezone::Local => Local::now().to_rfc3339(),

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AppContext, Entity};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -19,13 +19,11 @@ pub struct OpenToolInput {
pub struct OpenTool;
impl Tool for OpenTool {
type Input = OpenToolInput;
fn name(&self) -> String {
"open".to_string()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
}
fn may_perform_edits(&self) -> bool {
@@ -43,13 +41,16 @@ impl Tool for OpenTool {
json_schema_for::<OpenToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
format!("Open `{}`", MarkdownEscaped(&input.path_or_url))
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<OpenToolInput>(input.clone()) {
Ok(input) => format!("Open `{}`", MarkdownEscaped(&input.path_or_url)),
Err(_) => "Open file or URL".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -57,6 +58,11 @@ impl Tool for OpenTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: OpenToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// If path_or_url turns out to be a path in the project, make it absolute.
let abs_path = to_absolute_path(&input.path_or_url, project, cx);

View File

@@ -51,13 +51,11 @@ pub struct ReadFileToolInput {
pub struct ReadFileTool;
impl Tool for ReadFileTool {
type Input = ReadFileToolInput;
fn name(&self) -> String {
"read_file".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -77,18 +75,23 @@ impl Tool for ReadFileTool {
json_schema_for::<ReadFileToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let path = MarkdownInlineCode(&input.path);
match (input.start_line, input.end_line) {
(Some(start), None) => format!("Read file {path} (from line {start})"),
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
_ => format!("Read file {path}"),
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownInlineCode(&input.path);
match (input.start_line, input.end_line) {
(Some(start), None) => format!("Read file {path} (from line {start})"),
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
_ => format!("Read file {path}"),
}
}
Err(_) => "Read file".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -96,6 +99,11 @@ impl Tool for ReadFileTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
};
@@ -300,12 +308,9 @@ mod test {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/nonexistent_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "root/nonexistent_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -342,11 +347,9 @@ mod test {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/small_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "root/small_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -386,11 +389,9 @@ mod test {
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/large_file.rs".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "root/large_file.rs"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -420,11 +421,10 @@ mod test {
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/large_file.rs".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "root/large_file.rs",
"offset": 1
});
Arc::new(ReadFileTool)
.run(
input,
@@ -477,11 +477,11 @@ mod test {
let model = Arc::new(FakeLanguageModel::default());
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(2),
end_line: Some(4),
};
let input = json!({
"path": "root/multiline.txt",
"start_line": 2,
"end_line": 4
});
Arc::new(ReadFileTool)
.run(
input,
@@ -520,11 +520,11 @@ mod test {
// start_line of 0 should be treated as 1
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(0),
end_line: Some(2),
};
let input = json!({
"path": "root/multiline.txt",
"start_line": 0,
"end_line": 2
});
Arc::new(ReadFileTool)
.run(
input,
@@ -543,11 +543,11 @@ mod test {
// end_line of 0 should result in at least 1 line
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(1),
end_line: Some(0),
};
let input = json!({
"path": "root/multiline.txt",
"start_line": 1,
"end_line": 0
});
Arc::new(ReadFileTool)
.run(
input,
@@ -566,11 +566,11 @@ mod test {
// when start_line > end_line, should still return at least 1 line
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(3),
end_line: Some(2),
};
let input = json!({
"path": "root/multiline.txt",
"start_line": 3,
"end_line": 2
});
Arc::new(ReadFileTool)
.run(
input,
@@ -694,11 +694,9 @@ mod test {
// Reading a file outside the project worktree should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "/outside_project/sensitive_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "/outside_project/sensitive_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -720,11 +718,9 @@ mod test {
// Reading a file within the project should succeed
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/allowed_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/allowed_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -746,11 +742,9 @@ mod test {
// Reading files that match file_scan_exclusions should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/.secretdir/config".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/.secretdir/config"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -771,11 +765,9 @@ mod test {
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/.mymetadata".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/.mymetadata"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -797,11 +789,9 @@ mod test {
// Reading private files should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/secrets/.mysecrets".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/.mysecrets"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -822,11 +812,9 @@ mod test {
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/special.privatekey".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/subdir/special.privatekey"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -847,11 +835,9 @@ mod test {
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/data.mysensitive".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/subdir/data.mysensitive"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -873,11 +859,9 @@ mod test {
// Reading a normal file should still work, even with private_files configured
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/normal_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/subdir/normal_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -900,11 +884,9 @@ mod test {
// Path traversal attempts with .. should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/../outside_project/sensitive_file.txt".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "project_root/../outside_project/sensitive_file.txt"
});
Arc::new(ReadFileTool)
.run(
input,
@@ -999,11 +981,9 @@ mod test {
let tool = Arc::new(ReadFileTool);
// Test reading allowed files in worktree1
let input = ReadFileToolInput {
path: "worktree1/src/main.rs".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree1/src/main.rs"
});
let result = cx
.update(|cx| {
@@ -1027,11 +1007,9 @@ mod test {
);
// Test reading private file in worktree1 should fail
let input = ReadFileToolInput {
path: "worktree1/src/secret.rs".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree1/src/secret.rs"
});
let result = cx
.update(|cx| {
@@ -1058,11 +1036,9 @@ mod test {
);
// Test reading excluded file in worktree1 should fail
let input = ReadFileToolInput {
path: "worktree1/tests/fixture.sql".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree1/tests/fixture.sql"
});
let result = cx
.update(|cx| {
@@ -1089,11 +1065,9 @@ mod test {
);
// Test reading allowed files in worktree2
let input = ReadFileToolInput {
path: "worktree2/lib/public.js".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree2/lib/public.js"
});
let result = cx
.update(|cx| {
@@ -1117,11 +1091,9 @@ mod test {
);
// Test reading private file in worktree2 should fail
let input = ReadFileToolInput {
path: "worktree2/lib/private.js".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree2/lib/private.js"
});
let result = cx
.update(|cx| {
@@ -1148,11 +1120,9 @@ mod test {
);
// Test reading excluded file in worktree2 should fail
let input = ReadFileToolInput {
path: "worktree2/docs/internal.md".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree2/docs/internal.md"
});
let result = cx
.update(|cx| {
@@ -1180,11 +1150,9 @@ mod test {
// Test that files allowed in one worktree but not in another are handled correctly
// (e.g., config.toml is private in worktree1 but doesn't exist in worktree2)
let input = ReadFileToolInput {
path: "worktree1/src/config.toml".to_string(),
start_line: None,
end_line: None,
};
let input = json!({
"path": "worktree1/src/config.toml"
});
let result = cx
.update(|cx| {

View File

@@ -1,9 +1,8 @@
use anyhow::Result;
use language_model::LanguageModelToolSchemaFormat;
use schemars::{
JsonSchema, Schema,
generate::SchemaSettings,
transform::{Transform, transform_subschemas},
JsonSchema,
schema::{RootSchema, Schema, SchemaObject},
};
pub fn json_schema_for<T: JsonSchema>(
@@ -14,7 +13,7 @@ pub fn json_schema_for<T: JsonSchema>(
}
fn schema_to_json(
schema: &Schema,
schema: &RootSchema,
format: LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let mut value = serde_json::to_value(schema)?;
@@ -22,40 +21,58 @@ fn schema_to_json(
Ok(value)
}
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
let mut generator = match format {
LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
.with(|settings| {
settings.meta_schema = None;
settings.inline_subschemas = true;
})
.with_transform(ToJsonSchemaSubsetTransform)
.into_generator(),
LanguageModelToolSchemaFormat::JsonSchema => schemars::SchemaGenerator::default(),
LanguageModelToolSchemaFormat::JsonSchemaSubset => {
schemars::r#gen::SchemaSettings::default()
.with(|settings| {
settings.meta_schema = None;
settings.inline_subschemas = true;
settings
.visitors
.push(Box::new(TransformToJsonSchemaSubsetVisitor));
})
.into_generator()
}
};
generator.root_schema_for::<T>()
}
#[derive(Debug, Clone)]
struct ToJsonSchemaSubsetTransform;
struct TransformToJsonSchemaSubsetVisitor;
impl Transform for ToJsonSchemaSubsetTransform {
fn transform(&mut self, schema: &mut Schema) {
impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
fn visit_root_schema(&mut self, root: &mut RootSchema) {
schemars::visit::visit_root_schema(self, root)
}
fn visit_schema(&mut self, schema: &mut Schema) {
schemars::visit::visit_schema(self, schema)
}
fn visit_schema_object(&mut self, schema: &mut SchemaObject) {
// Ensure that the type field is not an array, this happens when we use
// Option<T>, the type will be [T, "null"].
if let Some(type_field) = schema.get_mut("type") {
if let Some(types) = type_field.as_array() {
if let Some(first_type) = types.first() {
*type_field = first_type.clone();
if let Some(instance_type) = schema.instance_type.take() {
schema.instance_type = match instance_type {
schemars::schema::SingleOrVec::Single(t) => {
Some(schemars::schema::SingleOrVec::Single(t))
}
schemars::schema::SingleOrVec::Vec(items) => items
.into_iter()
.next()
.map(schemars::schema::SingleOrVec::from),
};
}
// One of is not supported, use anyOf instead.
if let Some(subschema) = schema.subschemas.as_mut() {
if let Some(one_of) = subschema.one_of.take() {
subschema.any_of = Some(one_of);
}
}
// oneOf is not supported, use anyOf instead
if let Some(one_of) = schema.remove("oneOf") {
schema.insert("anyOf".to_string(), one_of);
}
transform_subschemas(self, schema);
schemars::visit::visit_schema_object(self, schema)
}
}

View File

@@ -2,7 +2,7 @@ use crate::{
schema::json_schema_for,
ui::{COLLAPSED_LINES, ToolOutputPreview},
};
use anyhow::{Context as _, Result};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt as _, future::Shared};
use gpui::{
@@ -72,13 +72,11 @@ impl TerminalTool {
}
impl Tool for TerminalTool {
type Input = TerminalToolInput;
fn name(&self) -> String {
Self::NAME.to_string()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
}
@@ -98,24 +96,30 @@ impl Tool for TerminalTool {
json_schema_for::<TerminalToolInput>(format)
}
fn ui_text(&self, input: &Self::Input) -> String {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count();
match remaining_line_count {
0 => MarkdownInlineCode(&first_line).to_string(),
1 => MarkdownInlineCode(&format!(
"{} - {} more line",
first_line, remaining_line_count
))
.to_string(),
n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n)).to_string(),
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<TerminalToolInput>(input.clone()) {
Ok(input) => {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count();
match remaining_line_count {
0 => MarkdownInlineCode(&first_line).to_string(),
1 => MarkdownInlineCode(&format!(
"{} - {} more line",
first_line, remaining_line_count
))
.to_string(),
n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n))
.to_string(),
}
}
Err(_) => "Run terminal command".to_string(),
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -123,6 +127,11 @@ impl Tool for TerminalTool {
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: TerminalToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let working_dir = match working_dir(&input, &project, cx) {
Ok(dir) => dir,
Err(err) => return Task::ready(Err(err)).into(),
@@ -682,7 +691,7 @@ fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
MarkdownStyle {
base_text_style: text_style.clone(),
selection_background_color: cx.theme().colors().element_selection_background,
selection_background_color: cx.theme().players().local().selection,
..Default::default()
}
}
@@ -747,7 +756,7 @@ mod tests {
let result = cx.update(|cx| {
TerminalTool::run(
Arc::new(TerminalTool::new(cx)),
input,
serde_json::to_value(input).unwrap(),
Arc::default(),
project.clone(),
action_log.clone(),
@@ -782,7 +791,7 @@ mod tests {
let check = |input, expected, cx: &mut App| {
let headless_result = TerminalTool::run(
Arc::new(TerminalTool::new(cx)),
input,
serde_json::to_value(input).unwrap(),
Arc::default(),
project.clone(),
action_log.clone(),

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::Result;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
@@ -20,13 +20,11 @@ pub struct ThinkingToolInput {
pub struct ThinkingTool;
impl Tool for ThinkingTool {
type Input = ThinkingToolInput;
fn name(&self) -> String {
"thinking".to_string()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -46,13 +44,13 @@ impl Tool for ThinkingTool {
json_schema_for::<ThinkingToolInput>(format)
}
fn ui_text(&self, _input: &Self::Input) -> String {
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Thinking".to_string()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -61,6 +59,10 @@ impl Tool for ThinkingTool {
_cx: &mut App,
) -> ToolResult {
// This tool just "thinks out loud" and doesn't perform any actions.
Task::ready(Ok("Finished thinking.".to_string().into())).into()
Task::ready(match serde_json::from_value::<ThinkingToolInput>(input) {
Ok(_input) => Ok("Finished thinking.".to_string().into()),
Err(err) => Err(anyhow!(err)),
})
.into()
}
}

View File

@@ -28,13 +28,11 @@ pub struct WebSearchToolInput {
pub struct WebSearchTool;
impl Tool for WebSearchTool {
type Input = WebSearchToolInput;
fn name(&self) -> String {
"web_search".into()
}
fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
@@ -54,13 +52,13 @@ impl Tool for WebSearchTool {
json_schema_for::<WebSearchToolInput>(format)
}
fn ui_text(&self, _input: &Self::Input) -> String {
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Searching the Web".to_string()
}
fn run(
self: Arc<Self>,
input: Self::Input,
input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
@@ -68,6 +66,10 @@ impl Tool for WebSearchTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
return Task::ready(Err(anyhow!("Web search is not available."))).into();
};

View File

@@ -28,17 +28,7 @@ use workspace::Workspace;
const SHOULD_SHOW_UPDATE_NOTIFICATION_KEY: &str = "auto-updater-should-show-updated-notification";
const POLL_INTERVAL: Duration = Duration::from_secs(60 * 60);
actions!(
auto_update,
[
/// Checks for available updates.
Check,
/// Dismisses the update error message.
DismissErrorMessage,
/// Opens the release notes for the current version in a browser.
ViewReleaseNotes,
]
);
actions!(auto_update, [Check, DismissErrorMessage, ViewReleaseNotes,]);
#[derive(Serialize)]
struct UpdateRequestBody {

View File

@@ -1,7 +1,7 @@
use auto_update::AutoUpdater;
use client::proto::UpdateNotification;
use editor::{Editor, MultiBuffer};
use gpui::{App, Context, DismissEvent, Entity, Window, actions, prelude::*};
use gpui::{App, Context, DismissEvent, Entity, SharedString, Window, actions, prelude::*};
use http_client::HttpClient;
use markdown_preview::markdown_preview_view::{MarkdownPreviewMode, MarkdownPreviewView};
use release_channel::{AppVersion, ReleaseChannel};
@@ -12,13 +12,7 @@ use workspace::Workspace;
use workspace::notifications::simple_message_notification::MessageNotification;
use workspace::notifications::{NotificationId, show_app_notification};
actions!(
auto_update,
[
/// Opens the release notes for the current version in a new tab.
ViewReleaseNotesLocally
]
);
actions!(auto_update, [ViewReleaseNotesLocally]);
pub fn init(cx: &mut App) {
notify_if_app_was_updated(cx);
@@ -100,6 +94,7 @@ fn view_release_notes_locally(
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let tab_content = Some(SharedString::from(body.title.to_string()));
let editor = cx.new(|cx| {
Editor::for_multibuffer(buffer, Some(project), window, cx)
});
@@ -110,6 +105,7 @@ fn view_release_notes_locally(
editor,
workspace_handle,
language_registry,
tab_content,
window,
cx,
);
@@ -136,11 +132,6 @@ pub fn notify_if_app_was_updated(cx: &mut App) {
let Some(updater) = AutoUpdater::get(cx) else {
return;
};
if let ReleaseChannel::Nightly = ReleaseChannel::global(cx) {
return;
}
let should_show_notification = updater.read(cx).should_show_update_notification(cx);
cx.spawn(async move |cx| {
let should_show_notification = should_show_notification.await?;

View File

@@ -25,4 +25,5 @@ serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
workspace-hack.workspace = true

View File

@@ -1,6 +1,9 @@
mod models;
use anyhow::{Context, Error, Result, anyhow};
use std::collections::HashMap;
use std::pin::Pin;
use anyhow::{Context as _, Error, Result, anyhow};
use aws_sdk_bedrockruntime as bedrock;
pub use aws_sdk_bedrockruntime as bedrock_client;
pub use aws_sdk_bedrockruntime::types::{
@@ -21,10 +24,9 @@ pub use bedrock::types::{
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
};
use futures::stream::{self, BoxStream};
use futures::stream::{self, BoxStream, Stream};
use serde::{Deserialize, Serialize};
use serde_json::{Number, Value};
use std::collections::HashMap;
use thiserror::Error;
pub use crate::models::*;
@@ -32,59 +34,70 @@ pub use crate::models::*;
pub async fn stream_completion(
client: bedrock::Client,
request: Request,
handle: tokio::runtime::Handle,
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
let mut response = bedrock::Client::converse_stream(&client)
.model_id(request.model.clone())
.set_messages(request.messages.into());
handle
.spawn(async move {
let mut response = bedrock::Client::converse_stream(&client)
.model_id(request.model.clone())
.set_messages(request.messages.into());
if let Some(Thinking::Enabled {
budget_tokens: Some(budget_tokens),
}) = request.thinking
{
let thinking_config = HashMap::from([
("type".to_string(), Document::String("enabled".to_string())),
(
"budget_tokens".to_string(),
Document::Number(AwsNumber::PosInt(budget_tokens)),
),
]);
response = response.additional_model_request_fields(Document::Object(HashMap::from([(
"thinking".to_string(),
Document::from(thinking_config),
)])));
}
if let Some(Thinking::Enabled {
budget_tokens: Some(budget_tokens),
}) = request.thinking
{
response =
response.additional_model_request_fields(Document::Object(HashMap::from([(
"thinking".to_string(),
Document::from(HashMap::from([
("type".to_string(), Document::String("enabled".to_string())),
(
"budget_tokens".to_string(),
Document::Number(AwsNumber::PosInt(budget_tokens)),
),
])),
)])));
}
if request
.tools
.as_ref()
.map_or(false, |t| !t.tools.is_empty())
{
response = response.set_tool_config(request.tools);
}
if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
response = response.set_tool_config(request.tools);
}
let output = response
.send()
.await
.context("Failed to send API request to Bedrock");
let response = response.send().await;
let stream = Box::pin(stream::unfold(
output?.stream,
move |mut stream| async move {
match stream.recv().await {
Ok(Some(output)) => Some((Ok(output), stream)),
Ok(None) => None,
Err(err) => Some((
Err(BedrockError::ClientError(anyhow!(
"{:?}",
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
))),
stream,
match response {
Ok(output) => {
let stream: Pin<
Box<
dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
+ Send,
>,
> = Box::pin(stream::unfold(output.stream, |mut stream| async move {
match stream.recv().await {
Ok(Some(output)) => Some(({ Ok(output) }, stream)),
Ok(None) => None,
Err(err) => {
Some((
// TODO: Figure out how we can capture Throttling Exceptions
Err(BedrockError::ClientError(anyhow!(
"{:?}",
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
))),
stream,
))
}
}
}));
Ok(stream)
}
Err(err) => Err(anyhow!(
"{:?}",
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
)),
}
},
));
Ok(stream)
})
.await
.context("spawning a task")?
}
pub fn aws_document_to_value(document: &Document) -> Value {

View File

@@ -11,13 +11,6 @@ pub enum BedrockModelMode {
},
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct BedrockModelCacheConfiguration {
pub max_cache_anchors: usize,
pub min_total_token: u64,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
@@ -111,7 +104,6 @@ pub enum Model {
display_name: Option<String>,
max_output_tokens: Option<u64>,
default_temperature: Option<f32>,
cache_configuration: Option<BedrockModelCacheConfiguration>,
},
}
@@ -409,56 +401,6 @@ impl Model {
}
}
pub fn supports_caching(&self) -> bool {
match self {
// Only Claude models on Bedrock support caching
// Nova models support only text caching
// https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models
Self::Claude3_5Haiku
| Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking
| Self::ClaudeSonnet4
| Self::ClaudeSonnet4Thinking
| Self::ClaudeOpus4
| Self::ClaudeOpus4Thinking => true,
// Custom models - check if they have cache configuration
Self::Custom {
cache_configuration,
..
} => cache_configuration.is_some(),
// All other models don't support caching
_ => false,
}
}
pub fn cache_configuration(&self) -> Option<BedrockModelCacheConfiguration> {
match self {
Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking
| Self::ClaudeSonnet4
| Self::ClaudeSonnet4Thinking
| Self::ClaudeOpus4
| Self::ClaudeOpus4Thinking => Some(BedrockModelCacheConfiguration {
max_cache_anchors: 4,
min_total_token: 1024,
}),
Self::Claude3_5Haiku => Some(BedrockModelCacheConfiguration {
max_cache_anchors: 4,
min_total_token: 2048,
}),
Self::Custom {
cache_configuration,
..
} => cache_configuration.clone(),
_ => None,
}
}
pub fn mode(&self) -> BedrockModelMode {
match self {
Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
@@ -718,7 +660,6 @@ mod tests {
display_name: Some("My Custom Model".to_string()),
max_output_tokens: Some(8192),
default_temperature: Some(0.7),
cache_configuration: None,
};
// Custom model should return its name unchanged

View File

@@ -1867,7 +1867,7 @@ mod tests {
let hunk = diff.hunks(&buffer, cx).next().unwrap();
let new_index_text = diff
.stage_or_unstage_hunks(true, std::slice::from_ref(&hunk), &buffer, true, cx)
.stage_or_unstage_hunks(true, &[hunk.clone()], &buffer, true, cx)
.unwrap()
.to_string();
assert_eq!(new_index_text, buffer_text);

View File

@@ -29,7 +29,7 @@ client.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["screen-capture"] }
gpui.workspace = true
language.workspace = true
log.workspace = true
postage.workspace = true

View File

@@ -12,6 +12,7 @@ pub struct CallSettings {
/// Configuration of voice calls in Zed.
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct CallSettingsContent {
/// Whether the microphone should be muted when joining a channel or a call.
///

View File

@@ -81,17 +81,7 @@ pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
actions!(
client,
[
/// Signs in to Zed account.
SignIn,
/// Signs out of Zed account.
SignOut,
/// Reconnects to the collaboration server.
Reconnect
]
);
actions!(client, [SignIn, SignOut, Reconnect]);
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
pub struct ClientSettingsContent {

View File

@@ -35,7 +35,6 @@ dashmap.workspace = true
derive_more.workspace = true
envy = "0.4.2"
futures.workspace = true
gpui = { workspace = true, features = ["screen-capture"] }
hex.workspace = true
http_client.workspace = true
jsonwebtoken.workspace = true

View File

@@ -107,7 +107,7 @@ CREATE INDEX "index_worktree_entries_on_project_id" ON "worktree_entries" ("proj
CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id");
CREATE TABLE "project_repositories" (
"project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
"project_id" INTEGER NOT NULL,
"abs_path" VARCHAR,
"id" INTEGER NOT NULL,
"entry_ids" VARCHAR,
@@ -124,7 +124,7 @@ CREATE TABLE "project_repositories" (
CREATE INDEX "index_project_repositories_on_project_id" ON "project_repositories" ("project_id");
CREATE TABLE "project_repository_statuses" (
"project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
"project_id" INTEGER NOT NULL,
"repository_id" INTEGER NOT NULL,
"repo_path" VARCHAR NOT NULL,
"status" INT8 NOT NULL,

View File

@@ -1,25 +0,0 @@
DELETE FROM project_repositories
WHERE project_id NOT IN (SELECT id FROM projects);
ALTER TABLE project_repositories
ADD CONSTRAINT fk_project_repositories_project_id
FOREIGN KEY (project_id)
REFERENCES projects (id)
ON DELETE CASCADE
NOT VALID;
ALTER TABLE project_repositories
VALIDATE CONSTRAINT fk_project_repositories_project_id;
DELETE FROM project_repository_statuses
WHERE project_id NOT IN (SELECT id FROM projects);
ALTER TABLE project_repository_statuses
ADD CONSTRAINT fk_project_repository_statuses_project_id
FOREIGN KEY (project_id)
REFERENCES projects (id)
ON DELETE CASCADE
NOT VALID;
ALTER TABLE project_repository_statuses
VALIDATE CONSTRAINT fk_project_repository_statuses_project_id;

View File

@@ -1404,9 +1404,6 @@ async fn sync_model_request_usage_with_stripe(
llm_db: &Arc<LlmDatabase>,
stripe_billing: &Arc<StripeBilling>,
) -> anyhow::Result<()> {
log::info!("Stripe usage sync: Starting");
let started_at = Utc::now();
let staff_users = app.db.get_staff_users().await?;
let staff_user_ids = staff_users
.iter()
@@ -1451,10 +1448,6 @@ async fn sync_model_request_usage_with_stripe(
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
.await?;
let usage_meter_count = usage_meters.len();
log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters");
for (usage_meter, usage) in usage_meters {
maybe!(async {
let Some((billing_customer, billing_subscription)) =
@@ -1511,10 +1504,5 @@ async fn sync_model_request_usage_with_stripe(
.log_err();
}
log::info!(
"Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}",
Utc::now() - started_at
);
Ok(())
}

View File

@@ -4,19 +4,20 @@ mod tables;
#[cfg(test)]
pub mod tests;
use crate::{Error, Result};
use crate::{Error, Result, executor::Executor};
use anyhow::{Context as _, anyhow};
use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use dashmap::DashMap;
use futures::StreamExt;
use project_repository_statuses::StatusKind;
use rand::{Rng, SeedableRng, prelude::StdRng};
use rpc::ExtensionProvides;
use rpc::{
ConnectionId, ExtensionMetadata,
proto::{self},
};
use sea_orm::{
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
TransactionTrait,
entity::prelude::*,
@@ -32,6 +33,7 @@ use std::{
ops::{Deref, DerefMut},
rc::Rc,
sync::Arc,
time::Duration,
};
use time::PrimitiveDateTime;
use tokio::sync::{Mutex, OwnedMutexGuard};
@@ -56,7 +58,6 @@ pub use tables::*;
#[cfg(test)]
pub struct DatabaseTestOptions {
pub executor: gpui::BackgroundExecutor,
pub runtime: tokio::runtime::Runtime,
pub query_failure_probability: parking_lot::Mutex<f64>,
}
@@ -68,6 +69,8 @@ pub struct Database {
pool: DatabaseConnection,
rooms: DashMap<RoomId, Arc<Mutex<()>>>,
projects: DashMap<ProjectId, Arc<Mutex<()>>>,
rng: Mutex<StdRng>,
executor: Executor,
notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
notification_kinds_by_name: HashMap<String, NotificationKindId>,
#[cfg(test)]
@@ -78,15 +81,17 @@ pub struct Database {
// separate files in the `queries` folder.
impl Database {
/// Connects to the database with the given options
pub async fn new(options: ConnectOptions) -> Result<Self> {
pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
sqlx::any::install_default_drivers();
Ok(Self {
options: options.clone(),
pool: sea_orm::Database::connect(options).await?,
rooms: DashMap::with_capacity(16384),
projects: DashMap::with_capacity(16384),
rng: Mutex::new(StdRng::seed_from_u64(0)),
notification_kinds_by_id: HashMap::default(),
notification_kinds_by_name: HashMap::default(),
executor,
#[cfg(test)]
test_options: None,
})
@@ -102,13 +107,48 @@ impl Database {
self.projects.clear();
}
/// Transaction runs things in a transaction. If you want to call other methods
/// and pass the transaction around you need to reborrow the transaction at each
/// call site with: `&*tx`.
pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let (tx, result) = self.with_transaction(&f).await?;
let mut i = 0;
loop {
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(result) => match tx.commit().await.map_err(Into::into) {
Ok(()) => return Ok(result),
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
i += 1;
}
};
self.run(body).await
}
pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let (tx, result) = self.with_weak_transaction(&f).await?;
match result {
Ok(result) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(result),
@@ -134,28 +174,44 @@ impl Database {
Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
{
let body = async {
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(Some((room_id, data))) => {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(Some(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
})),
Err(error) => Err(error),
let mut i = 0;
loop {
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(Some((room_id, data))) => {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
match tx.commit().await.map_err(Into::into) {
Ok(()) => {
return Ok(Some(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
}));
}
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
}
Ok(None) => match tx.commit().await.map_err(Into::into) {
Ok(()) => return Ok(None),
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
Ok(None) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(None),
Err(error) => Err(error),
},
Err(error) => {
tx.rollback().await?;
Err(error)
}
i += 1;
}
};
@@ -173,26 +229,38 @@ impl Database {
{
let room_id = Database::room_id_for_project(self, project_id).await?;
let body = async {
let lock = if let Some(room_id) = room_id {
self.rooms.entry(room_id).or_default().clone()
} else {
self.projects.entry(project_id).or_default().clone()
};
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
}),
Err(error) => Err(error),
},
Err(error) => {
tx.rollback().await?;
Err(error)
let mut i = 0;
loop {
let lock = if let Some(room_id) = room_id {
self.rooms.entry(room_id).or_default().clone()
} else {
self.projects.entry(project_id).or_default().clone()
};
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => {
return Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
});
}
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
i += 1;
}
};
@@ -212,22 +280,34 @@ impl Database {
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
}),
Err(error) => Err(error),
},
Err(error) => {
tx.rollback().await?;
Err(error)
let mut i = 0;
loop {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => {
return Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
});
}
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
i += 1;
}
};
@@ -235,6 +315,28 @@ impl Database {
}
async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let tx = self
.pool
.begin_with_config(Some(IsolationLevel::Serializable), None)
.await?;
let mut tx = Arc::new(Some(tx));
let result = f(TransactionHandle(tx.clone())).await;
let tx = Arc::get_mut(&mut tx)
.and_then(|tx| tx.take())
.context("couldn't complete transaction because it's still in use")?;
Ok((tx, result))
}
async fn with_weak_transaction<F, Fut, T>(
&self,
f: &F,
) -> Result<(DatabaseTransaction, Result<T>)>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
@@ -259,13 +361,13 @@ impl Database {
{
#[cfg(test)]
{
use rand::prelude::*;
let test_options = self.test_options.as_ref().unwrap();
test_options.executor.simulate_random_delay().await;
let fail_probability = *test_options.query_failure_probability.lock();
if test_options.executor.rng().gen_bool(fail_probability) {
return Err(anyhow!("simulated query failure"))?;
if let Executor::Deterministic(executor) = &self.executor {
executor.simulate_random_delay().await;
let fail_probability = *test_options.query_failure_probability.lock();
if executor.rng().gen_bool(fail_probability) {
return Err(anyhow!("simulated query failure"))?;
}
}
test_options.runtime.block_on(future)
@@ -276,6 +378,46 @@ impl Database {
future.await
}
}
async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: usize) -> bool {
// If the error is due to a failure to serialize concurrent transactions, then retry
// this transaction after a delay. With each subsequent retry, double the delay duration.
// Also vary the delay randomly in order to ensure different database connections retry
// at different times.
const SLEEPS: [f32; 10] = [10., 20., 40., 80., 160., 320., 640., 1280., 2560., 5120.];
if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() {
let base_delay = SLEEPS[prev_attempt_count];
let randomized_delay = base_delay * self.rng.lock().await.gen_range(0.5..=2.0);
log::warn!(
"retrying transaction after serialization error. delay: {} ms.",
randomized_delay
);
self.executor
.sleep(Duration::from_millis(randomized_delay as u64))
.await;
true
} else {
false
}
}
}
fn is_serialization_error(error: &Error) -> bool {
const SERIALIZATION_FAILURE_CODE: &str = "40001";
match error {
Error::Database(
DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
| DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
) if error
.as_database_error()
.and_then(|error| error.code())
.as_deref()
== Some(SERIALIZATION_FAILURE_CODE) =>
{
true
}
_ => false,
}
}
/// A handle to a [`DatabaseTransaction`].

View File

@@ -20,7 +20,7 @@ impl Database {
&self,
params: &CreateBillingCustomerParams,
) -> Result<billing_customer::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
user_id: ActiveValue::set(params.user_id),
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
@@ -40,7 +40,7 @@ impl Database {
id: BillingCustomerId,
params: &UpdateBillingCustomerParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
billing_customer::Entity::update(billing_customer::ActiveModel {
id: ActiveValue::set(id),
user_id: params.user_id.clone(),
@@ -61,7 +61,7 @@ impl Database {
&self,
id: BillingCustomerId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::Id.eq(id))
.one(&*tx)
@@ -75,7 +75,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::UserId.eq(user_id))
.one(&*tx)
@@ -89,7 +89,7 @@ impl Database {
&self,
stripe_customer_id: &str,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
.one(&*tx)

View File

@@ -22,7 +22,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_preference::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_preference::Entity::find()
.filter(billing_preference::Column::UserId.eq(user_id))
.one(&*tx)
@@ -37,7 +37,7 @@ impl Database {
user_id: UserId,
params: &CreateBillingPreferencesParams,
) -> Result<billing_preference::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel {
user_id: ActiveValue::set(user_id),
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
@@ -65,7 +65,7 @@ impl Database {
user_id: UserId,
params: &UpdateBillingPreferencesParams,
) -> Result<billing_preference::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let preferences = billing_preference::Entity::update_many()
.set(billing_preference::ActiveModel {
max_monthly_llm_usage_spending_in_cents: params

View File

@@ -35,7 +35,7 @@ impl Database {
&self,
params: &CreateBillingSubscriptionParams,
) -> Result<billing_subscription::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
kind: ActiveValue::set(params.kind),
@@ -64,7 +64,7 @@ impl Database {
id: BillingSubscriptionId,
params: &UpdateBillingSubscriptionParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
billing_subscription::Entity::update(billing_subscription::ActiveModel {
id: ActiveValue::set(id),
billing_customer_id: params.billing_customer_id.clone(),
@@ -90,7 +90,7 @@ impl Database {
&self,
id: BillingSubscriptionId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find_by_id(id)
.one(&*tx)
.await?)
@@ -103,7 +103,7 @@ impl Database {
&self,
stripe_subscription_id: &str,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.filter(
billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
@@ -118,7 +118,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
@@ -152,7 +152,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Vec<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
@@ -169,7 +169,7 @@ impl Database {
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
self.weak_transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
@@ -201,7 +201,7 @@ impl Database {
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
self.weak_transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
@@ -236,7 +236,7 @@ impl Database {
/// Returns the count of the active billing subscriptions for the user with the specified ID.
pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let count = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(

View File

@@ -501,8 +501,10 @@ impl Database {
/// Returns all channels for the user with the given ID.
pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
self.transaction(|tx| async move { self.get_user_channels(user_id, None, true, &tx).await })
.await
self.weak_transaction(
|tx| async move { self.get_user_channels(user_id, None, true, &tx).await },
)
.await
}
/// Returns all channels for the user with the given ID that are descendants
@@ -732,8 +734,8 @@ impl Database {
users.push(proto::User {
id: user.id.to_proto(),
avatar_url: format!(
"https://avatars.githubusercontent.com/u/{}?s=128&v=4",
user.github_user_id
"https://github.com/{}.png?size=128",
user.github_login
),
github_login: user.github_login,
name: user.name,

View File

@@ -15,7 +15,7 @@ impl Database {
user_b_busy: bool,
}
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let user_a_participant = Alias::new("user_a_participant");
let user_b_participant = Alias::new("user_b_participant");
let mut db_contacts = contact::Entity::find()
@@ -91,7 +91,7 @@ impl Database {
/// Returns whether the given user is a busy (on a call).
pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let participant = room_participant::Entity::find()
.filter(room_participant::Column::UserId.eq(user_id))
.one(&*tx)

View File

@@ -9,7 +9,7 @@ pub enum ContributorSelector {
impl Database {
/// Retrieves the GitHub logins of all users who have signed the CLA.
pub async fn get_contributors(&self) -> Result<Vec<String>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryGithubLogin {
GithubLogin,
@@ -32,7 +32,7 @@ impl Database {
&self,
selector: &ContributorSelector,
) -> Result<Option<DateTime>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let condition = match selector {
ContributorSelector::GitHubUserId { github_user_id } => {
user::Column::GithubUserId.eq(*github_user_id)
@@ -69,7 +69,7 @@ impl Database {
github_user_created_at: DateTimeUtc,
initial_channel_id: Option<ChannelId>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let user = self
.update_or_create_user_by_github_account_tx(
github_login,

View File

@@ -8,7 +8,7 @@ impl Database {
model: &str,
digests: &[Vec<u8>],
) -> Result<HashMap<Vec<u8>, Vec<f32>>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let embeddings = {
let mut db_embeddings = embedding::Entity::find()
.filter(
@@ -52,7 +52,7 @@ impl Database {
model: &str,
embeddings: &HashMap<Vec<u8>, Vec<f32>>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
embedding::Entity::insert_many(embeddings.iter().map(|(digest, dimensions)| {
let now_offset_datetime = OffsetDateTime::now_utc();
let retrieved_at =
@@ -78,7 +78,7 @@ impl Database {
}
pub async fn purge_old_embeddings(&self) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
embedding::Entity::delete_many()
.filter(
embedding::Column::RetrievedAt

View File

@@ -15,7 +15,7 @@ impl Database {
max_schema_version: i32,
limit: usize,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let mut condition = Condition::all()
.add(
extension::Column::LatestVersion
@@ -43,7 +43,7 @@ impl Database {
ids: &[&str],
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extensions = extension::Entity::find()
.filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
.all(&*tx)
@@ -123,7 +123,7 @@ impl Database {
&self,
extension_id: &str,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let condition = extension::Column::ExternalId
.eq(extension_id)
.into_condition();
@@ -162,7 +162,7 @@ impl Database {
extension_id: &str,
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.one(&*tx)
@@ -187,7 +187,7 @@ impl Database {
extension_id: &str,
version: &str,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.filter(extension_version::Column::Version.eq(version))
@@ -204,7 +204,7 @@ impl Database {
}
pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let mut extension_external_ids_by_id = HashMap::default();
let mut rows = extension::Entity::find().stream(&*tx).await?;
@@ -242,7 +242,7 @@ impl Database {
&self,
versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
for (external_id, versions) in versions_by_extension_id {
if versions.is_empty() {
continue;
@@ -349,7 +349,7 @@ impl Database {
}
pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryId {
Id,

View File

@@ -13,7 +13,7 @@ impl Database {
&self,
params: &CreateProcessedStripeEventParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel {
stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()),
stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()),
@@ -35,7 +35,7 @@ impl Database {
&self,
event_id: &str,
) -> Result<Option<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find_by_id(event_id)
.one(&*tx)
.await?)
@@ -48,7 +48,7 @@ impl Database {
&self,
event_ids: &[&str],
) -> Result<Vec<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find()
.filter(
processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()),

Some files were not shown because too many files have changed in this diff Show More