Merge branch 'main' into websearch-tool

This commit is contained in:
Bennet Bo Fenner
2025-04-15 15:33:24 +02:00
317 changed files with 10368 additions and 6202 deletions

19
.github/ISSUE_TEMPLATE/99_other.yml vendored Normal file
View File

@@ -0,0 +1,19 @@
name: Other [Staff Only]
description: Zed Staff Only
body:
- type: textarea
attributes:
label: Summary
value: |
<!-- Please insert a one line summary of the issue below -->
SUMMARY_SENTENCE_HERE
### Description
IF YOU DO NOT WORK FOR ZED INDUSTRIES DO NOT CREATE ISSUES WITH THIS TEMPLATE.
THEY WILL BE AUTO-CLOSED AND MAY RESULT IN YOU BEING BANNED FROM THE ZED ISSUE TRACKER.
FEATURE REQUESTS / SUPPORT REQUESTS SHOULD BE OPENED AS DISCUSSIONS:
https://github.com/zed-industries/zed/discussions/new/choose
validations:
required: true

View File

@@ -594,7 +594,7 @@ jobs:
timeout-minutes: 60
name: Linux x86_x64 release bundle
runs-on:
- buildjet-16vcpu-ubuntu-2004
- buildjet-16vcpu-ubuntu-2004 # ubuntu 20.04 for minimal glibc
if: |
startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
@@ -622,26 +622,23 @@ jobs:
- name: Create Linux .tar.gz bundle
run: script/bundle-linux
- name: Upload Linux bundle to workflow run if main branch or specific label
- name: Upload Artifact to Workflow - zed (run-bundling)
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
with:
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.tar.gz
path: target/release/zed-*.tar.gz
- name: Upload Linux remote server to workflow run if main branch or specific label
- name: Upload Artifact to Workflow - zed-remote-server (run-bundling)
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
with:
name: zed-remote-server-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.gz
path: target/zed-remote-server-linux-x86_64.gz
- name: Upload app bundle to release
- name: Upload Artifacts to release
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) }}
with:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
@@ -680,29 +677,26 @@ jobs:
# This exports RELEASE_CHANNEL into env (GITHUB_ENV)
script/determine-release-channel
- name: Create and upload Linux .tar.gz bundle
- name: Create and upload Linux .tar.gz bundles
run: script/bundle-linux
- name: Upload Linux bundle to workflow run if main branch or specific label
- name: Upload Artifact to Workflow - zed (run-bundling)
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
with:
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.tar.gz
path: target/release/zed-*.tar.gz
- name: Upload Linux remote server to workflow run if main branch or specific label
- name: Upload Artifact to Workflow - zed-remote-server (run-bundling)
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
if: |
github.ref == 'refs/heads/main'
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
with:
name: zed-remote-server-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.gz
path: target/zed-remote-server-linux-aarch64.gz
- name: Upload app bundle to release
- name: Upload Artifacts to release
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) }}
with:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}

View File

@@ -117,12 +117,10 @@ jobs:
export ZED_KUBE_NAMESPACE=production
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=10
export ZED_API_LOAD_BALANCER_SIZE_UNIT=2
export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=2
elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then
export ZED_KUBE_NAMESPACE=staging
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=1
export ZED_API_LOAD_BALANCER_SIZE_UNIT=1
export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=1
else
echo "cowardly refusing to deploy from an unknown branch"
exit 1
@@ -147,9 +145,3 @@ jobs:
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
export ZED_SERVICE_NAME=llm
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_LLM_LOAD_BALANCER_SIZE_UNIT
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"

45
Cargo.lock generated
View File

@@ -326,7 +326,6 @@ dependencies = [
"serde_json",
"strum",
"thiserror 2.0.12",
"util",
"workspace-hack",
]
@@ -1183,6 +1182,18 @@ dependencies = [
"workspace-hack",
]
[[package]]
name = "auto_update_helper"
version = "0.1.0"
dependencies = [
"anyhow",
"log",
"simplelog",
"windows 0.61.1",
"winresource",
"workspace-hack",
]
[[package]]
name = "auto_update_ui"
version = "0.1.0"
@@ -2932,7 +2943,6 @@ dependencies = [
name = "collab"
version = "0.44.0"
dependencies = [
"anthropic",
"anyhow",
"assistant",
"assistant_context_editor",
@@ -3176,14 +3186,18 @@ dependencies = [
name = "component_preview"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"collections",
"component",
"db",
"gpui",
"languages",
"notifications",
"project",
"serde",
"ui",
"ui_input",
"workspace",
"workspace-hack",
]
@@ -3988,7 +4002,6 @@ dependencies = [
"node_runtime",
"parking_lot",
"paths",
"regex",
"schemars",
"serde",
"serde_json",
@@ -4020,7 +4033,6 @@ dependencies = [
"gpui",
"language",
"paths",
"regex",
"serde",
"serde_json",
"task",
@@ -4164,6 +4176,7 @@ dependencies = [
"collections",
"command_palette_hooks",
"dap",
"db",
"editor",
"env_logger 0.11.8",
"feature_flags",
@@ -4863,25 +4876,37 @@ dependencies = [
"assistant_settings",
"assistant_tool",
"assistant_tools",
"async-watch",
"chrono",
"clap",
"client",
"collections",
"context_server",
"dap",
"env_logger 0.11.8",
"extension",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
"language",
"language_extension",
"language_model",
"language_models",
"languages",
"node_runtime",
"paths",
"project",
"prompt_store",
"release_channel",
"reqwest_client",
"serde",
"settings",
"shellexpand 2.1.2",
"toml 0.8.20",
"unindent",
"util",
"workspace-hack",
]
@@ -4976,10 +5001,10 @@ dependencies = [
"async-tar",
"async-trait",
"collections",
"convert_case 0.8.0",
"fs",
"futures 0.3.31",
"gpui",
"heck 0.5.0",
"http_client",
"language",
"log",
@@ -7654,6 +7679,7 @@ dependencies = [
name = "language_model_selector"
version = "0.1.0"
dependencies = [
"collections",
"feature_flags",
"gpui",
"language_model",
@@ -7704,6 +7730,7 @@ dependencies = [
"smol",
"strum",
"theme",
"thiserror 2.0.12",
"tiktoken-rs",
"tokio",
"ui",
@@ -17628,6 +17655,7 @@ dependencies = [
"ui",
"util",
"uuid",
"windows 0.61.1",
"workspace-hack",
"zed_actions",
]
@@ -17790,6 +17818,8 @@ dependencies = [
"wasmtime-cranelift",
"wasmtime-environ",
"winapi",
"windows-core 0.61.0",
"windows-numerics",
"windows-sys 0.48.0",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
@@ -18134,7 +18164,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.182.0"
version = "0.183.0"
dependencies = [
"activity_indicator",
"agent",
@@ -18230,7 +18260,6 @@ dependencies = [
"settings",
"settings_ui",
"shellexpand 2.1.2",
"simplelog",
"smol",
"snippet_provider",
"snippets_ui",
@@ -18581,7 +18610,9 @@ name = "zlog"
version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"log",
"tempfile",
"workspace-hack",
]

View File

@@ -15,6 +15,7 @@ members = [
"crates/assistant_tools",
"crates/audio",
"crates/auto_update",
"crates/auto_update_helper",
"crates/auto_update_ui",
"crates/aws_http_client",
"crates/bedrock",
@@ -224,6 +225,7 @@ assistant_tool = { path = "crates/assistant_tool" }
assistant_tools = { path = "crates/assistant_tools" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
auto_update_helper = { path = "crates/auto_update_helper" }
auto_update_ui = { path = "crates/auto_update_ui" }
aws_http_client = { path = "crates/aws_http_client" }
bedrock = { path = "crates/bedrock" }
@@ -403,8 +405,12 @@ async-tungstenite = "0.29.1"
async-watch = "0.3.1"
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
aws-config = { version = "1.6.1", features = ["behavior-version-latest"] }
aws-credential-types = { version = "1.2.2", features = ["hardcoded-credentials"] }
aws-sdk-bedrockruntime = { version = "1.80.0", features = ["behavior-version-latest"] }
aws-credential-types = { version = "1.2.2", features = [
"hardcoded-credentials",
] }
aws-sdk-bedrockruntime = { version = "1.80.0", features = [
"behavior-version-latest",
] }
aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
base64 = "0.22"
@@ -443,6 +449,7 @@ futures-lite = "1.13"
git2 = { version = "0.20.1", default-features = false }
globset = "0.4"
handlebars = "4.3"
heck = "0.5"
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
hex = "0.4.3"
html5ever = "0.27.0"
@@ -619,12 +626,10 @@ features = [
[workspace.dependencies.windows]
version = "0.61"
features = [
"Foundation_Collections",
"Foundation_Numerics",
"Storage_Search",
"Storage_Streams",
"System_Threading",
"UI_StartScreen",
"UI_ViewManagement",
"Wdk_System_SystemServices",
"Win32_Globalization",
@@ -651,6 +656,7 @@ features = [
"Win32_System_SystemInformation",
"Win32_System_SystemServices",
"Win32_System_Threading",
"Win32_System_Variant",
"Win32_System_WinRT",
"Win32_UI_Controls",
"Win32_UI_HiDpi",
@@ -658,6 +664,7 @@ features = [
"Win32_UI_Input_KeyboardAndMouse",
"Win32_UI_Shell",
"Win32_UI_Shell_Common",
"Win32_UI_Shell_PropertiesSystem",
"Win32_UI_WindowsAndMessaging",
]
@@ -781,4 +788,12 @@ let_underscore_future = "allow"
too_many_arguments = "allow"
[workspace.metadata.cargo-machete]
ignored = ["bindgen", "cbindgen", "prost_build", "serde", "component", "linkme", "workspace-hack"]
ignored = [
"bindgen",
"cbindgen",
"prost_build",
"serde",
"component",
"linkme",
"workspace-hack",
]

View File

@@ -1,12 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="16" height="16" rx="2" fill="black" fill-opacity="0.2"/>
<g clip-path="url(#clip0_1916_18)">
<path d="M10.652 3.79999H8.816L12.164 12.2H14L10.652 3.79999Z" fill="#1F1F1E"/>
<path d="M5.348 3.79999L2 12.2H3.872L4.55672 10.436H8.05927L8.744 12.2H10.616L7.268 3.79999H5.348ZM5.16224 8.87599L6.308 5.92399L7.45374 8.87599H5.16224Z" fill="#1F1F1E"/>
</g>
<defs>
<clipPath id="clip0_1916_18">
<rect width="12" height="8.4" fill="white" transform="translate(2 3.79999)"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 601 B

5
assets/icons/layout.svg Normal file
View File

@@ -0,0 +1,5 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M20 14H4C3.44772 14 3 14.4477 3 15V20C3 20.5523 3.44772 21 4 21H20C20.5523 21 21 20.5523 21 20V15C21 14.4477 20.5523 14 20 14Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M11 3H4C3.44772 3 3 3.44772 3 4V9C3 9.55228 3.44772 10 4 10H11C11.5523 10 12 9.55228 12 9V4C12 3.44772 11.5523 3 11 3Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M20 3H17C16.4477 3 16 3.44772 16 4V9C16 9.55228 16.4477 10 17 10H20C20.5523 10 21 9.55228 21 9V4C21 3.44772 20.5523 3 20 3Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 746 B

View File

@@ -150,7 +150,9 @@
"context": "AgentDiff",
"bindings": {
"ctrl-y": "agent::Keep",
"ctrl-n": "agent::Reject"
"ctrl-n": "agent::Reject",
"ctrl-shift-y": "agent::KeepAll",
"ctrl-shift-n": "agent::RejectAll"
}
},
{
@@ -352,11 +354,11 @@
"alt-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }],
"ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match
"ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }],
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }],
"ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }],
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
"ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch
"ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip
"ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch
"ctrl-k ctrl-i": "editor::Hover",
"ctrl-/": ["editor::ToggleComments", { "advance_downwards": false }],
"ctrl-u": "editor::UndoSelection",
@@ -780,6 +782,7 @@
"shift-tab": "git_panel::FocusEditor",
"escape": "git_panel::ToggleFocus",
"ctrl-enter": "git::Commit",
"ctrl-shift-enter": "git::Amend",
"alt-enter": "menu::SecondaryConfirm",
"delete": ["git::RestoreFile", { "skip_prompt": false }],
"backspace": ["git::RestoreFile", { "skip_prompt": false }],
@@ -788,12 +791,20 @@
"ctrl-delete": ["git::RestoreFile", { "skip_prompt": false }]
}
},
{
"context": "GitPanel && CommitEditor",
"use_key_equivalents": true,
"bindings": {
"escape": "git::Cancel"
}
},
{
"context": "GitCommit > Editor",
"bindings": {
"escape": "menu::Cancel",
"enter": "editor::Newline",
"ctrl-enter": "git::Commit",
"ctrl-shift-enter": "git::Amend",
"alt-l": "git::GenerateCommitMessage"
}
},
@@ -815,6 +826,7 @@
"context": "GitDiff > Editor",
"bindings": {
"ctrl-enter": "git::Commit",
"ctrl-shift-enter": "git::Amend",
"ctrl-space": "git::StageAll",
"ctrl-shift-space": "git::UnstageAll"
}
@@ -833,6 +845,7 @@
"shift-tab": "git_panel::FocusChanges",
"enter": "editor::Newline",
"ctrl-enter": "git::Commit",
"ctrl-shift-enter": "git::Amend",
"alt-up": "git_panel::FocusChanges",
"alt-l": "git::GenerateCommitMessage"
}

View File

@@ -242,7 +242,9 @@
"use_key_equivalents": true,
"bindings": {
"cmd-y": "agent::Keep",
"cmd-n": "agent::Reject"
"cmd-n": "agent::Reject",
"cmd-shift-y": "agent::KeepAll",
"cmd-shift-n": "agent::RejectAll"
}
},
{
@@ -489,12 +491,15 @@
"alt-shift-down": "editor::DuplicateLineDown",
"ctrl-shift-right": "editor::SelectLargerSyntaxNode", // Expand Selection
"ctrl-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection
"cmd-d": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match
"cmd-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
"cmd-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
"cmd-f2": "editor::SelectAllMatches", // Select all occurrences of current word
"ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": false }],
"cmd-k cmd-d": ["editor::SelectNext", { "replace_newest": true }],
"cmd-k ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": true }],
"cmd-k cmd-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip
// macOS binds `ctrl-cmd-d` to Show Dictionary which breaks these two binds
// To use `ctrl-cmd-d` or `ctrl-k ctrl-cmd-d` in Zed you must execute this command and then restart:
// defaults write com.apple.symbolichotkeys AppleSymbolicHotKeys -dict-add 70 '<dict><key>enabled</key><false/></dict>'
"ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch
"cmd-k ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch
"cmd-k cmd-i": "editor::Hover",
"cmd-/": ["editor::ToggleComments", { "advance_downwards": false }],
"cmd-u": "editor::UndoSelection",
@@ -850,17 +855,26 @@
"shift-tab": "git_panel::FocusEditor",
"escape": "git_panel::ToggleFocus",
"cmd-enter": "git::Commit",
"cmd-shift-enter": "git::Amend",
"backspace": ["git::RestoreFile", { "skip_prompt": false }],
"delete": ["git::RestoreFile", { "skip_prompt": false }],
"cmd-backspace": ["git::RestoreFile", { "skip_prompt": true }],
"cmd-delete": ["git::RestoreFile", { "skip_prompt": true }]
}
},
{
"context": "GitPanel && CommitEditor",
"use_key_equivalents": true,
"bindings": {
"escape": "git::Cancel"
}
},
{
"context": "GitDiff > Editor",
"use_key_equivalents": true,
"bindings": {
"cmd-enter": "git::Commit",
"cmd-shift-enter": "git::Amend",
"cmd-ctrl-y": "git::StageAll",
"cmd-ctrl-shift-y": "git::UnstageAll"
}
@@ -871,6 +885,7 @@
"bindings": {
"enter": "editor::Newline",
"cmd-enter": "git::Commit",
"cmd-shift-enter": "git::Amend",
"tab": "git_panel::FocusChanges",
"shift-tab": "git_panel::FocusChanges",
"alt-up": "git_panel::FocusChanges",
@@ -900,6 +915,7 @@
"enter": "editor::Newline",
"escape": "menu::Cancel",
"cmd-enter": "git::Commit",
"cmd-shift-enter": "git::Amend",
"alt-tab": "git::GenerateCommitMessage"
}
},

View File

@@ -37,6 +37,8 @@
"ctrl-shift-a": "editor::SelectLargerSyntaxNode",
"ctrl-shift-d": "editor::DuplicateSelection",
"alt-f3": "editor::SelectAllMatches", // find_all_under
// "ctrl-f3": "", // find_under (cancels any selections)
// "cmd-alt-shift-g": "" // find_under_prev (cancels any selections)
"f9": "editor::SortLinesCaseSensitive",
"ctrl-f9": "editor::SortLinesCaseInsensitive",
"f12": "editor::GoToDefinition",

View File

@@ -38,6 +38,8 @@
"cmd-shift-a": "editor::SelectLargerSyntaxNode",
"cmd-shift-d": "editor::DuplicateSelection",
"ctrl-cmd-g": "editor::SelectAllMatches", // find_all_under
// "cmd-alt-g": "", // find_under (cancels any selections)
// "cmd-alt-shift-g": "" // find_under_prev (cancels any selections)
"f5": "editor::SortLinesCaseSensitive",
"ctrl-f5": "editor::SortLinesCaseInsensitive",
"shift-f12": "editor::FindAllReferences",

View File

@@ -539,6 +539,7 @@
"bindings": {
"d": "vim::CurrentLine",
"s": "vim::PushDeleteSurrounds",
"v": "vim::PushForcedMotion", // "d v"
"o": "editor::ToggleSelectedDiffHunks", // "d o"
"shift-o": "git::ToggleStaged",
"p": "git::Restore", // "d p"
@@ -587,6 +588,7 @@
"context": "vim_operator == y",
"bindings": {
"y": "vim::CurrentLine",
"v": "vim::PushForcedMotion",
"s": ["vim::PushAddSurrounds", {}]
}
},

View File

@@ -80,6 +80,8 @@
// Values are clamped to the [0.0, 1.0] range.
"inactive_opacity": 1.0
},
// Layout mode of the bottom dock. Defaults to "contained"
"bottom_dock_layout": "contained",
// The direction that you want to split panes horizontally. Defaults to "up"
"pane_split_direction_horizontal": "up",
// The direction that you want to split panes horizontally. Defaults to "left"
@@ -642,6 +644,7 @@
// We don't know which of the context server tools are safe for the "Ask" profile, so we don't enable them by default.
// "enable_all_context_servers": true,
"tools": {
"contents": true,
"diagnostics": true,
"fetch": true,
"list_directory": false,
@@ -661,6 +664,7 @@
"batch_tool": true,
"code_actions": true,
"code_symbols": true,
"contents": true,
"copy_path": false,
"create_file": true,
"delete_path": false,

View File

@@ -13,18 +13,18 @@ use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use assistant_tool::ToolUseStatus;
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
use editor::{Editor, MultiBuffer};
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem,
DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Hsla, ListAlignment, ListState,
MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, Task,
TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
use project::ProjectItem as _;
use rope::Point;
use settings::{Settings as _, update_settings_file};
@@ -34,7 +34,9 @@ use std::sync::Arc;
use std::time::Duration;
use text::ToPoint;
use theme::ThemeSettings;
use ui::{Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*};
use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
};
use util::ResultExt as _;
use workspace::{OpenOptions, Workspace};
@@ -66,8 +68,6 @@ pub struct ActiveThread {
open_feedback_editors: HashMap<MessageId, Entity<Editor>>,
}
const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 5;
struct RenderedMessage {
language_registry: Arc<LanguageRegistry>,
segments: Vec<RenderedMessageSegment>,
@@ -176,11 +176,37 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
});
MarkdownStyle {
base_text_style: text_style,
base_text_style: text_style.clone(),
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block_overflow_x_scroll: true,
table_overflow_x_scroll: true,
heading_level_styles: Some(HeadingLevelStyles {
h1: Some(TextStyleRefinement {
font_size: Some(rems(1.15).into()),
..Default::default()
}),
h2: Some(TextStyleRefinement {
font_size: Some(rems(1.1).into()),
..Default::default()
}),
h3: Some(TextStyleRefinement {
font_size: Some(rems(1.05).into()),
..Default::default()
}),
h4: Some(TextStyleRefinement {
font_size: Some(rems(1.).into()),
..Default::default()
}),
h5: Some(TextStyleRefinement {
font_size: Some(rems(0.95).into()),
..Default::default()
}),
h6: Some(TextStyleRefinement {
font_size: Some(rems(0.875).into()),
..Default::default()
}),
}),
code_block: StyleRefinement {
padding: EdgesRefinement {
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
@@ -292,6 +318,8 @@ fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
}
}
const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 10;
fn render_markdown_code_block(
message_id: MessageId,
ix: usize,
@@ -578,7 +606,7 @@ fn render_markdown_code_block(
if is_expanded {
this.h_full()
} else {
this.max_h_40()
this.max_h_80()
}
},
)
@@ -1497,12 +1525,36 @@ impl ActiveThread {
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.child(edit_message_editor)
.pt_1()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.into_any()
} else {
div()
@@ -1667,11 +1719,9 @@ impl ActiveThread {
),
Role::Assistant => v_flex()
.id(("message-container", ix))
.ml_2()
.ml_2p5()
.pl_2()
.pr_4()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.children(message_content)
.when(has_tool_uses, |parent| {
parent.children(

View File

@@ -1,4 +1,4 @@
use crate::{Keep, Reject, Thread, ThreadEvent};
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
use anyhow::Result;
use buffer_diff::DiffHunkStatus;
use collections::HashSet;
@@ -843,7 +843,7 @@ impl ToolbarItemView for AgentDiffToolbar {
}
impl Render for AgentDiffToolbar {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let agent_diff = match self.agent_diff(cx) {
Some(ad) => ad,
None => return div(),
@@ -855,6 +855,8 @@ impl Render for AgentDiffToolbar {
return div();
}
let focus_handle = agent_diff.focus_handle(cx);
h_group_xl()
.my_neg_1()
.items_center()
@@ -864,15 +866,25 @@ impl Render for AgentDiffToolbar {
.child(
h_group_sm()
.child(
Button::new("reject-all", "Reject All").on_click(cx.listener(
|this, _, window, cx| {
this.dispatch_action(&crate::RejectAll, window, cx)
},
)),
Button::new("reject-all", "Reject All")
.key_binding({
KeyBinding::for_action_in(&RejectAll, &focus_handle, window, cx)
.map(|kb| kb.size(rems_from_px(12.)))
})
.on_click(cx.listener(|this, _, window, cx| {
this.dispatch_action(&RejectAll, window, cx)
})),
)
.child(Button::new("keep-all", "Keep All").on_click(cx.listener(
|this, _, window, cx| this.dispatch_action(&crate::KeepAll, window, cx),
))),
.child(
Button::new("keep-all", "Keep All")
.key_binding({
KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx)
.map(|kb| kb.size(rems_from_px(12.)))
})
.on_click(cx.listener(|this, _, window, cx| {
this.dispatch_action(&KeepAll, window, cx)
})),
),
)
}
}
@@ -882,6 +894,7 @@ mod tests {
use super::*;
use crate::{ThreadStore, thread_store};
use assistant_settings::AssistantSettings;
use assistant_tool::ToolWorkingSet;
use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::TestAppContext;
@@ -925,7 +938,7 @@ mod tests {
.update(|cx| {
ThreadStore::load(
project.clone(),
Arc::default(),
cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)

View File

@@ -12,7 +12,9 @@ use fs::Fs;
use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
use settings::{Settings, update_settings_file};
use ui::{Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, prelude::*};
use ui::{
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip, prelude::*,
};
use util::ResultExt as _;
use zed_actions::ExtensionCategoryFilter;
@@ -27,7 +29,7 @@ pub struct AssistantConfiguration {
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
context_server_manager: Entity<ContextServerManager>,
expanded_context_server_tools: HashMap<Arc<str>, bool>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
_registry_subscription: Subscription,
}
@@ -35,7 +37,7 @@ impl AssistantConfiguration {
pub fn new(
fs: Arc<dyn Fs>,
context_server_manager: Entity<ContextServerManager>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -224,7 +226,7 @@ impl AssistantConfiguration {
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
let tools_by_source = self.tools.tools_by_source(cx);
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let empty = Vec::new();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
@@ -236,7 +238,10 @@ impl AssistantConfiguration {
.child(
v_flex()
.gap_0p5()
.child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
.child(
Headline::new("Model Context Protocol (MCP) Servers")
.size(HeadlineSize::Small),
)
.child(Label::new(SUBHEADING).color(Color::Muted)),
)
.children(context_servers.into_iter().map(|context_server| {
@@ -262,10 +267,9 @@ impl AssistantConfiguration {
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.p_1()
.justify_between()
.px_2()
.py_1()
.when(are_tools_expanded, |element| {
.when(are_tools_expanded && tool_count > 1, |element| {
element
.border_b_1()
.border_color(cx.theme().colors().border)
@@ -275,6 +279,7 @@ impl AssistantConfiguration {
.gap_2()
.child(
Disclosure::new("tool-list-disclosure", are_tools_expanded)
.disabled(tool_count == 0)
.on_click(cx.listener({
let context_server_id = context_server.id();
move |this, _event, _window, _cx| {
@@ -295,10 +300,11 @@ impl AssistantConfiguration {
.child(Label::new(context_server.id()))
.child(
Label::new(format!("{tool_count} tools"))
.color(Color::Muted),
.color(Color::Muted)
.size(LabelSize::Small),
),
)
.child(h_flex().child(
.child(
Switch::new("context-server-switch", is_running.into()).on_click({
let context_server_manager =
self.context_server_manager.clone();
@@ -334,7 +340,7 @@ impl AssistantConfiguration {
}
}
}),
)),
),
)
.map(|parent| {
if !are_tools_expanded {
@@ -344,14 +350,29 @@ impl AssistantConfiguration {
parent.child(v_flex().children(tools.into_iter().enumerate().map(
|(ix, tool)| {
h_flex()
.px_2()
.id("tool-item")
.pl_2()
.pr_1()
.py_1()
.gap_2()
.justify_between()
.when(ix < tool_count - 1, |element| {
element
.border_b_1()
.border_color(cx.theme().colors().border)
.border_color(cx.theme().colors().border_variant)
})
.child(Label::new(tool.name()))
.child(
Label::new(tool.name())
.buffer_font(cx)
.size(LabelSize::Small),
)
.child(
IconButton::new(("tool-description", ix), IconName::Info)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Ignored)
.tooltip(Tooltip::text(tool.description())),
)
},
)))
})
@@ -362,7 +383,7 @@ impl AssistantConfiguration {
.gap_2()
.child(
h_flex().w_full().child(
Button::new("add-context-server", "Add Context Server")
Button::new("add-context-server", "Add MCPs Directly")
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
.full_width()
@@ -378,7 +399,7 @@ impl AssistantConfiguration {
h_flex().w_full().child(
Button::new(
"install-context-server-extensions",
"Install Context Server Extensions",
"Install MCP Extensions",
)
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)

View File

@@ -84,7 +84,7 @@ pub struct NewProfileMode {
pub struct ManageProfilesModal {
fs: Arc<dyn Fs>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle,
mode: Mode,
@@ -117,7 +117,7 @@ impl ManageProfilesModal {
pub fn new(
fs: Arc<dyn Fs>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut Context<Self>,

View File

@@ -60,7 +60,7 @@ pub struct ToolPickerDelegate {
impl ToolPickerDelegate {
pub fn new(
fs: Arc<dyn Fs>,
tool_set: Arc<ToolWorkingSet>,
tool_set: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
profile_id: AgentProfileId,
profile: AgentProfile,
@@ -68,7 +68,7 @@ impl ToolPickerDelegate {
) -> Self {
let mut tool_entries = Vec::new();
for (source, tools) in tool_set.tools_by_source(cx) {
for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
name: tool.name().into(),
source: source.clone(),
@@ -192,7 +192,7 @@ impl PickerDelegate for ToolPickerDelegate {
if active_profile_id == &self.profile_id {
self.thread_store
.update(cx, |this, cx| {
this.load_profile(&self.profile, cx);
this.load_profile(self.profile.clone(), cx);
})
.log_err();
}

View File

@@ -80,17 +80,16 @@ impl AssistantModelSelector {
impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let model_registry = LanguageModelRegistry::read_global(cx);
let focus_handle = self.focus_handle.clone();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = match self.model_type {
ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
};
let focus_handle = self.focus_handle.clone();
let model_name = match model {
Some(model) => model.model.name().0,
_ => SharedString::from("No model selected"),
let (model_name, model_icon) = match model {
Some(model) => (model.model.name().0, Some(model.provider.icon())),
_ => (SharedString::from("No model selected"), None),
};
LanguageModelSelectorPopoverMenu::new(
@@ -100,10 +99,16 @@ impl Render for AssistantModelSelector {
.child(
h_flex()
.gap_0p5()
.children(
model_icon.map(|icon| {
Icon::new(icon).color(Color::Muted).size(IconSize::Small)
}),
)
.child(
Label::new(model_name)
.size(LabelSize::Small)
.color(Color::Muted),
.color(Color::Muted)
.ml_1(),
)
.child(
Icon::new(IconName::ChevronDown)

View File

@@ -44,8 +44,8 @@ use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
use crate::thread_history::{PastContext, PastThread, ThreadHistory};
use crate::thread_store::ThreadStore;
use crate::{
AgentDiff, InlineAssistant, NewTextThread, NewThread, OpenActiveThreadAsMarkdown,
OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker,
AgentDiff, ExpandMessageEditor, InlineAssistant, NewTextThread, NewThread,
OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker,
};
pub fn init(cx: &mut App) {
@@ -90,6 +90,16 @@ pub fn init(cx: &mut App) {
let thread = panel.read(cx).thread.read(cx).thread().clone();
AgentDiff::deploy_in_workspace(thread, workspace, window, cx);
}
})
.register_action(|workspace, _: &ExpandMessageEditor, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.message_editor.update(cx, |editor, cx| {
editor.expand_message_editor(&ExpandMessageEditor, window, cx);
});
});
}
});
},
)
@@ -193,7 +203,7 @@ impl AssistantPanel {
cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
let tools = Arc::new(ToolWorkingSet::default());
let tools = cx.new(|_| ToolWorkingSet::default())?;
let thread_store = workspace
.update(cx, |workspace, cx| {
let project = workspace.project().clone();
@@ -559,6 +569,7 @@ impl AssistantPanel {
ActiveView::Configuration | ActiveView::History => {
self.active_view =
ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx);
self.message_editor.focus_handle(cx).focus(window);
cx.notify();
}
_ => {}
@@ -1088,20 +1099,30 @@ impl AssistantPanel {
window,
cx,
|menu, _window, _cx| {
menu.action(
menu
.when(!is_empty, |menu| {
menu.action(
"Start New From Summary",
Box::new(NewThread {
from_thread_id: Some(thread_id.clone()),
}),
).separator()
})
.action(
"New Text Thread",
NewTextThread.boxed_clone(),
)
.when(!is_empty, |menu| {
menu.action(
"Continue in New Thread",
Box::new(NewThread {
from_thread_id: Some(thread_id.clone()),
}),
)
})
.separator()
.action("Settings", OpenConfiguration.boxed_clone())
.separator()
.action(
"Install MCPs",
zed_actions::Extensions {
category_filter: Some(
zed_actions::ExtensionCategoryFilter::ContextServers,
),
}
.boxed_clone(),
)
},
))
}),

View File

@@ -34,12 +34,6 @@ use crate::context_store::ContextStore;
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
#[derive(Debug, Clone, Copy)]
pub enum ConfirmBehavior {
KeepOpen,
Close,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContextPickerMode {
File,
@@ -105,7 +99,6 @@ pub(super) struct ContextPicker {
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
confirm_behavior: ConfirmBehavior,
_subscriptions: Vec<Subscription>,
}
@@ -114,7 +107,6 @@ impl ContextPicker {
workspace: WeakEntity<Workspace>,
thread_store: Option<WeakEntity<ThreadStore>>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -143,7 +135,6 @@ impl ContextPicker {
workspace,
context_store,
thread_store,
confirm_behavior,
_subscriptions: subscriptions,
}
}
@@ -166,37 +157,32 @@ impl ContextPicker {
let modes = supported_context_picker_modes(&self.thread_store);
let menu = menu
.when(has_recent, |menu| {
menu.custom_row(|_, _| {
div()
.mb_1()
.child(
Label::new("Recent")
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element()
})
menu.when(has_recent, |menu| {
menu.custom_row(|_, _| {
div()
.mb_1()
.child(
Label::new("Recent")
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element()
})
.extend(recent_entries)
.when(has_recent, |menu| menu.separator())
.extend(modes.into_iter().map(|mode| {
let context_picker = context_picker.clone();
})
.extend(recent_entries)
.when(has_recent, |menu| menu.separator())
.extend(modes.into_iter().map(|mode| {
let context_picker = context_picker.clone();
ContextMenuEntry::new(mode.label())
.icon(mode.icon())
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.handler(move |window, cx| {
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
})
}));
match self.confirm_behavior {
ConfirmBehavior::KeepOpen => menu.keep_open_on_confirm(),
ConfirmBehavior::Close => menu,
}
ContextMenuEntry::new(mode.label())
.icon(mode.icon())
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.handler(move |window, cx| {
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
})
}))
.keep_open_on_confirm()
});
cx.subscribe(&menu, move |_, _, _: &DismissEvent, cx| {
@@ -227,7 +213,6 @@ impl ContextPicker {
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -239,7 +224,6 @@ impl ContextPicker {
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -251,7 +235,6 @@ impl ContextPicker {
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)
@@ -264,7 +247,6 @@ impl ContextPicker {
thread_store.clone(),
context_picker.clone(),
self.context_store.clone(),
self.confirm_behavior,
window,
cx,
)

View File

@@ -11,7 +11,7 @@ use picker::{Picker, PickerDelegate};
use ui::{Context, ListItem, Window, prelude::*};
use workspace::Workspace;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
pub struct FetchContextPicker {
@@ -23,16 +23,10 @@ impl FetchContextPicker {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = FetchContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let delegate = FetchContextPickerDelegate::new(context_picker, workspace, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
@@ -62,7 +56,6 @@ pub struct FetchContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
url: String,
}
@@ -71,13 +64,11 @@ impl FetchContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
FetchContextPickerDelegate {
context_picker,
workspace,
context_store,
confirm_behavior,
url: String::new(),
}
}
@@ -204,25 +195,15 @@ impl PickerDelegate for FetchContextPickerDelegate {
let http_client = workspace.read(cx).client().http_client().clone();
let url = self.url.clone();
let confirm_behavior = self.confirm_behavior;
cx.spawn_in(window, async move |this, cx| {
let text = cx
.background_spawn(fetch_url_content(http_client, url.clone()))
.await?;
this.update_in(cx, |this, window, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_fetched_url(url, text, cx)
})?;
match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}
anyhow::Ok(())
this.update(cx, |this, cx| {
this.delegate.context_store.update(cx, |context_store, cx| {
context_store.add_fetched_url(url, text, cx)
})
})??;
anyhow::Ok(())

View File

@@ -11,9 +11,9 @@ use picker::{Picker, PickerDelegate};
use project::{PathMatchCandidateSet, ProjectPath, WorktreeId};
use ui::{ListItem, Tooltip, prelude::*};
use util::ResultExt as _;
use workspace::{Workspace, notifications::NotifyResultExt};
use workspace::Workspace;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_picker::ContextPicker;
use crate::context_store::{ContextStore, FileInclusion};
pub struct FileContextPicker {
@@ -25,16 +25,10 @@ impl FileContextPicker {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = FileContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let delegate = FileContextPickerDelegate::new(context_picker, workspace, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
@@ -57,7 +51,6 @@ pub struct FileContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<FileMatch>,
selected_index: usize,
}
@@ -67,13 +60,11 @@ impl FileContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
Self {
context_picker,
workspace,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
@@ -127,7 +118,7 @@ impl PickerDelegate for FileContextPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(FileMatch { mat, .. }) = self.matches.get(self.selected_index) else {
return;
};
@@ -153,17 +144,7 @@ impl PickerDelegate for FileContextPickerDelegate {
return;
};
let confirm_behavior = self.confirm_behavior;
cx.spawn_in(window, async move |this, cx| {
match task.await.notify_async_err(cx) {
None => anyhow::Ok(()),
Some(()) => this.update_in(cx, |this, window, cx| match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}),
}
})
.detach_and_log_err(cx);
task.detach_and_log_err(cx);
}
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {

View File

@@ -15,7 +15,7 @@ use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
pub struct SymbolContextPicker {
@@ -27,16 +27,10 @@ impl SymbolContextPicker {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = SymbolContextPickerDelegate::new(
context_picker,
workspace,
context_store,
confirm_behavior,
);
let delegate = SymbolContextPickerDelegate::new(context_picker, workspace, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
Self { picker }
@@ -59,7 +53,6 @@ pub struct SymbolContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<SymbolEntry>,
selected_index: usize,
}
@@ -69,13 +62,11 @@ impl SymbolContextPickerDelegate {
context_picker: WeakEntity<ContextPicker>,
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
Self {
context_picker,
workspace,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
@@ -135,7 +126,7 @@ impl PickerDelegate for SymbolContextPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(mat) = self.matches.get(self.selected_index) else {
return;
};
@@ -143,7 +134,6 @@ impl PickerDelegate for SymbolContextPickerDelegate {
return;
};
let confirm_behavior = self.confirm_behavior;
let add_symbol_task = add_symbol(
mat.symbol.clone(),
true,
@@ -153,16 +143,12 @@ impl PickerDelegate for SymbolContextPickerDelegate {
);
let selected_index = self.selected_index;
cx.spawn_in(window, async move |this, cx| {
cx.spawn(async move |this, cx| {
let included = add_symbol_task.await?;
this.update_in(cx, |this, window, cx| {
this.update(cx, |this, _| {
if let Some(mat) = this.delegate.matches.get_mut(selected_index) {
mat.is_included = included;
}
match confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}
})
})
.detach_and_log_err(cx);

View File

@@ -6,7 +6,7 @@ use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use ui::{ListItem, prelude::*};
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_picker::ContextPicker;
use crate::context_store::{self, ContextStore};
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
@@ -20,16 +20,11 @@ impl ThreadContextPicker {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
confirm_behavior: ConfirmBehavior,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = ThreadContextPickerDelegate::new(
thread_store,
context_picker,
context_store,
confirm_behavior,
);
let delegate =
ThreadContextPickerDelegate::new(thread_store, context_picker, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
ThreadContextPicker { picker }
@@ -58,7 +53,6 @@ pub struct ThreadContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<ThreadContextEntry>,
selected_index: usize,
}
@@ -68,13 +62,11 @@ impl ThreadContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
confirm_behavior: ConfirmBehavior,
) -> Self {
ThreadContextPickerDelegate {
thread_store,
context_picker,
context_store,
confirm_behavior,
matches: Vec::new(),
selected_index: 0,
}
@@ -127,7 +119,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(entry) = self.matches.get(self.selected_index) else {
return;
};
@@ -138,20 +130,15 @@ impl PickerDelegate for ThreadContextPickerDelegate {
let open_thread_task = thread_store.update(cx, |this, cx| this.open_thread(&entry.id, cx));
cx.spawn_in(window, async move |this, cx| {
cx.spawn(async move |this, cx| {
let thread = open_thread_task.await?;
this.update_in(cx, |this, window, cx| {
this.update(cx, |this, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_thread(thread, true, cx)
})
.ok();
match this.delegate.confirm_behavior {
ConfirmBehavior::KeepOpen => {}
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
}
})
})
.detach_and_log_err(cx);

View File

@@ -15,7 +15,7 @@ use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
use workspace::{Workspace, notifications::NotifyResultExt};
use crate::context::{ContextId, ContextKind};
use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
use crate::thread::Thread;
use crate::thread_store::ThreadStore;
@@ -52,7 +52,6 @@ impl ContextStrip {
workspace.clone(),
thread_store.clone(),
context_store.downgrade(),
ConfirmBehavior::KeepOpen,
window,
cx,
)

File diff suppressed because it is too large Load Diff

View File

@@ -86,7 +86,7 @@ impl ProfileSelector {
thread_store
.update(cx, |this, cx| {
this.load_profile_by_id(&profile_id, cx);
this.load_profile_by_id(profile_id.clone(), cx);
})
.log_err();
}

View File

@@ -15,10 +15,11 @@ use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
PaymentRequiredError, Role, StopReason, TokenUsage,
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, TokenUsage,
};
use project::Project;
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
@@ -228,7 +229,7 @@ pub struct TotalTokenUsage {
pub ratio: TokenUsageRatio,
}
#[derive(Default, PartialEq, Eq)]
#[derive(Debug, Default, PartialEq, Eq)]
pub enum TokenUsageRatio {
#[default]
Normal,
@@ -253,22 +254,31 @@ pub struct Thread {
pending_completions: Vec<PendingCompletion>,
project: Entity<Project>,
prompt_builder: Arc<PromptBuilder>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
tool_use: ToolUseState,
action_log: Entity<ActionLog>,
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
pending_checkpoint: Option<ThreadCheckpoint>,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>,
feedback: Option<ThreadFeedback>,
message_feedback: HashMap<MessageId, ThreadFeedback>,
last_auto_capture_at: Option<Instant>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExceededWindowError {
/// Model used when last message exceeded context window
model_id: LanguageModelId,
/// Token count including last message
token_count: usize,
}
impl Thread {
pub fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
system_prompt: SharedProjectContext,
cx: &mut Context<Self>,
@@ -301,6 +311,7 @@ impl Thread {
.shared()
},
cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
@@ -311,7 +322,7 @@ impl Thread {
id: ThreadId,
serialized: SerializedThread,
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext,
cx: &mut Context<Self>,
@@ -367,6 +378,7 @@ impl Thread {
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
@@ -446,7 +458,7 @@ impl Thread {
!self.pending_completions.is_empty() || !self.all_tools_finished()
}
pub fn tools(&self) -> &Arc<ToolWorkingSet> {
pub fn tools(&self) -> &Entity<ToolWorkingSet> {
&self.tools
}
@@ -819,8 +831,9 @@ impl Thread {
})
.collect(),
initial_project_snapshot,
cumulative_token_usage: this.cumulative_token_usage.clone(),
cumulative_token_usage: this.cumulative_token_usage,
detailed_summary_state: this.detailed_summary_state.clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
})
})
}
@@ -835,13 +848,21 @@ impl Thread {
if model.supports_tools() {
request.tools = {
let mut tools = Vec::new();
tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool.input_schema(model.tool_input_format()),
}
}));
tools.extend(
self.tools()
.read(cx)
.enabled_tools(cx)
.into_iter()
.filter_map(|tool| {
// Skip tools that cannot be supported
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema,
})
}),
);
tools
};
@@ -1000,7 +1021,7 @@ impl Thread {
let task = cx.spawn(async move |thread, cx| {
let stream = model.stream_completion(request, &cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async {
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
@@ -1022,9 +1043,9 @@ impl Thread {
stop_reason = reason;
}
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
thread.cumulative_token_usage =
thread.cumulative_token_usage.clone() + token_usage.clone()
- current_token_usage.clone();
thread.cumulative_token_usage = thread.cumulative_token_usage
+ token_usage
- current_token_usage;
current_token_usage = token_usage;
}
LanguageModelCompletionEvent::Text(chunk) => {
@@ -1133,6 +1154,20 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(
ThreadError::MaxMonthlySpendReached,
));
} else if let Some(known_error) =
error.downcast_ref::<LanguageModelKnownError>()
{
match known_error {
LanguageModelKnownError::ContextWindowLimitExceeded {
tokens,
} => {
thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(),
token_count: *tokens,
});
cx.notify();
}
}
} else {
let error_message = error
.chain()
@@ -1153,7 +1188,7 @@ impl Thread {
thread.auto_capture_telemetry(cx);
if let Ok(initial_usage) = initial_token_usage {
let usage = thread.cumulative_token_usage.clone() - initial_usage;
let usage = thread.cumulative_token_usage - initial_usage;
telemetry::event!(
"Assistant Thread Completion",
@@ -1324,7 +1359,7 @@ impl Thread {
.collect::<Vec<_>>();
for tool_use in pending_tool_uses.iter() {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
if tool.needs_confirmation(&tool_use.input, cx)
&& !AssistantSettings::get_global(cx).always_allow_tool_actions
{
@@ -1376,7 +1411,7 @@ impl Thread {
) -> Task<()> {
let tool_name: Arc<str> = tool.name().into();
let tool_result = if self.tools.is_disabled(&tool.source(), &tool_name) {
let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
ToolResult {
output: Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))),
card: None,
@@ -1500,6 +1535,7 @@ impl Thread {
let enabled_tool_names: Vec<String> = self
.tools()
.read(cx)
.enabled_tools(cx)
.iter()
.map(|tool| tool.name().to_string())
@@ -1797,10 +1833,6 @@ impl Thread {
&self.project
}
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone()
}
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
return;
@@ -1845,6 +1877,10 @@ impl Thread {
.detach();
}
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage
}
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.default_model() else {
@@ -1853,6 +1889,16 @@ impl Thread {
let max = model.model.max_token_count();
if let Some(exceeded_error) = &self.exceeded_window_error {
if model.model.id() == exceeded_error.model_id {
return TotalTokenUsage {
total: exceeded_error.token_count,
max,
ratio: TokenUsageRatio::Exceeded,
};
}
}
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.8".to_string())
@@ -2310,7 +2356,7 @@ fn main() {{
.update(|_, cx| {
ThreadStore::load(
project.clone(),
Arc::default(),
cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)

View File

@@ -4,11 +4,14 @@ use assistant_context_editor::SavedContextMetadata;
use editor::{Editor, EditorEvent};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, Entity, FocusHandle, Focusable, ScrollStrategy, Task, UniformListScrollHandle, WeakEntity,
Window, uniform_list,
App, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, UniformListScrollHandle,
WeakEntity, Window, uniform_list,
};
use time::{OffsetDateTime, UtcOffset};
use ui::{HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Tooltip, prelude::*};
use ui::{
HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState,
Tooltip, prelude::*,
};
use util::ResultExt;
use crate::history_store::{HistoryEntry, HistoryStore};
@@ -26,6 +29,8 @@ pub struct ThreadHistory {
matches: Vec<StringMatch>,
_subscriptions: Vec<gpui::Subscription>,
_search_task: Option<Task<()>>,
scrollbar_visibility: bool,
scrollbar_state: ScrollbarState,
}
impl ThreadHistory {
@@ -58,10 +63,13 @@ impl ThreadHistory {
this.update_all_entries(cx);
});
let scroll_handle = UniformListScrollHandle::default();
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
Self {
assistant_panel,
history_store,
scroll_handle: UniformListScrollHandle::default(),
scroll_handle,
selected_index: 0,
search_query: SharedString::new_static(""),
all_entries: entries,
@@ -69,6 +77,8 @@ impl ThreadHistory {
search_editor,
_subscriptions: vec![search_editor_subscription, history_store_subscription],
_search_task: None,
scrollbar_visibility: true,
scrollbar_state,
}
}
@@ -220,6 +230,43 @@ impl ThreadHistory {
cx.notify();
}
fn render_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) {
return None;
}
Some(
div()
.occlude()
.id("thread-history-scroll")
.h_full()
.bg(cx.theme().colors().panel_background.opacity(0.8))
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.absolute()
.right_1()
.top_0()
.bottom_0()
.w_4()
.pl_1()
.cursor_default()
.on_mouse_move(cx.listener(|_, _, _window, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _window, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _window, cx| {
cx.stop_propagation();
})
.on_scroll_wheel(cx.listener(|_, _, _window, cx| {
cx.notify();
}))
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
)
}
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
if let Some(entry) = self.get_match(self.selected_index) {
let task_result = match entry {
@@ -305,7 +352,11 @@ impl Render for ThreadHistory {
)
})
.child({
let view = v_flex().overflow_hidden().flex_grow();
let view = v_flex()
.id("list-container")
.relative()
.overflow_hidden()
.flex_grow();
if self.all_entries.is_empty() {
view.justify_center()
@@ -322,59 +373,70 @@ impl Render for ThreadHistory {
),
)
} else {
view.p_1().child(
uniform_list(
cx.entity().clone(),
"thread-history",
self.matched_count(),
move |history, range, _window, _cx| {
let range_start = range.start;
let assistant_panel = history.assistant_panel.clone();
view.pr_5()
.child(
uniform_list(
cx.entity().clone(),
"thread-history",
self.matched_count(),
move |history, range, _window, _cx| {
let range_start = range.start;
let assistant_panel = history.assistant_panel.clone();
let render_item = |index: usize,
entry: &HistoryEntry,
highlight_positions: Vec<usize>|
-> Div {
h_flex().w_full().pb_1().child(match entry {
HistoryEntry::Thread(thread) => PastThread::new(
thread.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
HistoryEntry::Context(context) => PastContext::new(
context.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
})
};
if history.has_search_query() {
history.matches[range]
.iter()
.enumerate()
.filter_map(|(index, m)| {
history.all_entries.get(m.candidate_id).map(|entry| {
render_item(index, entry, m.positions.clone())
})
let render_item = |index: usize,
entry: &HistoryEntry,
highlight_positions: Vec<usize>|
-> Div {
h_flex().w_full().pb_1().child(match entry {
HistoryEntry::Thread(thread) => PastThread::new(
thread.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
HistoryEntry::Context(context) => PastContext::new(
context.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
})
.collect()
} else {
history.all_entries[range]
.iter()
.enumerate()
.map(|(index, entry)| render_item(index, entry, vec![]))
.collect()
}
},
};
if history.has_search_query() {
history.matches[range]
.iter()
.enumerate()
.filter_map(|(index, m)| {
history.all_entries.get(m.candidate_id).map(
|entry| {
render_item(
index,
entry,
m.positions.clone(),
)
},
)
})
.collect()
} else {
history.all_entries[range]
.iter()
.enumerate()
.map(|(index, entry)| render_item(index, entry, vec![]))
.collect()
}
},
)
.p_1()
.track_scroll(self.scroll_handle.clone())
.flex_grow(),
)
.track_scroll(self.scroll_handle.clone())
.flex_grow(),
)
.when_some(self.render_scrollbar(cx), |div, scrollbar| {
div.child(scrollbar)
})
}
})
}
@@ -440,6 +502,7 @@ impl RenderOnce for PastThread {
IconButton::new("delete", IconName::TrashAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(move |window, cx| {
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
})
@@ -531,6 +594,7 @@ impl RenderOnce for PastContext {
IconButton::new("delete", IconName::TrashAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(move |window, cx| {
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
})

View File

@@ -27,7 +27,9 @@ use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
use crate::thread::{
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
};
const RULES_FILE_NAMES: [&'static str; 6] = [
".rules",
@@ -54,7 +56,7 @@ impl SharedProjectContext {
pub struct ThreadStore {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
@@ -72,7 +74,7 @@ impl EventEmitter<RulesLoadingError> for ThreadStore {}
impl ThreadStore {
pub fn load(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
) -> Task<Entity<Self>> {
@@ -86,7 +88,7 @@ impl ThreadStore {
fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
) -> Self {
@@ -246,7 +248,7 @@ impl ThreadStore {
self.context_server_manager.clone()
}
pub fn tools(&self) -> Arc<ToolWorkingSet> {
pub fn tools(&self) -> Entity<ToolWorkingSet> {
self.tools.clone()
}
@@ -353,52 +355,60 @@ impl ThreadStore {
})
}
fn load_default_profile(&self, cx: &Context<Self>) {
fn load_default_profile(&self, cx: &mut Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
self.load_profile_by_id(&assistant_settings.default_profile, cx);
self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
}
pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
self.load_profile(profile, cx);
if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
self.load_profile(profile.clone(), cx);
}
}
pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
self.tools.disable_all_tools();
self.tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
);
pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
self.tools.update(cx, |tools, cx| {
tools.disable_all_tools(cx);
tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
);
});
if profile.enable_all_context_servers {
for context_server in self.context_server_manager.read(cx).all_servers() {
self.tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
self.tools.update(cx, |tools, cx| {
tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
});
}
} else {
for (context_server_id, preset) in &profile.context_servers {
self.tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
)
self.tools.update(cx, |tools, cx| {
tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
)
})
}
}
}
@@ -432,29 +442,36 @@ impl ThreadStore {
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
.log_err();
this.update(cx, |this, cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, cx| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
}
}
}
}
@@ -464,7 +481,9 @@ impl ThreadStore {
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids);
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
self.load_default_profile(cx);
}
}
@@ -491,6 +510,8 @@ pub struct SerializedThread {
pub cumulative_token_usage: TokenUsage,
#[serde(default)]
pub detailed_summary_state: DetailedSummaryState,
#[serde(default)]
pub exceeded_window_error: Option<ExceededWindowError>,
}
impl SerializedThread {
@@ -577,6 +598,7 @@ impl LegacySerializedThread {
initial_project_snapshot: self.initial_project_snapshot,
cumulative_token_usage: TokenUsage::default(),
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
}
}
}

View File

@@ -5,7 +5,7 @@ use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
use collections::HashMap;
use futures::FutureExt as _;
use futures::future::Shared;
use gpui::{App, SharedString, Task};
use gpui::{App, Entity, SharedString, Task};
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
@@ -30,7 +30,7 @@ pub struct ToolUse {
pub const USING_TOOL_MARKER: &str = "<using_tool>";
pub struct ToolUseState {
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
@@ -39,7 +39,7 @@ pub struct ToolUseState {
}
impl ToolUseState {
pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
@@ -54,7 +54,7 @@ impl ToolUseState {
///
/// Accepts a function to filter the tools that should be used to populate the state.
pub fn from_serialized_messages(
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self {
@@ -180,12 +180,12 @@ impl ToolUseState {
}
})();
let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
{
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
let (icon, needs_confirmation) =
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
@@ -207,7 +207,7 @@ impl ToolUseState {
input: &serde_json::Value,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.tool(tool_name, cx) {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
tool.ui_text(input).into()
} else {
format!("Unknown tool {tool_name:?}").into()

View File

@@ -25,5 +25,4 @@ serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
util.workspace = true
workspace-hack.workspace = true

View File

@@ -10,7 +10,6 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
use util::ResultExt as _;
pub use supported_countries::*;
@@ -363,11 +362,25 @@ pub struct RateLimitInfo {
impl RateLimitInfo {
fn from_headers(headers: &HeaderMap<HeaderValue>) -> Self {
// Check if any rate limit headers exist
let has_rate_limit_headers = headers
.keys()
.any(|k| k.as_str().starts_with("anthropic-ratelimit-"));
if !has_rate_limit_headers {
return Self {
requests: None,
tokens: None,
input_tokens: None,
output_tokens: None,
};
}
Self {
requests: RateLimit::from_headers("requests", headers).log_err(),
tokens: RateLimit::from_headers("tokens", headers).log_err(),
input_tokens: RateLimit::from_headers("input-tokens", headers).log_err(),
output_tokens: RateLimit::from_headers("output-tokens", headers).log_err(),
requests: RateLimit::from_headers("requests", headers).ok(),
tokens: RateLimit::from_headers("tokens", headers).ok(),
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
output_tokens: RateLimit::from_headers("output-tokens", headers).ok(),
}
}
}
@@ -724,4 +737,54 @@ impl ApiError {
pub fn is_rate_limit_error(&self) -> bool {
matches!(self.error_type.as_str(), "rate_limit_error")
}
pub fn match_window_exceeded(&self) -> Option<usize> {
let Some(ApiErrorCode::InvalidRequestError) = self.code() else {
return None;
};
parse_prompt_too_long(&self.message)
}
}
pub fn parse_prompt_too_long(message: &str) -> Option<usize> {
message
.strip_prefix("prompt is too long: ")?
.split_once(" tokens")?
.0
.parse::<usize>()
.ok()
}
#[test]
fn test_match_window_exceeded() {
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "prompt is too long: 220000 tokens > 200000".to_string(),
};
assert_eq!(error.match_window_exceeded(), Some(220_000));
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "prompt is too long: 1234953 tokens".to_string(),
};
assert_eq!(error.match_window_exceeded(), Some(1234953));
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "not a prompt length error".to_string(),
};
assert_eq!(error.match_window_exceeded(), None);
let error = ApiError {
error_type: "rate_limit_error".to_string(),
message: "prompt is too long: 12345 tokens".to_string(),
};
assert_eq!(error.match_window_exceeded(), None);
let error = ApiError {
error_type: "invalid_request_error".to_string(),
message: "prompt is too long: invalid tokens".to_string(),
};
assert_eq!(error.match_window_exceeded(), None);
}

View File

@@ -4,7 +4,7 @@ use collections::BTreeMap;
use futures::{StreamExt, channel::mpsc};
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
use project::{Project, ProjectItem};
use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
use std::{cmp, ops::Range, sync::Arc};
use text::{Edit, Patch, Rope};
use util::RangeExt;
@@ -49,6 +49,10 @@ impl ActionLog {
.tracked_buffers
.entry(buffer.clone())
.or_insert_with(|| {
let open_lsp_handle = self.project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
});
let text_snapshot = buffer.read(cx).text_snapshot();
let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
let (diff_update_tx, diff_update_rx) = mpsc::unbounded();
@@ -76,6 +80,7 @@ impl ActionLog {
version: buffer.read(cx).version(),
diff,
diff_update: diff_update_tx,
_open_lsp_handle: open_lsp_handle,
_maintain_diff: cx.spawn({
let buffer = buffer.clone();
async move |this, cx| {
@@ -615,6 +620,7 @@ struct TrackedBuffer {
diff: Entity<BufferDiff>,
snapshot: text::BufferSnapshot,
diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>,
_open_lsp_handle: OpenLspBufferHandle,
_maintain_diff: Task<()>,
_subscription: Subscription,
}

View File

@@ -1,5 +1,6 @@
mod action_log;
mod tool_registry;
mod tool_schema;
mod tool_working_set;
use std::fmt;
@@ -20,6 +21,7 @@ use project::Project;
pub use crate::action_log::*;
pub use crate::tool_registry::*;
pub use crate::tool_schema::*;
pub use crate::tool_working_set::*;
pub fn init(cx: &mut App) {
@@ -139,8 +141,8 @@ pub trait Tool: 'static + Send + Sync {
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::default())
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
Ok(serde_json::Value::Object(serde_json::Map::default()))
}
/// Returns markdown to be displayed in the UI for this tool.

View File

@@ -0,0 +1,236 @@
use anyhow::Result;
use serde_json::Value;
use crate::LanguageModelToolSchemaFormat;
/// Tries to adapt a JSON schema representation to be compatible with the specified format.
///
/// If the json cannot be made compatible with the specified format, an error is returned.
pub fn adapt_schema_to_format(
json: &mut Value,
format: LanguageModelToolSchemaFormat,
) -> Result<()> {
match format {
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
}
}
/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
if let Value::Object(obj) = json {
const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
for key in UNSUPPORTED_KEYS {
if obj.contains_key(key) {
return Err(anyhow::anyhow!(
"Schema cannot be made compatible because it contains \"{}\" ",
key
));
}
}
const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
for key in KEYS_TO_REMOVE {
obj.remove(key);
}
if let Some(default) = obj.get("default") {
let is_null = default.is_null();
// Default is not supported, so we need to remove it
obj.remove("default");
if is_null {
obj.insert("nullable".to_string(), Value::Bool(true));
}
}
// If a type is not specified for an input parameter, add a default type
if obj.contains_key("description")
&& !obj.contains_key("type")
&& !(obj.contains_key("anyOf")
|| obj.contains_key("oneOf")
|| obj.contains_key("allOf"))
{
obj.insert("type".to_string(), Value::String("string".to_string()));
}
// Handle oneOf -> anyOf conversion
if let Some(subschemas) = obj.get_mut("oneOf") {
if subschemas.is_array() {
let subschemas_clone = subschemas.clone();
obj.remove("oneOf");
obj.insert("anyOf".to_string(), subschemas_clone);
}
}
// Recursively process all nested objects and arrays
for (_, value) in obj.iter_mut() {
if let Value::Object(_) | Value::Array(_) = value {
adapt_to_json_schema_subset(value)?;
}
}
} else if let Value::Array(arr) = json {
for item in arr.iter_mut() {
adapt_to_json_schema_subset(item)?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_transform_default_null_to_nullable() {
let mut json = json!({
"description": "A test field",
"type": "string",
"default": null
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field",
"type": "string",
"nullable": true
})
);
}
#[test]
fn test_transform_adds_type_when_missing() {
let mut json = json!({
"description": "A test field without type"
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field without type",
"type": "string"
})
);
}
#[test]
fn test_transform_removes_format() {
let mut json = json!({
"description": "A test field",
"type": "integer",
"format": "uint32"
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field",
"type": "integer"
})
);
}
#[test]
fn test_transform_one_of_to_any_of() {
let mut json = json!({
"description": "A test field",
"oneOf": [
{ "type": "string" },
{ "type": "integer" }
]
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"description": "A test field",
"anyOf": [
{ "type": "string" },
{ "type": "integer" }
]
})
);
}
#[test]
fn test_transform_nested_objects() {
let mut json = json!({
"type": "object",
"properties": {
"nested": {
"oneOf": [
{ "type": "string" },
{ "type": "null" }
],
"format": "email"
}
}
});
adapt_to_json_schema_subset(&mut json).unwrap();
assert_eq!(
json,
json!({
"type": "object",
"properties": {
"nested": {
"anyOf": [
{ "type": "string" },
{ "type": "null" }
]
}
}
})
);
}
#[test]
fn test_transform_fails_if_unsupported_keys_exist() {
let mut json = json!({
"type": "object",
"properties": {
"$ref": "#/definitions/User",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
let mut json = json!({
"type": "object",
"properties": {
"if": "...",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
let mut json = json!({
"type": "object",
"properties": {
"then": "...",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
let mut json = json!({
"type": "object",
"properties": {
"else": "...",
}
});
assert!(adapt_to_json_schema_subset(&mut json).is_err());
}
}

View File

@@ -1,8 +1,7 @@
use std::sync::Arc;
use collections::{HashMap, HashSet, IndexMap};
use gpui::App;
use parking_lot::Mutex;
use gpui::{App, Context, EventEmitter};
use crate::{Tool, ToolRegistry, ToolSource};
@@ -12,11 +11,6 @@ pub struct ToolId(usize);
/// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)]
pub struct ToolWorkingSet {
state: Mutex<WorkingSetState>,
}
#[derive(Default)]
struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
enabled_sources: HashSet<ToolSource>,
@@ -24,99 +18,27 @@ struct WorkingSetState {
next_tool_id: ToolId,
}
pub enum ToolWorkingSetEvent {
EnabledToolsChanged,
}
impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
self.state
.lock()
.context_server_tools_by_name
self.context_server_tools_by_name
.get(name)
.cloned()
.or_else(|| ToolRegistry::global(cx).tool(name))
}
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().tools(cx)
}
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
self.state.lock().tools_by_source(cx)
}
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().enabled_tools(cx)
}
pub fn disable_all_tools(&self) {
let mut state = self.state.lock();
state.disable_all_tools();
}
pub fn enable_source(&self, source: ToolSource, cx: &App) {
let mut state = self.state.lock();
state.enable_source(source, cx);
}
pub fn disable_source(&self, source: &ToolSource) {
let mut state = self.state.lock();
state.disable_source(source);
}
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
let mut state = self.state.lock();
let tool_id = state.next_tool_id;
state.next_tool_id.0 += 1;
state
.context_server_tools_by_id
.insert(tool_id, tool.clone());
state.tools_changed();
tool_id
}
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_enabled(source, name)
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_disabled(source, name)
}
pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
let mut state = self.state.lock();
state.enable(source, tools_to_enable);
}
pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
let mut state = self.state.lock();
state.disable(source, tools_to_disable);
}
pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
let mut state = self.state.lock();
state
.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
state.tools_changed();
}
}
impl WorkingSetState {
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
}
fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools();
tools.extend(self.context_server_tools_by_id.values().cloned());
tools
}
fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
let mut tools_by_source = IndexMap::default();
for tool in self.tools(cx) {
@@ -135,7 +57,7 @@ impl WorkingSetState {
tools_by_source
}
fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let all_tools = self.tools(cx);
all_tools
@@ -144,31 +66,12 @@ impl WorkingSetState {
.collect()
}
fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.enabled_tools_by_source
.get(source)
.map_or(false, |enabled_tools| enabled_tools.contains(name))
pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
self.enabled_tools_by_source.clear();
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_enabled(source, name)
}
fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_enable.into_iter().cloned());
}
fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_disable.contains(name));
}
fn enable_source(&mut self, source: ToolSource, cx: &App) {
pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
self.enabled_sources.insert(source.clone());
let tools_by_source = self.tools_by_source(cx);
@@ -181,14 +84,72 @@ impl WorkingSetState {
.collect::<HashSet<_>>(),
);
}
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
fn disable_source(&mut self, source: &ToolSource) {
pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
self.enabled_sources.remove(source);
self.enabled_tools_by_source.remove(source);
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
fn disable_all_tools(&mut self) {
self.enabled_tools_by_source.clear();
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1;
self.context_server_tools_by_id
.insert(tool_id, tool.clone());
self.tools_changed();
tool_id
}
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.enabled_tools_by_source
.get(source)
.map_or(false, |enabled_tools| enabled_tools.contains(name))
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_enabled(source, name)
}
pub fn enable(
&mut self,
source: ToolSource,
tools_to_enable: &[Arc<str>],
cx: &mut Context<Self>,
) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_enable.into_iter().cloned());
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn disable(
&mut self,
source: ToolSource,
tools_to_disable: &[Arc<str>],
cx: &mut Context<Self>,
) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_disable.contains(name));
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed();
}
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
}
}

View File

@@ -1,6 +1,7 @@
mod batch_tool;
mod code_action_tool;
mod code_symbols_tool;
mod contents_tool;
mod copy_path_tool;
mod create_directory_tool;
mod create_file_tool;
@@ -35,6 +36,7 @@ use web_search_tool::WebSearchTool;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
use crate::code_symbols_tool::CodeSymbolsTool;
use crate::contents_tool::ContentsTool;
use crate::create_directory_tool::CreateDirectoryTool;
use crate::create_file_tool::CreateFileTool;
use crate::delete_path_tool::DeletePathTool;
@@ -59,6 +61,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(BatchTool);
registry.register_tool(CodeActionTool);
registry.register_tool(CodeSymbolsTool);
registry.register_tool(ContentsTool);
registry.register_tool(CopyPathTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
@@ -79,3 +82,42 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(ThinkingTool);
registry.register_tool(WebSearchTool);
}
#[cfg(test)]
mod tests {
use http_client::FakeHttpClient;
use super::*;
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
crate::init(
Arc::new(http_client::HttpClientWithUrl::new(
FakeHttpClient::with_200_response(),
"https://zed.dev",
None,
)),
cx,
);
for tool in ToolRegistry::global(cx).tools() {
let actual_schema = tool
.input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset)
.unwrap();
let mut expected_schema = actual_schema.clone();
assistant_tool::adapt_schema_to_format(
&mut expected_schema,
language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
)
.unwrap();
let error_message = format!(
"Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\
Are you using `schema::json_schema_for<T>(format)` to generate the schema?",
tool.name(),
);
assert_eq!(actual_schema, expected_schema, "{}", error_message)
}
}
}

View File

@@ -172,7 +172,7 @@ impl Tool for BatchTool {
IconName::Cog
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<BatchToolInput>(format)
}

View File

@@ -2,7 +2,7 @@ use anyhow::{Context as _, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use language::{self, Anchor, Buffer, ToPointUtf16};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{self, LspAction, Project};
use regex::Regex;
use schemars::JsonSchema;
@@ -10,6 +10,8 @@ use serde::{Deserialize, Serialize};
use std::{ops::Range, sync::Arc};
use ui::IconName;
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct CodeActionToolInput {
/// The relative path to the file containing the text range.
@@ -95,12 +97,8 @@ impl Tool for CodeActionTool {
IconName::Wand
}
fn input_schema(
&self,
_format: language_model::LanguageModelToolSchemaFormat,
) -> serde_json::Value {
let schema = schemars::schema_for!(CodeActionToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CodeActionToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View File

@@ -91,7 +91,7 @@ impl Tool for CodeSymbolsTool {
IconName::Code
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CodeSymbolsInput>(format)
}

View File

@@ -0,0 +1,239 @@
use std::sync::Arc;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{fmt::Write, path::Path};
use ui::IconName;
use util::markdown::MarkdownString;
/// If the model requests to read a file whose size exceeds this, then
/// the tool will return the file's symbol outline instead of its contents,
/// and suggest trying again using line ranges from the outline.
const MAX_FILE_SIZE_TO_READ: usize = 16384;
/// If the model requests to list the entries in a directory with more
/// entries than this, then the tool will return a subset of the entries
/// and suggest trying again.
const MAX_DIR_ENTRIES: usize = 1024;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ContentsToolInput {
/// The relative path of the file or directory to access.
///
/// This path should never be absolute, and the first component
/// of the path should always be a root directory in a project.
///
/// <example>
/// If the project has the following root directories:
///
/// - directory1
/// - directory2
///
/// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
/// If you want to list contents in the directory `directory2/subfolder`, you should use the path `directory2/subfolder`.
/// </example>
pub path: String,
/// Optional position (1-based index) to start reading on, if you want to read a subset of the contents.
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
///
/// Defaults to 1.
pub start: Option<u32>,
/// Optional position (1-based index) to end reading on, if you want to read a subset of the contents.
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
///
/// Defaults to reading until the end of the file or directory.
pub end: Option<u32>,
}
pub struct ContentsTool;
impl Tool for ContentsTool {
fn name(&self) -> String {
"contents".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./contents_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::FileSearch
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ContentsToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ContentsToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownString::inline_code(&input.path);
match (input.start, input.end) {
(Some(start), None) => format!("Read {path} (from line {start})"),
(Some(start), Some(end)) => {
format!("Read {path} (lines {start}-{end})")
}
_ => format!("Read {path}"),
}
}
Err(_) => "Read file or directory".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<ContentsToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
// Sometimes models will return these even though we tell it to give a path and not a glob.
// When this happens, just list the root worktree directories.
if matches!(input.path.as_str(), "." | "" | "./" | "*") {
let output = project
.read(cx)
.worktrees(cx)
.filter_map(|worktree| {
worktree.read(cx).root_entry().and_then(|entry| {
if entry.is_dir() {
entry.path.to_str()
} else {
None
}
})
})
.collect::<Vec<_>>()
.join("\n");
return Task::ready(Ok(output));
}
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found")));
};
let worktree = worktree.read(cx);
let Some(entry) = worktree.entry_for_path(&project_path.path) else {
return Task::ready(Err(anyhow!("Path not found: {}", input.path)));
};
// If it's a directory, list its contents
if entry.is_dir() {
let mut output = String::new();
let start_index = input
.start
.map(|line| (line as usize).saturating_sub(1))
.unwrap_or(0);
let end_index = input
.end
.map(|line| (line as usize).saturating_sub(1))
.unwrap_or(MAX_DIR_ENTRIES);
let mut skipped = 0;
for (index, entry) in worktree.child_entries(&project_path.path).enumerate() {
if index >= start_index && index <= end_index {
writeln!(
output,
"{}",
Path::new(worktree.root_name()).join(&entry.path).display(),
)
.unwrap();
} else {
skipped += 1;
}
}
if output.is_empty() {
output.push_str(&input.path);
output.push_str(" is empty.");
}
if skipped > 0 {
write!(
output,
"\n\nNote: Skipped {skipped} entries. Adjust start and end to see other entries.",
).ok();
}
Task::ready(Ok(output))
} else {
// It's a file, so read its contents
let file_path = input.path.clone();
cx.spawn(async move |cx| {
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
})?
.await?;
if input.start.is_some() || input.end.is_some() {
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let start = input.start.unwrap_or(1);
let lines = text.split('\n').skip(start as usize - 1);
if let Some(end) = input.end {
let count = end.saturating_sub(start).max(1); // Ensure at least 1 line
Itertools::intersperse(lines.take(count as usize), "\n").collect()
} else {
Itertools::intersperse(lines, "\n").collect()
}
})?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
Ok(result)
} else {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
if file_size <= MAX_FILE_SIZE_TO_READ {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
Ok(result)
} else {
// File is too big, so return its outline and a suggestion to
// read again with a line number range specified.
let outline = file_outline(project, file_path, action_log, None, 0, cx).await?;
Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline."))
}
}
})
}
}
}

View File

@@ -0,0 +1,9 @@
Reads the contents of a path on the filesystem.
If the path is a directory, this lists all files and directories within that path.
If the path is a file, this returns the file's contents.
When reading a file, if the file is too big and no line range is specified, an outline of the file's code symbols is listed instead, which can be used to request specific line ranges in a subsequent call.
Similarly, if a directory has too many entries to show at once, a subset of entries will be shown,
and subsequent requests can use starting and ending line numbers to get other subsets.

View File

@@ -55,7 +55,7 @@ impl Tool for CopyPathTool {
IconName::Clipboard
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CopyPathToolInput>(format)
}

View File

@@ -45,7 +45,7 @@ impl Tool for CreateDirectoryTool {
IconName::Folder
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CreateDirectoryToolInput>(format)
}

View File

@@ -52,7 +52,7 @@ impl Tool for CreateFileTool {
IconName::FileCreate
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<CreateFileToolInput>(format)
}

View File

@@ -45,7 +45,7 @@ impl Tool for DeletePathTool {
IconName::FileDelete
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<DeletePathToolInput>(format)
}

View File

@@ -58,7 +58,7 @@ impl Tool for DiagnosticsTool {
IconName::XCircle
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<DiagnosticsToolInput>(format)
}

View File

@@ -128,7 +128,7 @@ impl Tool for FetchTool {
IconName::Globe
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<FetchToolInput>(format)
}

View File

@@ -151,7 +151,7 @@ impl Tool for FindReplaceFileTool {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<FindReplaceFileToolInput>(format)
}

View File

@@ -56,7 +56,7 @@ impl Tool for ListDirectoryTool {
IconName::Folder
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ListDirectoryToolInput>(format)
}

View File

@@ -54,7 +54,7 @@ impl Tool for MovePathTool {
IconName::ArrowRightLeft
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<MovePathToolInput>(format)
}

View File

@@ -45,7 +45,7 @@ impl Tool for NowTool {
IconName::Info
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<NowToolInput>(format)
}

View File

@@ -35,7 +35,7 @@ impl Tool for OpenTool {
IconName::ArrowUpRight
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<OpenToolInput>(format)
}

View File

@@ -53,7 +53,7 @@ impl Tool for PathSearchTool {
IconName::SearchCode
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<PathSearchToolInput>(format)
}

View File

@@ -63,7 +63,7 @@ impl Tool for ReadFileTool {
IconName::FileSearch
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ReadFileToolInput>(format)
}

View File

@@ -60,7 +60,7 @@ impl Tool for RegexSearchTool {
IconName::Regex
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<RegexSearchToolInput>(format)
}

View File

@@ -2,13 +2,15 @@ use anyhow::{Context as _, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use language::{self, Buffer, ToPointUtf16};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ui::IconName;
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RenameToolInput {
/// The relative path to the file containing the symbol to rename.
@@ -66,12 +68,8 @@ impl Tool for RenameTool {
IconName::Pencil
}
fn input_schema(
&self,
_format: language_model::LanguageModelToolSchemaFormat,
) -> serde_json::Value {
let schema = schemars::schema_for!(RenameToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<RenameToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View File

@@ -5,23 +5,20 @@ use schemars::{
schema::{RootSchema, Schema, SchemaObject},
};
pub fn json_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> serde_json::Value {
pub fn json_schema_for<T: JsonSchema>(
format: LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let schema = root_schema_for::<T>(format);
schema_to_json(&schema, format).expect("Failed to convert tool calling schema to JSON")
schema_to_json(&schema, format)
}
pub fn schema_to_json(
fn schema_to_json(
schema: &RootSchema,
format: LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let mut value = serde_json::to_value(schema)?;
match format {
LanguageModelToolSchemaFormat::JsonSchema => Ok(value),
LanguageModelToolSchemaFormat::JsonSchemaSubset => {
transform_fields_to_json_schema_subset(&mut value);
Ok(value)
}
}
assistant_tool::adapt_schema_to_format(&mut value, format)?;
Ok(value)
}
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
@@ -79,42 +76,3 @@ impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
schemars::visit::visit_schema_object(self, schema)
}
}
fn transform_fields_to_json_schema_subset(json: &mut serde_json::Value) {
if let serde_json::Value::Object(obj) = json {
if let Some(default) = obj.get("default") {
let is_null = default.is_null();
//Default is not supported, so we need to remove it.
obj.remove("default");
if is_null {
obj.insert("nullable".to_string(), serde_json::Value::Bool(true));
}
}
// If a type is not specified for an input parameter we need to add it.
if obj.contains_key("description")
&& !obj.contains_key("type")
&& !(obj.contains_key("anyOf")
|| obj.contains_key("oneOf")
|| obj.contains_key("allOf"))
{
obj.insert(
"type".to_string(),
serde_json::Value::String("string".to_string()),
);
}
//Format field is only partially supported (e.g. not uint compatibility)
obj.remove("format");
for (_, value) in obj.iter_mut() {
if let serde_json::Value::Object(_) | serde_json::Value::Array(_) = value {
transform_fields_to_json_schema_subset(value);
}
}
} else if let serde_json::Value::Array(arr) = json {
for item in arr.iter_mut() {
transform_fields_to_json_schema_subset(item);
}
}
}

View File

@@ -84,7 +84,7 @@ impl Tool for SymbolInfoTool {
IconName::Code
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<SymbolInfoToolInput>(format)
}

View File

@@ -44,7 +44,7 @@ impl Tool for TerminalTool {
IconName::Terminal
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<TerminalToolInput>(format)
}

View File

@@ -36,7 +36,7 @@ impl Tool for ThinkingTool {
IconName::LightBulb
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ThinkingToolInput>(format)
}

View File

@@ -27,6 +27,8 @@ serde_json.workspace = true
settings.workspace = true
smol.workspace = true
tempfile.workspace = true
which.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
[target.'cfg(not(target_os = "windows"))'.dependencies]
which.workspace = true

View File

@@ -23,7 +23,6 @@ use std::{
sync::Arc,
time::Duration,
};
use which::which;
use workspace::Workspace;
const SHOULD_SHOW_UPDATE_NOTIFICATION_KEY: &str = "auto-updater-should-show-updated-notification";
@@ -63,7 +62,7 @@ pub struct AutoUpdater {
pending_poll: Option<Task<Option<()>>>,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
pub struct JsonRelease {
pub version: String,
pub url: String,
@@ -237,6 +236,46 @@ pub fn view_release_notes(_: &ViewReleaseNotes, cx: &mut App) -> Option<()> {
None
}
#[cfg(not(target_os = "windows"))]
struct InstallerDir(tempfile::TempDir);
#[cfg(not(target_os = "windows"))]
impl InstallerDir {
async fn new() -> Result<Self> {
Ok(Self(
tempfile::Builder::new()
.prefix("zed-auto-update")
.tempdir()?,
))
}
fn path(&self) -> &Path {
self.0.path()
}
}
#[cfg(target_os = "windows")]
struct InstallerDir(PathBuf);
#[cfg(target_os = "windows")]
impl InstallerDir {
async fn new() -> Result<Self> {
let installer_dir = std::env::current_exe()?
.parent()
.context("No parent dir for Zed.exe")?
.join("updates");
if smol::fs::metadata(&installer_dir).await.is_ok() {
smol::fs::remove_dir_all(&installer_dir).await?;
}
smol::fs::create_dir(&installer_dir).await?;
Ok(Self(installer_dir))
}
fn path(&self) -> &Path {
self.0.as_path()
}
}
impl AutoUpdater {
pub fn get(cx: &mut App) -> Option<Entity<Self>> {
cx.default_global::<GlobalAutoUpdate>().0.clone()
@@ -469,22 +508,21 @@ impl AutoUpdater {
cx.notify();
})?;
let temp_dir = tempfile::Builder::new()
.prefix("zed-auto-update")
.tempdir()?;
let installer_dir = InstallerDir::new().await?;
let filename = match OS {
"macos" => Ok("Zed.dmg"),
"linux" => Ok("zed.tar.gz"),
"windows" => Ok("ZedUpdateInstaller.exe"),
_ => Err(anyhow!("not supported: {:?}", OS)),
}?;
#[cfg(not(target_os = "windows"))]
anyhow::ensure!(
which("rsync").is_ok(),
which::which("rsync").is_ok(),
"Aborting. Could not find rsync which is required for auto-updates."
);
let downloaded_asset = temp_dir.path().join(filename);
let downloaded_asset = installer_dir.path().join(filename);
download_release(&downloaded_asset, release, client, &cx).await?;
this.update(&mut cx, |this, cx| {
@@ -493,8 +531,9 @@ impl AutoUpdater {
})?;
let binary_path = match OS {
"macos" => install_release_macos(&temp_dir, downloaded_asset, &cx).await,
"linux" => install_release_linux(&temp_dir, downloaded_asset, &cx).await,
"macos" => install_release_macos(&installer_dir, downloaded_asset, &cx).await,
"linux" => install_release_linux(&installer_dir, downloaded_asset, &cx).await,
"windows" => install_release_windows(downloaded_asset).await,
_ => Err(anyhow!("not supported: {:?}", OS)),
}?;
@@ -629,7 +668,7 @@ async fn download_release(
}
async fn install_release_linux(
temp_dir: &tempfile::TempDir,
temp_dir: &InstallerDir,
downloaded_tar_gz: PathBuf,
cx: &AsyncApp,
) -> Result<PathBuf> {
@@ -696,7 +735,7 @@ async fn install_release_linux(
}
async fn install_release_macos(
temp_dir: &tempfile::TempDir,
temp_dir: &InstallerDir,
downloaded_dmg: PathBuf,
cx: &AsyncApp,
) -> Result<PathBuf> {
@@ -743,3 +782,41 @@ async fn install_release_macos(
Ok(running_app_path)
}
async fn install_release_windows(downloaded_installer: PathBuf) -> Result<PathBuf> {
let output = Command::new(downloaded_installer)
.arg("/verysilent")
.arg("/update=true")
.arg("!desktopicon")
.arg("!quicklaunchicon")
.output()
.await?;
anyhow::ensure!(
output.status.success(),
"failed to start installer: {:?}",
String::from_utf8_lossy(&output.stderr)
);
Ok(std::env::current_exe()?)
}
pub fn check_pending_installation() -> bool {
let Some(installer_path) = std::env::current_exe()
.ok()
.and_then(|p| p.parent().map(|p| p.join("updates")))
else {
return false;
};
// The installer will create a flag file after it finishes updating
let flag_file = installer_path.join("versions.txt");
if flag_file.exists() {
if let Some(helper) = installer_path
.parent()
.map(|p| p.join("tools\\auto_update_helper.exe"))
{
let _ = std::process::Command::new(helper).spawn();
return true;
}
}
false
}

View File

@@ -0,0 +1,29 @@
[package]
name = "auto_update_helper"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[[bin]]
name = "auto_update_helper"
path = "src/auto_update_helper.rs"
doctest = false
[dependencies]
anyhow.workspace = true
log.workspace = true
simplelog.workspace = true
workspace-hack.workspace = true
[target.'cfg(target_os = "windows")'.dependencies]
windows.workspace = true
[target.'cfg(target_os = "windows")'.build-dependencies]
winresource = "0.1"
[package.metadata.docs.rs]
targets = ["x86_64-pc-windows-msvc"]

View File

@@ -0,0 +1 @@
../../LICENSE-GPL

Binary file not shown.

After

Width:  |  Height:  |  Size: 577 KiB

View File

@@ -0,0 +1,15 @@
fn main() {
#[cfg(target_os = "windows")]
{
println!("cargo:rerun-if-changed=manifest.xml");
let mut res = winresource::WindowsResource::new();
res.set_manifest_file("manifest.xml");
res.set_icon("app-icon.ico");
if let Err(e) = res.compile() {
eprintln!("{}", e);
std::process::exit(1);
}
}
}

View File

@@ -0,0 +1,16 @@
<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0" xmlns:asmv3="urn:schemas-microsoft-com:asm.v3">
<asmv3:application>
<asmv3:windowsSettings>
<dpiAware xmlns="http://schemas.microsoft.com/SMI/2005/WindowsSettings">true</dpiAware>
<dpiAwareness xmlns="http://schemas.microsoft.com/SMI/2016/WindowsSettings">PerMonitorV2</dpiAwareness>
</asmv3:windowsSettings>
</asmv3:application>
<dependency>
<dependentAssembly>
<assemblyIdentity type='win32'
name='Microsoft.Windows.Common-Controls'
version='6.0.0.0' processorArchitecture='*'
publicKeyToken='6595b64144ccf1df' />
</dependentAssembly>
</dependency>
</assembly>

View File

@@ -0,0 +1,94 @@
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
#[cfg(target_os = "windows")]
mod dialog;
#[cfg(target_os = "windows")]
mod updater;
#[cfg(target_os = "windows")]
fn main() {
if let Err(e) = windows_impl::run() {
log::error!("Error: Zed update failed, {:?}", e);
windows_impl::show_error(format!("Error: {:?}", e));
}
}
#[cfg(not(target_os = "windows"))]
fn main() {}
#[cfg(target_os = "windows")]
mod windows_impl {
use std::path::Path;
use super::dialog::create_dialog_window;
use super::updater::perform_update;
use anyhow::{Context, Result};
use windows::{
Win32::{
Foundation::{HWND, LPARAM, WPARAM},
UI::WindowsAndMessaging::{
DispatchMessageW, GetMessageW, MB_ICONERROR, MB_SYSTEMMODAL, MSG, MessageBoxW,
PostMessageW, WM_USER,
},
},
core::HSTRING,
};
pub(crate) const WM_JOB_UPDATED: u32 = WM_USER + 1;
pub(crate) const WM_TERMINATE: u32 = WM_USER + 2;
pub(crate) fn run() -> Result<()> {
let helper_dir = std::env::current_exe()?
.parent()
.context("No parent directory")?
.to_path_buf();
init_log(&helper_dir)?;
let app_dir = helper_dir
.parent()
.context("No parent directory")?
.to_path_buf();
log::info!("======= Starting Zed update =======");
let (tx, rx) = std::sync::mpsc::channel();
let hwnd = create_dialog_window(rx)?.0 as isize;
std::thread::spawn(move || {
let result = perform_update(app_dir.as_path(), Some(hwnd));
tx.send(result).ok();
unsafe { PostMessageW(Some(HWND(hwnd as _)), WM_TERMINATE, WPARAM(0), LPARAM(0)) }.ok();
});
unsafe {
let mut message = MSG::default();
while GetMessageW(&mut message, None, 0, 0).as_bool() {
DispatchMessageW(&message);
}
}
Ok(())
}
fn init_log(helper_dir: &Path) -> Result<()> {
simplelog::WriteLogger::init(
simplelog::LevelFilter::Info,
simplelog::Config::default(),
std::fs::File::options()
.append(true)
.create(true)
.open(helper_dir.join("auto_update_helper.log"))?,
)?;
Ok(())
}
pub(crate) fn show_error(mut content: String) {
if content.len() > 600 {
content.truncate(600);
content.push_str("...\n");
}
let _ = unsafe {
MessageBoxW(
None,
&HSTRING::from(content),
windows::core::w!("Error: Zed update failed."),
MB_ICONERROR | MB_SYSTEMMODAL,
)
};
}
}

View File

@@ -0,0 +1,236 @@
use std::{cell::RefCell, sync::mpsc::Receiver};
use anyhow::{Context as _, Result};
use windows::{
Win32::{
Foundation::{HWND, LPARAM, LRESULT, RECT, WPARAM},
Graphics::Gdi::{
BeginPaint, CLEARTYPE_QUALITY, CLIP_DEFAULT_PRECIS, CreateFontW, DEFAULT_CHARSET,
DeleteObject, EndPaint, FW_NORMAL, LOGFONTW, OUT_TT_ONLY_PRECIS, PAINTSTRUCT,
ReleaseDC, SelectObject, TextOutW,
},
System::LibraryLoader::GetModuleHandleW,
UI::{
Controls::{PBM_SETRANGE, PBM_SETSTEP, PBM_STEPIT, PROGRESS_CLASS},
WindowsAndMessaging::{
CREATESTRUCTW, CS_HREDRAW, CS_VREDRAW, CreateWindowExW, DefWindowProcW,
GWLP_USERDATA, GetDesktopWindow, GetWindowLongPtrW, GetWindowRect, HICON,
IMAGE_ICON, LR_DEFAULTSIZE, LR_SHARED, LoadImageW, PostQuitMessage, RegisterClassW,
SPI_GETICONTITLELOGFONT, SYSTEM_PARAMETERS_INFO_UPDATE_FLAGS, SendMessageW,
SetWindowLongPtrW, SystemParametersInfoW, WINDOW_EX_STYLE, WM_CLOSE, WM_CREATE,
WM_DESTROY, WM_NCCREATE, WM_PAINT, WNDCLASSW, WS_CAPTION, WS_CHILD, WS_EX_TOPMOST,
WS_POPUP, WS_VISIBLE,
},
},
},
core::HSTRING,
};
use crate::{
updater::JOBS,
windows_impl::{WM_JOB_UPDATED, WM_TERMINATE, show_error},
};
#[repr(C)]
#[derive(Debug)]
struct DialogInfo {
rx: Receiver<Result<()>>,
progress_bar: isize,
}
pub(crate) fn create_dialog_window(receiver: Receiver<Result<()>>) -> Result<HWND> {
unsafe {
let class_name = windows::core::w!("Zed-Auto-Updater-Dialog-Class");
let module = GetModuleHandleW(None).context("unable to get module handle")?;
let handle = LoadImageW(
Some(module.into()),
windows::core::PCWSTR(1 as _),
IMAGE_ICON,
0,
0,
LR_DEFAULTSIZE | LR_SHARED,
)
.context("unable to load icon file")?;
let wc = WNDCLASSW {
lpfnWndProc: Some(wnd_proc),
lpszClassName: class_name,
style: CS_HREDRAW | CS_VREDRAW,
hIcon: HICON(handle.0),
..Default::default()
};
RegisterClassW(&wc);
let mut rect = RECT::default();
GetWindowRect(GetDesktopWindow(), &mut rect)
.context("unable to get desktop window rect")?;
let width = 400;
let height = 150;
let info = Box::new(RefCell::new(DialogInfo {
rx: receiver,
progress_bar: 0,
}));
let hwnd = CreateWindowExW(
WS_EX_TOPMOST,
class_name,
windows::core::w!("Zed Editor"),
WS_VISIBLE | WS_POPUP | WS_CAPTION,
rect.right / 2 - width / 2,
rect.bottom / 2 - height / 2,
width,
height,
None,
None,
None,
Some(Box::into_raw(info) as _),
)
.context("unable to create dialog window")?;
Ok(hwnd)
}
}
macro_rules! return_if_failed {
($e:expr) => {
match $e {
Ok(v) => v,
Err(e) => {
return LRESULT(e.code().0 as _);
}
}
};
}
macro_rules! make_lparam {
($l:expr, $h:expr) => {
LPARAM(($l as u32 | ($h as u32) << 16) as isize)
};
}
unsafe extern "system" fn wnd_proc(
hwnd: HWND,
msg: u32,
wparam: WPARAM,
lparam: LPARAM,
) -> LRESULT {
match msg {
WM_NCCREATE => unsafe {
let create_struct = lparam.0 as *const CREATESTRUCTW;
let info = (*create_struct).lpCreateParams as *mut RefCell<DialogInfo>;
let info = Box::from_raw(info);
SetWindowLongPtrW(hwnd, GWLP_USERDATA, Box::into_raw(info) as _);
DefWindowProcW(hwnd, msg, wparam, lparam)
},
WM_CREATE => unsafe {
// Create progress bar
let mut rect = RECT::default();
return_if_failed!(GetWindowRect(hwnd, &mut rect));
let progress_bar = return_if_failed!(CreateWindowExW(
WINDOW_EX_STYLE(0),
PROGRESS_CLASS,
None,
WS_CHILD | WS_VISIBLE,
20,
50,
340,
35,
Some(hwnd),
None,
None,
None,
));
SendMessageW(
progress_bar,
PBM_SETRANGE,
None,
Some(make_lparam!(0, JOBS.len() * 10)),
);
SendMessageW(progress_bar, PBM_SETSTEP, Some(WPARAM(10)), None);
with_dialog_data(hwnd, |data| {
data.borrow_mut().progress_bar = progress_bar.0 as isize
});
LRESULT(0)
},
WM_PAINT => unsafe {
let mut ps = PAINTSTRUCT::default();
let hdc = BeginPaint(hwnd, &mut ps);
let font_name = get_system_ui_font_name();
let font = CreateFontW(
24,
0,
0,
0,
FW_NORMAL.0 as _,
0,
0,
0,
DEFAULT_CHARSET,
OUT_TT_ONLY_PRECIS,
CLIP_DEFAULT_PRECIS,
CLEARTYPE_QUALITY,
0,
&HSTRING::from(font_name),
);
let temp = SelectObject(hdc, font.into());
let string = HSTRING::from("Zed Editor is updating...");
return_if_failed!(TextOutW(hdc, 20, 15, &string).ok());
return_if_failed!(DeleteObject(temp).ok());
return_if_failed!(EndPaint(hwnd, &ps).ok());
ReleaseDC(Some(hwnd), hdc);
LRESULT(0)
},
WM_JOB_UPDATED => with_dialog_data(hwnd, |data| {
let progress_bar = data.borrow().progress_bar;
unsafe { SendMessageW(HWND(progress_bar as _), PBM_STEPIT, None, None) }
}),
WM_TERMINATE => {
with_dialog_data(hwnd, |data| {
if let Ok(result) = data.borrow_mut().rx.recv() {
if let Err(e) = result {
log::error!("Failed to update Zed: {:?}", e);
show_error(format!("Error: {:?}", e));
}
}
});
unsafe { PostQuitMessage(0) };
LRESULT(0)
}
WM_CLOSE => LRESULT(0), // Prevent user occasionally closing the window
WM_DESTROY => {
unsafe { PostQuitMessage(0) };
LRESULT(0)
}
_ => unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) },
}
}
fn with_dialog_data<F, T>(hwnd: HWND, f: F) -> T
where
F: FnOnce(&RefCell<DialogInfo>) -> T,
{
let raw = unsafe { GetWindowLongPtrW(hwnd, GWLP_USERDATA) as *mut RefCell<DialogInfo> };
let data = unsafe { Box::from_raw(raw) };
let result = f(data.as_ref());
unsafe { SetWindowLongPtrW(hwnd, GWLP_USERDATA, Box::into_raw(data) as _) };
result
}
fn get_system_ui_font_name() -> String {
unsafe {
let mut info: LOGFONTW = std::mem::zeroed();
if SystemParametersInfoW(
SPI_GETICONTITLELOGFONT,
std::mem::size_of::<LOGFONTW>() as u32,
Some(&mut info as *mut _ as _),
SYSTEM_PARAMETERS_INFO_UPDATE_FLAGS(0),
)
.is_ok()
{
let font_name = String::from_utf16_lossy(&info.lfFaceName);
font_name.trim_matches(char::from(0)).to_owned()
} else {
"MS Shell Dlg".to_owned()
}
}
}

View File

@@ -0,0 +1,171 @@
use std::{
os::windows::process::CommandExt,
path::Path,
time::{Duration, Instant},
};
use anyhow::{Context, Result};
use windows::Win32::{
Foundation::{HWND, LPARAM, WPARAM},
System::Threading::CREATE_NEW_PROCESS_GROUP,
UI::WindowsAndMessaging::PostMessageW,
};
use crate::windows_impl::WM_JOB_UPDATED;
type Job = fn(&Path) -> Result<()>;
#[cfg(not(test))]
pub(crate) const JOBS: [Job; 6] = [
// Delete old files
|app_dir| {
let zed_executable = app_dir.join("Zed.exe");
log::info!("Removing old file: {}", zed_executable.display());
std::fs::remove_file(&zed_executable).context(format!(
"Failed to remove old file {}",
zed_executable.display()
))
},
|app_dir| {
let zed_cli = app_dir.join("bin\\zed.exe");
log::info!("Removing old file: {}", zed_cli.display());
std::fs::remove_file(&zed_cli)
.context(format!("Failed to remove old file {}", zed_cli.display()))
},
// Copy new files
|app_dir| {
let zed_executable_source = app_dir.join("install\\Zed.exe");
let zed_executable_dest = app_dir.join("Zed.exe");
log::info!(
"Copying new file {} to {}",
zed_executable_source.display(),
zed_executable_dest.display()
);
std::fs::copy(&zed_executable_source, &zed_executable_dest)
.map(|_| ())
.context(format!(
"Failed to copy new file {} to {}",
zed_executable_source.display(),
zed_executable_dest.display()
))
},
|app_dir| {
let zed_cli_source = app_dir.join("install\\bin\\zed.exe");
let zed_cli_dest = app_dir.join("bin\\zed.exe");
log::info!(
"Copying new file {} to {}",
zed_cli_source.display(),
zed_cli_dest.display()
);
std::fs::copy(&zed_cli_source, &zed_cli_dest)
.map(|_| ())
.context(format!(
"Failed to copy new file {} to {}",
zed_cli_source.display(),
zed_cli_dest.display()
))
},
// Clean up installer folder and updates folder
|app_dir| {
let updates_folder = app_dir.join("updates");
log::info!("Cleaning up: {}", updates_folder.display());
std::fs::remove_dir_all(&updates_folder).context(format!(
"Failed to remove updates folder {}",
updates_folder.display()
))
},
|app_dir| {
let installer_folder = app_dir.join("install");
log::info!("Cleaning up: {}", installer_folder.display());
std::fs::remove_dir_all(&installer_folder).context(format!(
"Failed to remove installer folder {}",
installer_folder.display()
))
},
];
#[cfg(test)]
pub(crate) const JOBS: [Job; 2] = [
|_| {
std::thread::sleep(Duration::from_millis(1000));
if let Ok(config) = std::env::var("ZED_AUTO_UPDATE") {
match config.as_str() {
"err" => Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Simulated error",
))
.context("Anyhow!"),
_ => panic!("Unknown ZED_AUTO_UPDATE value: {}", config),
}
} else {
Ok(())
}
},
|_| {
std::thread::sleep(Duration::from_millis(1000));
if let Ok(config) = std::env::var("ZED_AUTO_UPDATE") {
match config.as_str() {
"err" => Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Simulated error",
))
.context("Anyhow!"),
_ => panic!("Unknown ZED_AUTO_UPDATE value: {}", config),
}
} else {
Ok(())
}
},
];
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>) -> Result<()> {
let hwnd = hwnd.map(|ptr| HWND(ptr as _));
for job in JOBS.iter() {
let start = Instant::now();
loop {
if start.elapsed().as_secs() > 2 {
return Err(anyhow::anyhow!("Timed out"));
}
match (*job)(app_dir) {
Ok(_) => {
unsafe { PostMessageW(hwnd, WM_JOB_UPDATED, WPARAM(0), LPARAM(0))? };
break;
}
Err(err) => {
// Check if it's a "not found" error
let io_err = err.downcast_ref::<std::io::Error>().unwrap();
if io_err.kind() == std::io::ErrorKind::NotFound {
log::warn!("File or folder not found.");
unsafe { PostMessageW(hwnd, WM_JOB_UPDATED, WPARAM(0), LPARAM(0))? };
break;
}
log::error!("Operation failed: {}", err);
std::thread::sleep(Duration::from_millis(50));
}
}
}
}
let _ = std::process::Command::new(app_dir.join("Zed.exe"))
.creation_flags(CREATE_NEW_PROCESS_GROUP.0)
.spawn();
log::info!("Update completed successfully");
Ok(())
}
#[cfg(test)]
mod test {
use super::perform_update;
#[test]
fn test_perform_update() {
let app_dir = std::path::Path::new("C:/");
assert!(perform_update(app_dir, None).is_ok());
// Simulate a timeout
unsafe { std::env::set_var("ZED_AUTO_UPDATE", "err") };
let ret = perform_update(app_dir, None);
assert!(ret.is_err_and(|e| e.to_string().as_str() == "Timed out"));
}
}

View File

@@ -7,8 +7,6 @@ fn main() {
if cfg!(target_os = "macos") {
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
// Weakly link ScreenCaptureKit to ensure can be used on macOS 10.15+.
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ScreenCaptureKit");
}
// Populate git sha environment variable if git is available

View File

@@ -18,7 +18,6 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
test-support = ["sqlite"]
[dependencies]
anthropic.workspace = true
anyhow.workspace = true
async-stripe.workspace = true
async-tungstenite.workspace = true

View File

@@ -253,7 +253,6 @@ impl Config {
pub enum ServiceMode {
Api,
Collab,
Llm,
All,
}
@@ -265,10 +264,6 @@ impl ServiceMode {
pub fn is_api(&self) -> bool {
matches!(self, Self::Api | Self::All)
}
pub fn is_llm(&self) -> bool {
matches!(self, Self::Llm | Self::All)
}
}
pub struct AppState {

View File

@@ -1,448 +1,10 @@
mod authorization;
pub mod db;
mod token;
use crate::api::CloudflareIpCountryHeader;
use crate::api::events::SnowflakeRow;
use crate::build_kinesis_client;
use crate::rpc::MIN_ACCOUNT_AGE_FOR_LLM_USE;
use crate::{Cents, Config, Error, Result, db::UserId, executor::Executor};
use anyhow::{Context as _, anyhow};
use authorization::authorize_access_to_language_model;
use axum::routing::get;
use axum::{
Extension, Json, Router, TypedHeader,
body::Body,
http::{self, HeaderName, HeaderValue, Request, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::post,
};
use chrono::{DateTime, Duration, Utc};
use collections::HashMap;
use db::TokenUsage;
use db::{ActiveUserCount, LlmDatabase, usage_measure::UsageMeasure};
use futures::{Stream, StreamExt as _};
use reqwest_client::ReqwestClient;
use rpc::{
EXPIRED_LLM_TOKEN_HEADER_NAME, LanguageModelProvider, PerformCompletionParams, proto::Plan,
};
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use serde_json::json;
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use strum::IntoEnumIterator;
use tokio::sync::RwLock;
use util::ResultExt;
use crate::Cents;
pub use token::*;
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
pub struct LlmState {
pub config: Config,
pub executor: Executor,
pub db: Arc<LlmDatabase>,
pub http_client: ReqwestClient,
pub kinesis_client: Option<aws_sdk_kinesis::Client>,
active_user_count_by_model:
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
}
impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
let database_url = config
.llm_database_url
.as_ref()
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
let max_connections = config
.llm_database_max_connections
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
let mut db_options = db::ConnectOptions::new(database_url);
db_options.max_connections(max_connections);
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
db.initialize().await?;
let db = Arc::new(db);
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client =
ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
let this = Self {
executor,
db,
http_client,
kinesis_client: if config.kinesis_access_key.is_some() {
build_kinesis_client(&config).await.log_err()
} else {
None
},
active_user_count_by_model: RwLock::new(HashMap::default()),
config,
};
Ok(Arc::new(this))
}
pub async fn get_active_user_count(
&self,
provider: LanguageModelProvider,
model: &str,
) -> Result<ActiveUserCount> {
let now = Utc::now();
{
let active_user_count_by_model = self.active_user_count_by_model.read().await;
if let Some((last_updated, count)) =
active_user_count_by_model.get(&(provider, model.to_string()))
{
if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
return Ok(*count);
}
}
}
let mut cache = self.active_user_count_by_model.write().await;
let new_count = self.db.get_active_user_count(provider, model, now).await?;
cache.insert((provider, model.to_string()), (now, new_count));
Ok(new_count)
}
}
pub fn routes() -> Router<(), Body> {
Router::new()
.route("/models", get(list_models))
.route("/completion", post(perform_completion))
.layer(middleware::from_fn(validate_api_token))
}
async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
let token = req
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
Error::http(
StatusCode::BAD_REQUEST,
"missing authorization header".to_string(),
)
})?
.strip_prefix("Bearer ")
.ok_or_else(|| {
Error::http(
StatusCode::BAD_REQUEST,
"invalid authorization header".to_string(),
)
})?;
let state = req.extensions().get::<Arc<LlmState>>().unwrap();
match LlmTokenClaims::validate(token, &state.config) {
Ok(claims) => {
if state.db.is_access_token_revoked(&claims.jti).await? {
return Err(Error::http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
));
}
tracing::Span::current()
.record("user_id", claims.user_id)
.record("login", claims.github_user_login.clone())
.record("authn.jti", &claims.jti)
.record("is_staff", claims.is_staff);
req.extensions_mut().insert(claims);
Ok::<_, Error>(next.run(req).await.into_response())
}
Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
[(
HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
HeaderValue::from_static("true"),
)]
.into_iter()
.collect(),
)),
Err(_err) => Err(Error::http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
)),
}
}
async fn list_models(
Extension(state): Extension<Arc<LlmState>>,
Extension(claims): Extension<LlmTokenClaims>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
) -> Result<Json<ListModelsResponse>> {
let country_code = country_code_header.map(|header| header.to_string());
let mut accessible_models = Vec::new();
for (provider, model) in state.db.all_models() {
let authorize_result = authorize_access_to_language_model(
&state.config,
&claims,
country_code.as_deref(),
provider,
&model.name,
);
if authorize_result.is_ok() {
accessible_models.push(rpc::LanguageModel {
provider,
name: model.name,
});
}
}
Ok(Json(ListModelsResponse {
models: accessible_models,
}))
}
async fn perform_completion(
Extension(state): Extension<Arc<LlmState>>,
Extension(claims): Extension<LlmTokenClaims>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
let model = normalize_model_name(
state.db.model_names_for_provider(params.provider),
params.model,
);
let bypass_account_age_check = claims.has_llm_subscription || claims.bypass_account_age_check;
if !bypass_account_age_check {
if Utc::now().naive_utc() - claims.account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
Err(anyhow!("account too young"))?
}
}
authorize_access_to_language_model(
&state.config,
&claims,
country_code_header
.map(|header| header.to_string())
.as_deref(),
params.provider,
&model,
)?;
check_usage_limit(&state, params.provider, &model, &claims).await?;
let stream = match params.provider {
LanguageModelProvider::Anthropic => {
let api_key = if claims.is_staff {
state
.config
.anthropic_staff_api_key
.as_ref()
.context("no Anthropic AI staff API key configured on the server")?
} else {
state
.config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?
};
let mut request: anthropic::Request =
serde_json::from_str(params.provider_request.get())?;
// Override the model on the request with the latest version of the model that is
// known to the server.
//
// Right now, we use the version that's defined in `model.id()`, but we will likely
// want to change this code once a new version of an Anthropic model is released,
// so that users can use the new version, without having to update Zed.
request.model = match model.as_str() {
"claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
"claude-3-7-sonnet" => anthropic::Model::Claude3_7Sonnet.id().to_string(),
"claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
"claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
"claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
_ => request.model,
};
let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
&state.http_client,
anthropic::ANTHROPIC_API_URL,
api_key,
request,
)
.await
.map_err(|err| match err {
anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
Some(anthropic::ApiErrorCode::RateLimitError) => {
tracing::info!(
target: "upstream rate limit exceeded",
user_id = claims.user_id,
login = claims.github_user_login,
authn.jti = claims.jti,
is_staff = claims.is_staff,
provider = params.provider.to_string(),
model = model
);
Error::http(
StatusCode::TOO_MANY_REQUESTS,
"Upstream Anthropic rate limit exceeded.".to_string(),
)
}
Some(anthropic::ApiErrorCode::InvalidRequestError) => {
Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
}
Some(anthropic::ApiErrorCode::OverloadedError) => {
Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
}
Some(_) => {
Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
}
None => Error::Internal(anyhow!(err)),
},
anthropic::AnthropicError::Other(err) => Error::Internal(err),
})?;
if let Some(rate_limit_info) = rate_limit_info {
tracing::info!(
target: "upstream rate limit",
is_staff = claims.is_staff,
provider = params.provider.to_string(),
model = model,
tokens_remaining = rate_limit_info.tokens.as_ref().map(|limits| limits.remaining),
input_tokens_remaining = rate_limit_info.input_tokens.as_ref().map(|limits| limits.remaining),
output_tokens_remaining = rate_limit_info.output_tokens.as_ref().map(|limits| limits.remaining),
requests_remaining = rate_limit_info.requests.as_ref().map(|limits| limits.remaining),
requests_reset = ?rate_limit_info.requests.as_ref().map(|limits| limits.reset),
tokens_reset = ?rate_limit_info.tokens.as_ref().map(|limits| limits.reset),
input_tokens_reset = ?rate_limit_info.input_tokens.as_ref().map(|limits| limits.reset),
output_tokens_reset = ?rate_limit_info.output_tokens.as_ref().map(|limits| limits.reset),
);
}
chunks
.map(move |event| {
let chunk = event?;
let (
input_tokens,
output_tokens,
cache_creation_input_tokens,
cache_read_input_tokens,
) = match &chunk {
anthropic::Event::MessageStart {
message: anthropic::Response { usage, .. },
}
| anthropic::Event::MessageDelta { usage, .. } => (
usage.input_tokens.unwrap_or(0) as usize,
usage.output_tokens.unwrap_or(0) as usize,
usage.cache_creation_input_tokens.unwrap_or(0) as usize,
usage.cache_read_input_tokens.unwrap_or(0) as usize,
),
_ => (0, 0, 0, 0),
};
anyhow::Ok(CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
cache_creation_input_tokens,
cache_read_input_tokens,
})
})
.boxed()
}
LanguageModelProvider::OpenAi => {
let api_key = state
.config
.openai_api_key
.as_ref()
.context("no OpenAI API key configured on the server")?;
let chunks = open_ai::stream_completion(
&state.http_client,
open_ai::OPEN_AI_API_URL,
api_key,
serde_json::from_str(params.provider_request.get())?,
)
.await?;
chunks
.map(|event| {
event.map(|chunk| {
let input_tokens =
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
let output_tokens =
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
})
})
.boxed()
}
LanguageModelProvider::Google => {
let api_key = state
.config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
let chunks = google_ai::stream_generate_content(
&state.http_client,
google_ai::API_URL,
api_key,
serde_json::from_str(params.provider_request.get())?,
)
.await?;
chunks
.map(|event| {
event.map(|chunk| {
// TODO - implement token counting for Google AI
CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
})
})
.boxed()
}
};
Ok(Response::new(Body::wrap_stream(TokenCountingStream {
state,
claims,
provider: params.provider,
model,
tokens: TokenUsage::default(),
inner_stream: stream,
})))
}
fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
if let Some(known_model_name) = known_models
.iter()
.filter(|known_model_name| name.starts_with(known_model_name.as_str()))
.max_by_key(|known_model_name| known_model_name.len())
{
known_model_name.to_string()
} else {
name
}
}
/// The maximum monthly spending an individual user can reach on the free tier
/// before they have to pay.
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
@@ -452,330 +14,3 @@ pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
///
/// Used to prevent surprise bills.
pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
async fn check_usage_limit(
state: &Arc<LlmState>,
provider: LanguageModelProvider,
model_name: &str,
claims: &LlmTokenClaims,
) -> Result<()> {
if claims.is_staff {
return Ok(());
}
let user_id = UserId::from_proto(claims.user_id);
let model = state.db.model(provider, model_name)?;
let free_tier = claims.free_tier_monthly_spending_limit();
let spending_this_month = state
.db
.get_user_spending_for_month(user_id, Utc::now())
.await?;
if spending_this_month >= free_tier {
if !claims.has_llm_subscription {
return Err(Error::http(
StatusCode::PAYMENT_REQUIRED,
"Maximum spending limit reached for this month.".to_string(),
));
}
let monthly_spend = spending_this_month.saturating_sub(free_tier);
if monthly_spend >= Cents(claims.max_monthly_spend_in_cents) {
return Err(Error::Http(
StatusCode::FORBIDDEN,
"Maximum spending limit reached for this month.".to_string(),
[(
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
HeaderValue::from_static("true"),
)]
.into_iter()
.collect(),
));
}
}
let active_users = state.get_active_user_count(provider, model_name).await?;
let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
let users_in_recent_days = active_users.users_in_recent_days.max(1);
let per_user_max_requests_per_minute =
model.max_requests_per_minute as usize / users_in_recent_minutes;
let per_user_max_tokens_per_minute =
model.max_tokens_per_minute as usize / users_in_recent_minutes;
let per_user_max_input_tokens_per_minute =
model.max_input_tokens_per_minute as usize / users_in_recent_minutes;
let per_user_max_output_tokens_per_minute =
model.max_output_tokens_per_minute as usize / users_in_recent_minutes;
let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
let usage = state
.db
.get_usage(user_id, provider, model_name, Utc::now())
.await?;
let checks = match (provider, model_name) {
(LanguageModelProvider::Anthropic, "claude-3-7-sonnet") => vec![
(
usage.requests_this_minute,
per_user_max_requests_per_minute,
UsageMeasure::RequestsPerMinute,
),
(
usage.input_tokens_this_minute,
per_user_max_tokens_per_minute,
UsageMeasure::InputTokensPerMinute,
),
(
usage.output_tokens_this_minute,
per_user_max_tokens_per_minute,
UsageMeasure::OutputTokensPerMinute,
),
(
usage.tokens_this_day,
per_user_max_tokens_per_day,
UsageMeasure::TokensPerDay,
),
],
_ => vec![
(
usage.requests_this_minute,
per_user_max_requests_per_minute,
UsageMeasure::RequestsPerMinute,
),
(
usage.tokens_this_minute,
per_user_max_tokens_per_minute,
UsageMeasure::TokensPerMinute,
),
(
usage.tokens_this_day,
per_user_max_tokens_per_day,
UsageMeasure::TokensPerDay,
),
],
};
for (used, limit, usage_measure) in checks {
if used > limit {
let resource = match usage_measure {
UsageMeasure::RequestsPerMinute => "requests_per_minute",
UsageMeasure::TokensPerMinute => "tokens_per_minute",
UsageMeasure::InputTokensPerMinute => "input_tokens_per_minute",
UsageMeasure::OutputTokensPerMinute => "output_tokens_per_minute",
UsageMeasure::TokensPerDay => "tokens_per_day",
};
tracing::info!(
target: "user rate limit",
user_id = claims.user_id,
login = claims.github_user_login,
authn.jti = claims.jti,
is_staff = claims.is_staff,
provider = provider.to_string(),
model = model.name,
usage_measure = resource,
requests_this_minute = usage.requests_this_minute,
tokens_this_minute = usage.tokens_this_minute,
input_tokens_this_minute = usage.input_tokens_this_minute,
output_tokens_this_minute = usage.output_tokens_this_minute,
tokens_this_day = usage.tokens_this_day,
users_in_recent_minutes = users_in_recent_minutes,
users_in_recent_days = users_in_recent_days,
max_requests_per_minute = per_user_max_requests_per_minute,
max_tokens_per_minute = per_user_max_tokens_per_minute,
max_input_tokens_per_minute = per_user_max_input_tokens_per_minute,
max_output_tokens_per_minute = per_user_max_output_tokens_per_minute,
max_tokens_per_day = per_user_max_tokens_per_day,
);
SnowflakeRow::new(
"Language Model Rate Limited",
Some(claims.metrics_id),
claims.is_staff,
claims.system_id.clone(),
json!({
"usage": usage,
"users_in_recent_minutes": users_in_recent_minutes,
"users_in_recent_days": users_in_recent_days,
"max_requests_per_minute": per_user_max_requests_per_minute,
"max_tokens_per_minute": per_user_max_tokens_per_minute,
"max_input_tokens_per_minute": per_user_max_input_tokens_per_minute,
"max_output_tokens_per_minute": per_user_max_output_tokens_per_minute,
"max_tokens_per_day": per_user_max_tokens_per_day,
"plan": match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
},
"model": model.name.clone(),
"provider": provider.to_string(),
"usage_measure": resource.to_string(),
}),
)
.write(&state.kinesis_client, &state.config.kinesis_stream)
.await
.log_err();
return Err(Error::http(
StatusCode::TOO_MANY_REQUESTS,
format!("Rate limit exceeded. Maximum {} reached.", resource),
));
}
}
Ok(())
}
struct CompletionChunk {
bytes: Vec<u8>,
input_tokens: usize,
output_tokens: usize,
cache_creation_input_tokens: usize,
cache_read_input_tokens: usize,
}
struct TokenCountingStream<S> {
state: Arc<LlmState>,
claims: LlmTokenClaims,
provider: LanguageModelProvider,
model: String,
tokens: TokenUsage,
inner_stream: S,
}
impl<S> Stream for TokenCountingStream<S>
where
S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
{
type Item = Result<Vec<u8>, anyhow::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner_stream).poll_next(cx) {
Poll::Ready(Some(Ok(mut chunk))) => {
chunk.bytes.push(b'\n');
self.tokens.input += chunk.input_tokens;
self.tokens.output += chunk.output_tokens;
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
Poll::Ready(Some(Ok(chunk.bytes)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl<S> Drop for TokenCountingStream<S> {
fn drop(&mut self) {
let state = self.state.clone();
let claims = self.claims.clone();
let provider = self.provider;
let model = std::mem::take(&mut self.model);
let tokens = self.tokens;
self.state.executor.spawn_detached(async move {
let usage = state
.db
.record_usage(
UserId::from_proto(claims.user_id),
claims.is_staff,
provider,
&model,
tokens,
claims.has_llm_subscription,
Cents(claims.max_monthly_spend_in_cents),
claims.free_tier_monthly_spending_limit(),
Utc::now(),
)
.await
.log_err();
if let Some(usage) = usage {
tracing::info!(
target: "user usage",
user_id = claims.user_id,
login = claims.github_user_login,
authn.jti = claims.jti,
is_staff = claims.is_staff,
provider = provider.to_string(),
model = model,
requests_this_minute = usage.requests_this_minute,
tokens_this_minute = usage.tokens_this_minute,
input_tokens_this_minute = usage.input_tokens_this_minute,
output_tokens_this_minute = usage.output_tokens_this_minute,
);
let properties = json!({
"has_llm_subscription": claims.has_llm_subscription,
"max_monthly_spend_in_cents": claims.max_monthly_spend_in_cents,
"plan": match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
},
"model": model,
"provider": provider,
"usage": usage,
"tokens": tokens
});
SnowflakeRow::new(
"Language Model Used",
Some(claims.metrics_id),
claims.is_staff,
claims.system_id.clone(),
properties,
)
.write(&state.kinesis_client, &state.config.kinesis_stream)
.await
.log_err();
}
})
}
}
pub fn log_usage_periodically(state: Arc<LlmState>) {
state.executor.clone().spawn_detached(async move {
loop {
state
.executor
.sleep(std::time::Duration::from_secs(30))
.await;
for provider in LanguageModelProvider::iter() {
for model in state.db.model_names_for_provider(provider) {
if let Some(active_user_count) = state
.get_active_user_count(provider, &model)
.await
.log_err()
{
tracing::info!(
target: "active user counts",
provider = provider.to_string(),
model = model,
users_in_recent_minutes = active_user_count.users_in_recent_minutes,
users_in_recent_days = active_user_count.users_in_recent_days,
);
}
}
}
if let Some(usages) = state
.db
.get_application_wide_usages_by_model(Utc::now())
.await
.log_err()
{
for usage in usages {
tracing::info!(
target: "computed usage",
provider = usage.provider.to_string(),
model = usage.model,
requests_this_minute = usage.requests_this_minute,
tokens_this_minute = usage.tokens_this_minute,
input_tokens_this_minute = usage.input_tokens_this_minute,
output_tokens_this_minute = usage.output_tokens_this_minute,
);
}
}
}
})
}

View File

@@ -1,330 +0,0 @@
use reqwest::StatusCode;
use rpc::LanguageModelProvider;
use crate::llm::LlmTokenClaims;
use crate::{Config, Error, Result};
pub fn authorize_access_to_language_model(
config: &Config,
claims: &LlmTokenClaims,
country_code: Option<&str>,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
authorize_access_for_country(config, country_code, provider)?;
authorize_access_to_model(config, claims, provider, model)?;
Ok(())
}
fn authorize_access_to_model(
config: &Config,
claims: &LlmTokenClaims,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
if claims.is_staff {
return Ok(());
}
if provider == LanguageModelProvider::Anthropic {
if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
return Ok(());
}
if claims.has_llm_closed_beta_feature_flag
&& Some(model) == config.llm_closed_beta_model_name.as_deref()
{
return Ok(());
}
}
Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"),
))
}
fn authorize_access_for_country(
config: &Config,
country_code: Option<&str>,
provider: LanguageModelProvider,
) -> Result<()> {
// In development we won't have the `CF-IPCountry` header, so we can't check
// the country code.
//
// This shouldn't be necessary, as anyone running in development will need to provide
// their own API credentials in order to use an LLM provider.
if config.is_development() {
return Ok(());
}
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
let country_code = match country_code {
// `XX` - Used for clients without country code data.
None | Some("XX") => Err(Error::http(
StatusCode::BAD_REQUEST,
"no country code".to_string(),
))?,
// `T1` - Used for clients using the Tor network.
Some("T1") => Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to {provider:?} models is not available over Tor"),
))?,
Some(country_code) => country_code,
};
let is_country_supported_by_provider = match provider {
LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
};
if !is_country_supported_by_provider {
Err(Error::http(
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
format!(
"access to {provider:?} models is not available in your region ({country_code})"
),
))?
}
Ok(())
}
#[cfg(test)]
mod tests {
use axum::response::IntoResponse;
use pretty_assertions::assert_eq;
use rpc::proto::Plan;
use super::*;
#[gpui::test]
async fn test_authorize_access_to_language_model_with_supported_country(
_cx: &mut gpui::TestAppContext,
) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
is_staff: true,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "US"), // United States
(LanguageModelProvider::Anthropic, "GB"), // United Kingdom
(LanguageModelProvider::OpenAi, "US"), // United States
(LanguageModelProvider::OpenAi, "GB"), // United Kingdom
(LanguageModelProvider::Google, "US"), // United States
(LanguageModelProvider::Google, "GB"), // United Kingdom
];
for (provider, country_code) in cases {
authorize_access_to_language_model(
&config,
&claims,
Some(country_code),
provider,
"the-model",
)
.unwrap_or_else(|_| {
panic!("expected authorization to return Ok for {provider:?}: {country_code}")
})
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_with_unsupported_country(
_cx: &mut gpui::TestAppContext,
) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "AF"), // Afghanistan
(LanguageModelProvider::Anthropic, "BY"), // Belarus
(LanguageModelProvider::Anthropic, "CF"), // Central African Republic
(LanguageModelProvider::Anthropic, "CN"), // China
(LanguageModelProvider::Anthropic, "CU"), // Cuba
(LanguageModelProvider::Anthropic, "ER"), // Eritrea
(LanguageModelProvider::Anthropic, "ET"), // Ethiopia
(LanguageModelProvider::Anthropic, "IR"), // Iran
(LanguageModelProvider::Anthropic, "KP"), // North Korea
(LanguageModelProvider::Anthropic, "XK"), // Kosovo
(LanguageModelProvider::Anthropic, "LY"), // Libya
(LanguageModelProvider::Anthropic, "MM"), // Myanmar
(LanguageModelProvider::Anthropic, "RU"), // Russia
(LanguageModelProvider::Anthropic, "SO"), // Somalia
(LanguageModelProvider::Anthropic, "SS"), // South Sudan
(LanguageModelProvider::Anthropic, "SD"), // Sudan
(LanguageModelProvider::Anthropic, "SY"), // Syria
(LanguageModelProvider::Anthropic, "VE"), // Venezuela
(LanguageModelProvider::Anthropic, "YE"), // Yemen
(LanguageModelProvider::OpenAi, "KP"), // North Korea
(LanguageModelProvider::Google, "KP"), // North Korea
];
for (provider, country_code) in cases {
let error_response = authorize_access_to_language_model(
&config,
&claims,
Some(country_code),
provider,
"the-model",
)
.expect_err(&format!(
"expected authorization to return an error for {provider:?}: {country_code}"
))
.into_response();
assert_eq!(
error_response.status(),
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
);
let response_body = hyper::body::to_bytes(error_response.into_body())
.await
.unwrap()
.to_vec();
assert_eq!(
String::from_utf8(response_body).unwrap(),
format!(
"access to {provider:?} models is not available in your region ({country_code})"
)
);
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "T1"), // Tor
(LanguageModelProvider::OpenAi, "T1"), // Tor
(LanguageModelProvider::Google, "T1"), // Tor
];
for (provider, country_code) in cases {
let error_response = authorize_access_to_language_model(
&config,
&claims,
Some(country_code),
provider,
"the-model",
)
.expect_err(&format!(
"expected authorization to return an error for {provider:?}: {country_code}"
))
.into_response();
assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
let response_body = hyper::body::to_bytes(error_response.into_body())
.await
.unwrap()
.to_vec();
assert_eq!(
String::from_utf8(response_body).unwrap(),
format!("access to {provider:?} models is not available over Tor")
);
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_based_on_plan() {
let config = Config::test();
let test_cases = vec![
// Pro plan should have access to claude-3.5-sonnet
(
Plan::ZedPro,
LanguageModelProvider::Anthropic,
"claude-3-5-sonnet",
true,
),
// Free plan should have access to claude-3.5-sonnet
(
Plan::Free,
LanguageModelProvider::Anthropic,
"claude-3-5-sonnet",
true,
),
// Pro plan should NOT have access to other Anthropic models
(
Plan::ZedPro,
LanguageModelProvider::Anthropic,
"claude-3-opus",
false,
),
];
for (plan, provider, model, expected_access) in test_cases {
let claims = LlmTokenClaims {
plan,
..Default::default()
};
let result =
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
if expected_access {
assert!(
result.is_ok(),
"Expected access to be granted for plan {:?}, provider {:?}, model {}",
plan,
provider,
model
);
} else {
let error = result.expect_err(&format!(
"Expected access to be denied for plan {:?}, provider {:?}, model {}",
plan, provider, model
));
let response = error.into_response();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_for_staff() {
let config = Config::test();
let claims = LlmTokenClaims {
is_staff: true,
..Default::default()
};
// Staff should have access to all models
let test_cases = vec![
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
(LanguageModelProvider::Anthropic, "claude-2"),
(LanguageModelProvider::Anthropic, "claude-123-agi"),
(LanguageModelProvider::OpenAi, "gpt-4"),
(LanguageModelProvider::Google, "gemini-pro"),
];
for (provider, model) in test_cases {
let result =
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
assert!(
result.is_ok(),
"Expected staff to have access to provider {:?}, model {}",
provider,
model
);
}
}
}

View File

@@ -20,7 +20,6 @@ use std::future::Future;
use std::sync::Arc;
use anyhow::anyhow;
pub use queries::usages::{ActiveUserCount, TokenUsage};
pub use sea_orm::ConnectOptions;
use sea_orm::prelude::*;
use sea_orm::{

View File

@@ -2,5 +2,4 @@ use super::*;
pub mod billing_events;
pub mod providers;
pub mod revoked_access_tokens;
pub mod usages;

View File

@@ -1,15 +0,0 @@
use super::*;
impl LlmDatabase {
/// Returns whether the access token with the given `jti` has been revoked.
pub async fn is_access_token_revoked(&self, jti: &str) -> Result<bool> {
self.transaction(|tx| async move {
Ok(revoked_access_token::Entity::find()
.filter(revoked_access_token::Column::Jti.eq(jti))
.one(&*tx)
.await?
.is_some())
})
.await
}
}

View File

@@ -1,56 +1,12 @@
use crate::db::UserId;
use crate::llm::Cents;
use chrono::{Datelike, Duration};
use chrono::Datelike;
use futures::StreamExt as _;
use rpc::LanguageModelProvider;
use sea_orm::QuerySelect;
use std::{iter, str::FromStr};
use std::str::FromStr;
use strum::IntoEnumIterator as _;
use super::*;
#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
pub struct TokenUsage {
pub input: usize,
pub input_cache_creation: usize,
pub input_cache_read: usize,
pub output: usize,
}
impl TokenUsage {
pub fn total(&self) -> usize {
self.input + self.input_cache_creation + self.input_cache_read + self.output
}
}
#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
pub struct Usage {
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
pub input_tokens_this_minute: usize,
pub output_tokens_this_minute: usize,
pub tokens_this_day: usize,
pub tokens_this_month: TokenUsage,
pub spending_this_month: Cents,
pub lifetime_spending: Cents,
}
#[derive(Debug, PartialEq, Clone)]
pub struct ApplicationWideUsage {
pub provider: LanguageModelProvider,
pub model: String,
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
pub input_tokens_this_minute: usize,
pub output_tokens_this_minute: usize,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ActiveUserCount {
pub users_in_recent_minutes: usize,
pub users_in_recent_days: usize,
}
impl LlmDatabase {
pub async fn initialize_usage_measures(&mut self) -> Result<()> {
let all_measures = self
@@ -90,100 +46,6 @@ impl LlmDatabase {
Ok(())
}
pub async fn get_application_wide_usages_by_model(
&self,
now: DateTimeUtc,
) -> Result<Vec<ApplicationWideUsage>> {
self.transaction(|tx| async move {
let past_minute = now - Duration::minutes(1);
let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
let input_tokens_per_minute =
self.usage_measure_ids[&UsageMeasure::InputTokensPerMinute];
let output_tokens_per_minute =
self.usage_measure_ids[&UsageMeasure::OutputTokensPerMinute];
let mut results = Vec::new();
for ((provider, model_name), model) in self.models.iter() {
let mut usages = usage::Entity::find()
.filter(
usage::Column::Timestamp
.gte(past_minute.naive_utc())
.and(usage::Column::IsStaff.eq(false))
.and(usage::Column::ModelId.eq(model.id))
.and(
usage::Column::MeasureId
.eq(requests_per_minute)
.or(usage::Column::MeasureId.eq(tokens_per_minute)),
),
)
.stream(&*tx)
.await?;
let mut requests_this_minute = 0;
let mut tokens_this_minute = 0;
let mut input_tokens_this_minute = 0;
let mut output_tokens_this_minute = 0;
while let Some(usage) = usages.next().await {
let usage = usage?;
if usage.measure_id == requests_per_minute {
requests_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::RequestsPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
} else if usage.measure_id == tokens_per_minute {
tokens_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::TokensPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
} else if usage.measure_id == input_tokens_per_minute {
input_tokens_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::InputTokensPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
} else if usage.measure_id == output_tokens_per_minute {
output_tokens_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::OutputTokensPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
}
}
results.push(ApplicationWideUsage {
provider: *provider,
model: model_name.clone(),
requests_this_minute,
tokens_this_minute,
input_tokens_this_minute,
output_tokens_this_minute,
})
}
Ok(results)
})
.await
}
pub async fn get_user_spending_for_month(
&self,
user_id: UserId,
@@ -223,499 +85,6 @@ impl LlmDatabase {
})
.await
}
pub async fn get_usage(
&self,
user_id: UserId,
provider: LanguageModelProvider,
model_name: &str,
now: DateTimeUtc,
) -> Result<Usage> {
self.transaction(|tx| async move {
let model = self
.models
.get(&(provider, model_name.to_string()))
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
let usages = usage::Entity::find()
.filter(
usage::Column::UserId
.eq(user_id)
.and(usage::Column::ModelId.eq(model.id)),
)
.all(&*tx)
.await?;
let month = now.date_naive().month() as i32;
let year = now.date_naive().year();
let monthly_usage = monthly_usage::Entity::find()
.filter(
monthly_usage::Column::UserId
.eq(user_id)
.and(monthly_usage::Column::ModelId.eq(model.id))
.and(monthly_usage::Column::Month.eq(month))
.and(monthly_usage::Column::Year.eq(year)),
)
.one(&*tx)
.await?;
let lifetime_usage = lifetime_usage::Entity::find()
.filter(
lifetime_usage::Column::UserId
.eq(user_id)
.and(lifetime_usage::Column::ModelId.eq(model.id)),
)
.one(&*tx)
.await?;
let requests_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
let tokens_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
let input_tokens_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMinute)?;
let output_tokens_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMinute)?;
let tokens_this_day =
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
calculate_spending(
model,
monthly_usage.input_tokens as usize,
monthly_usage.cache_creation_input_tokens as usize,
monthly_usage.cache_read_input_tokens as usize,
monthly_usage.output_tokens as usize,
)
} else {
Cents::ZERO
};
let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
calculate_spending(
model,
lifetime_usage.input_tokens as usize,
lifetime_usage.cache_creation_input_tokens as usize,
lifetime_usage.cache_read_input_tokens as usize,
lifetime_usage.output_tokens as usize,
)
} else {
Cents::ZERO
};
Ok(Usage {
requests_this_minute,
tokens_this_minute,
input_tokens_this_minute,
output_tokens_this_minute,
tokens_this_day,
tokens_this_month: TokenUsage {
input: monthly_usage
.as_ref()
.map_or(0, |usage| usage.input_tokens as usize),
input_cache_creation: monthly_usage
.as_ref()
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
input_cache_read: monthly_usage
.as_ref()
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
output: monthly_usage
.as_ref()
.map_or(0, |usage| usage.output_tokens as usize),
},
spending_this_month,
lifetime_spending,
})
})
.await
}
pub async fn record_usage(
&self,
user_id: UserId,
is_staff: bool,
provider: LanguageModelProvider,
model_name: &str,
tokens: TokenUsage,
has_llm_subscription: bool,
max_monthly_spend: Cents,
free_tier_monthly_spending_limit: Cents,
now: DateTimeUtc,
) -> Result<Usage> {
self.transaction(|tx| async move {
let model = self.model(provider, model_name)?;
let usages = usage::Entity::find()
.filter(
usage::Column::UserId
.eq(user_id)
.and(usage::Column::ModelId.eq(model.id)),
)
.all(&*tx)
.await?;
let requests_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::RequestsPerMinute,
now,
1,
&tx,
)
.await?;
let tokens_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::TokensPerMinute,
now,
tokens.total(),
&tx,
)
.await?;
let input_tokens_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::InputTokensPerMinute,
now,
// Cache read input tokens are not counted for the purposes of rate limits (but they are still billed).
tokens.input + tokens.input_cache_creation,
&tx,
)
.await?;
let output_tokens_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::OutputTokensPerMinute,
now,
tokens.output,
&tx,
)
.await?;
let tokens_this_day = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::TokensPerDay,
now,
tokens.total(),
&tx,
)
.await?;
let month = now.date_naive().month() as i32;
let year = now.date_naive().year();
// Update monthly usage
let monthly_usage = monthly_usage::Entity::find()
.filter(
monthly_usage::Column::UserId
.eq(user_id)
.and(monthly_usage::Column::ModelId.eq(model.id))
.and(monthly_usage::Column::Month.eq(month))
.and(monthly_usage::Column::Year.eq(year)),
)
.one(&*tx)
.await?;
let monthly_usage = match monthly_usage {
Some(usage) => {
monthly_usage::Entity::update(monthly_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id),
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
),
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default()
})
.exec(&*tx)
.await?
}
None => {
monthly_usage::ActiveModel {
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
month: ActiveValue::set(month),
year: ActiveValue::set(year),
input_tokens: ActiveValue::set(tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
..Default::default()
}
.insert(&*tx)
.await?
}
};
let spending_this_month = calculate_spending(
model,
monthly_usage.input_tokens as usize,
monthly_usage.cache_creation_input_tokens as usize,
monthly_usage.cache_read_input_tokens as usize,
monthly_usage.output_tokens as usize,
);
if !is_staff
&& spending_this_month > free_tier_monthly_spending_limit
&& has_llm_subscription
&& (spending_this_month - free_tier_monthly_spending_limit) <= max_monthly_spend
{
billing_event::ActiveModel {
id: ActiveValue::not_set(),
idempotency_key: ActiveValue::not_set(),
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
input_tokens: ActiveValue::set(tokens.input as i64),
input_cache_creation_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
}
.insert(&*tx)
.await?;
}
// Update lifetime usage
let lifetime_usage = lifetime_usage::Entity::find()
.filter(
lifetime_usage::Column::UserId
.eq(user_id)
.and(lifetime_usage::Column::ModelId.eq(model.id)),
)
.one(&*tx)
.await?;
let lifetime_usage = match lifetime_usage {
Some(usage) => {
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id),
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
),
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default()
})
.exec(&*tx)
.await?
}
None => {
lifetime_usage::ActiveModel {
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
input_tokens: ActiveValue::set(tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
..Default::default()
}
.insert(&*tx)
.await?
}
};
let lifetime_spending = calculate_spending(
model,
lifetime_usage.input_tokens as usize,
lifetime_usage.cache_creation_input_tokens as usize,
lifetime_usage.cache_read_input_tokens as usize,
lifetime_usage.output_tokens as usize,
);
Ok(Usage {
requests_this_minute,
tokens_this_minute,
input_tokens_this_minute,
output_tokens_this_minute,
tokens_this_day,
tokens_this_month: TokenUsage {
input: monthly_usage.input_tokens as usize,
input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
input_cache_read: monthly_usage.cache_read_input_tokens as usize,
output: monthly_usage.output_tokens as usize,
},
spending_this_month,
lifetime_spending,
})
})
.await
}
/// Returns the active user count for the specified model.
pub async fn get_active_user_count(
&self,
provider: LanguageModelProvider,
model_name: &str,
now: DateTimeUtc,
) -> Result<ActiveUserCount> {
self.transaction(|tx| async move {
let minute_since = now - Duration::minutes(5);
let day_since = now - Duration::days(5);
let model = self
.models
.get(&(provider, model_name.to_string()))
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
let users_in_recent_minutes = usage::Entity::find()
.filter(
usage::Column::ModelId
.eq(model.id)
.and(usage::Column::MeasureId.eq(tokens_per_minute))
.and(usage::Column::Timestamp.gte(minute_since.naive_utc()))
.and(usage::Column::IsStaff.eq(false)),
)
.select_only()
.column(usage::Column::UserId)
.group_by(usage::Column::UserId)
.count(&*tx)
.await? as usize;
let users_in_recent_days = usage::Entity::find()
.filter(
usage::Column::ModelId
.eq(model.id)
.and(usage::Column::MeasureId.eq(tokens_per_minute))
.and(usage::Column::Timestamp.gte(day_since.naive_utc()))
.and(usage::Column::IsStaff.eq(false)),
)
.select_only()
.column(usage::Column::UserId)
.group_by(usage::Column::UserId)
.count(&*tx)
.await? as usize;
Ok(ActiveUserCount {
users_in_recent_minutes,
users_in_recent_days,
})
})
.await
}
async fn update_usage_for_measure(
&self,
user_id: UserId,
is_staff: bool,
model_id: ModelId,
usages: &[usage::Model],
usage_measure: UsageMeasure,
now: DateTimeUtc,
usage_to_add: usize,
tx: &DatabaseTransaction,
) -> Result<usize> {
let now = now.naive_utc();
let measure_id = *self
.usage_measure_ids
.get(&usage_measure)
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
let mut id = None;
let mut timestamp = now;
let mut buckets = vec![0_i64];
if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
id = Some(old_usage.id);
let (live_buckets, buckets_since) =
Self::get_live_buckets(old_usage, now, usage_measure);
if !live_buckets.is_empty() {
buckets.clear();
buckets.extend_from_slice(live_buckets);
buckets.extend(iter::repeat(0).take(buckets_since));
timestamp =
old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
}
}
*buckets.last_mut().unwrap() += usage_to_add as i64;
let total_usage = buckets.iter().sum::<i64>() as usize;
let mut model = usage::ActiveModel {
user_id: ActiveValue::set(user_id),
is_staff: ActiveValue::set(is_staff),
model_id: ActiveValue::set(model_id),
measure_id: ActiveValue::set(measure_id),
timestamp: ActiveValue::set(timestamp),
buckets: ActiveValue::set(buckets),
..Default::default()
};
if let Some(id) = id {
model.id = ActiveValue::unchanged(id);
model.update(tx).await?;
} else {
usage::Entity::insert(model)
.exec_without_returning(tx)
.await?;
}
Ok(total_usage)
}
fn get_usage_for_measure(
&self,
usages: &[usage::Model],
now: DateTimeUtc,
usage_measure: UsageMeasure,
) -> Result<usize> {
let now = now.naive_utc();
let measure_id = *self
.usage_measure_ids
.get(&usage_measure)
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
return Ok(0);
};
let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
Ok(live_buckets.iter().sum::<i64>() as _)
}
fn get_live_buckets(
usage: &usage::Model,
now: chrono::NaiveDateTime,
measure: UsageMeasure,
) -> (&[i64], usize) {
let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
let buckets_since_usage =
seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
let buckets_since_usage = buckets_since_usage.ceil() as usize;
let mut live_buckets = &[] as &[i64];
if buckets_since_usage < measure.bucket_count() {
let expired_bucket_count =
(usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
live_buckets = &usage.buckets[expired_bucket_count..];
while live_buckets.first() == Some(&0) {
live_buckets = &live_buckets[1..];
}
}
(live_buckets, buckets_since_usage)
}
}
fn calculate_spending(
@@ -741,32 +110,3 @@ fn calculate_spending(
+ output_token_cost;
Cents::new(spending as u32)
}
const MINUTE_BUCKET_COUNT: usize = 12;
const DAY_BUCKET_COUNT: usize = 48;
impl UsageMeasure {
fn bucket_count(&self) -> usize {
match self {
UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
UsageMeasure::TokensPerMinute
| UsageMeasure::InputTokensPerMinute
| UsageMeasure::OutputTokensPerMinute => MINUTE_BUCKET_COUNT,
UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
}
}
fn total_duration(&self) -> Duration {
match self {
UsageMeasure::RequestsPerMinute => Duration::minutes(1),
UsageMeasure::TokensPerMinute
| UsageMeasure::InputTokensPerMinute
| UsageMeasure::OutputTokensPerMinute => Duration::minutes(1),
UsageMeasure::TokensPerDay => Duration::hours(24),
}
}
fn bucket_duration(&self) -> Duration {
self.total_duration() / self.bucket_count() as i32
}
}

View File

@@ -1,8 +1,6 @@
pub mod billing_event;
pub mod lifetime_usage;
pub mod model;
pub mod monthly_usage;
pub mod provider;
pub mod revoked_access_token;
pub mod usage;
pub mod usage_measure;

View File

@@ -1,20 +0,0 @@
use crate::{db::UserId, llm::db::ModelId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "lifetime_usages")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub user_id: UserId,
pub model_id: ModelId,
pub input_tokens: i64,
pub cache_creation_input_tokens: i64,
pub cache_read_input_tokens: i64,
pub output_tokens: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,19 +0,0 @@
use chrono::NaiveDateTime;
use sea_orm::entity::prelude::*;
use crate::llm::db::RevokedAccessTokenId;
/// A revoked access token.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "revoked_access_tokens")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: RevokedAccessTokenId,
pub jti: String,
pub revoked_at: NaiveDateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,6 +1,4 @@
mod billing_tests;
mod provider_tests;
mod usage_tests;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;

View File

@@ -1,152 +0,0 @@
use crate::{
Cents,
db::UserId,
llm::{
FREE_TIER_MONTHLY_SPENDING_LIMIT,
db::{LlmDatabase, TokenUsage, queries::providers::ModelParams},
},
test_llm_db,
};
use chrono::{DateTime, Utc};
use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider;
test_llm_db!(
test_billing_limit_exceeded,
test_billing_limit_exceeded_postgres
);
async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
let provider = LanguageModelProvider::Anthropic;
let model = "fake-claude-limerick";
const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
// Initialize the database and insert the model
db.initialize().await.unwrap();
db.insert_models(&[ModelParams {
provider,
name: model.to_string(),
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 50_000,
price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
}])
.await
.unwrap();
// Set a fixed datetime for consistent testing
let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
.unwrap()
.with_timezone(&Utc);
let user_id = UserId::from_proto(123);
let max_monthly_spend = Cents::from_dollars(11);
// Record usage that brings us close to the limit but doesn't exceed it
// Let's say we use $10.50 worth of tokens
let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens
let usage = TokenUsage {
input: tokens_to_use,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
// Verify that before we record any usage, there are 0 billing events
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 0);
db.record_usage(
user_id,
false,
provider,
model,
usage,
true,
max_monthly_spend,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
// Verify the recorded usage and spending
let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
// Verify that we exceeded the free tier usage
assert_eq!(recorded_usage.spending_this_month, Cents::new(1050));
assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT);
// Verify that there is one `billing_event` record
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 1);
let (billing_event, _model) = &billing_events[0];
assert_eq!(billing_event.user_id, user_id);
assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
assert_eq!(billing_event.input_cache_creation_tokens, 0);
assert_eq!(billing_event.input_cache_read_tokens, 0);
assert_eq!(billing_event.output_tokens, 0);
// Record usage that puts us at $20.50
let usage_2 = TokenUsage {
input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
db.record_usage(
user_id,
false,
provider,
model,
usage_2,
true,
max_monthly_spend,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
// Verify the updated usage and spending
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(updated_usage.spending_this_month, Cents::new(2050));
// Verify that there are now two billing events
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 2);
let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit
let usage_exceeding = TokenUsage {
input: tokens_to_exceed,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
// This should still create a billing event as it's the first request that exceeds the limit
db.record_usage(
user_id,
false,
provider,
model,
usage_exceeding,
true,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
max_monthly_spend,
now,
)
.await
.unwrap();
// Verify the updated usage and spending
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(updated_usage.spending_this_month, Cents::new(2150));
// Verify that we never exceed the user max spending for the user
// and avoid charging them.
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 2);
}

View File

@@ -1,306 +0,0 @@
use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT;
use crate::{
Cents,
db::UserId,
llm::db::{
LlmDatabase, TokenUsage,
queries::{providers::ModelParams, usages::Usage},
},
test_llm_db,
};
use chrono::{DateTime, Duration, Utc};
use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider;
test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
async fn test_tracking_usage(db: &mut LlmDatabase) {
let provider = LanguageModelProvider::Anthropic;
let model = "claude-3-5-sonnet";
db.initialize().await.unwrap();
db.insert_models(&[ModelParams {
provider,
name: model.to_string(),
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 50_000,
price_per_million_input_tokens: 50,
price_per_million_output_tokens: 50,
}])
.await
.unwrap();
// We're using a fixed datetime to prevent flakiness based on the clock.
let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
.unwrap()
.with_timezone(&Utc);
let user_id = UserId::from_proto(123);
let now = t0;
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let now = t0 + Duration::seconds(10);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 2000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 3000,
input_tokens_this_minute: 3000,
output_tokens_this_minute: 0,
tokens_this_day: 3000,
tokens_this_month: TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let now = t0 + Duration::seconds(60);
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 2000,
input_tokens_this_minute: 2000,
output_tokens_this_minute: 0,
tokens_this_day: 3000,
tokens_this_month: TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let now = t0 + Duration::seconds(60);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 5000,
input_tokens_this_minute: 5000,
output_tokens_this_minute: 0,
tokens_this_day: 6000,
tokens_this_month: TokenUsage {
input: 6000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let t1 = t0 + Duration::hours(24);
let now = t1;
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 0,
tokens_this_minute: 0,
input_tokens_this_minute: 0,
output_tokens_this_minute: 0,
tokens_this_day: 5000,
tokens_this_month: TokenUsage {
input: 6000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 4000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 4000,
input_tokens_this_minute: 4000,
output_tokens_this_minute: 0,
tokens_this_day: 9000,
tokens_this_month: TokenUsage {
input: 10000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
// We're using a fixed datetime to prevent flakiness based on the clock.
let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z")
.unwrap()
.with_timezone(&Utc);
// Test cache creation input tokens
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 500,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 1500,
input_tokens_this_minute: 1500,
output_tokens_this_minute: 0,
tokens_this_day: 1500,
tokens_this_month: TokenUsage {
input: 1000,
input_cache_creation: 500,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
// Test cache read input tokens
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 0,
input_cache_read: 300,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 2800,
input_tokens_this_minute: 2500,
output_tokens_this_minute: 0,
tokens_this_day: 2800,
tokens_this_month: TokenUsage {
input: 2000,
input_cache_creation: 500,
input_cache_read: 300,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
}

View File

@@ -9,14 +9,14 @@ use axum::{
use collab::api::CloudflareIpCountryHeader;
use collab::api::billing::sync_llm_usage_with_stripe_periodically;
use collab::llm::{db::LlmDatabase, log_usage_periodically};
use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller;
use collab::{
AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
env, executor::Executor, rpc::ResultExt,
};
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically, llm::LlmState};
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
use db::Database;
use std::{
env::args,
@@ -74,11 +74,10 @@ async fn main() -> Result<()> {
let mode = match args.next().as_deref() {
Some("collab") => ServiceMode::Collab,
Some("api") => ServiceMode::Api,
Some("llm") => ServiceMode::Llm,
Some("all") => ServiceMode::All,
_ => {
return Err(anyhow!(
"usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
"usage: collab <version | migrate | seed | serve <api|collab|all>>"
))?;
}
};
@@ -97,20 +96,9 @@ async fn main() -> Result<()> {
let mut on_shutdown = None;
if mode.is_llm() {
setup_llm_database(&config).await?;
let state = LlmState::new(config.clone(), Executor::Production).await?;
log_usage_periodically(state.clone());
app = app
.merge(collab::llm::routes())
.layer(Extension(state.clone()));
}
if mode.is_collab() || mode.is_api() {
setup_app_database(&config).await?;
setup_llm_database(&config).await?;
let state = AppState::new(config, Executor::Production).await?;
@@ -336,18 +324,11 @@ async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown"))
}
async fn handle_liveness_probe(
app_state: Option<Extension<Arc<AppState>>>,
llm_state: Option<Extension<Arc<LlmState>>>,
) -> Result<String> {
async fn handle_liveness_probe(app_state: Option<Extension<Arc<AppState>>>) -> Result<String> {
if let Some(state) = app_state {
state.db.get_all_users(0, 1).await?;
}
if let Some(llm_state) = llm_state {
llm_state.db.list_providers().await?;
}
Ok("ok".to_string())
}

View File

@@ -694,7 +694,15 @@ async fn test_collaborating_with_code_actions(
// Confirming the code action will trigger a resolve request.
let confirm_action = editor_b
.update_in(cx_b, |editor, window, cx| {
Editor::confirm_code_action(editor, &ConfirmCodeAction { item_ix: Some(0) }, window, cx)
Editor::confirm_code_action(
editor,
&ConfirmCodeAction {
item_ix: Some(0),
from_mouse_context_menu: false,
},
window,
cx,
)
})
.unwrap();
fake_language_server.set_request_handler::<lsp::request::CodeActionResolveRequest, _, _>(

View File

@@ -191,6 +191,14 @@ pub fn components() -> AllComponents {
all_components
}
// #[derive(Debug, Clone, PartialEq, Eq, Hash)]
// pub enum ComponentStatus {
// WorkInProgress,
// EngineeringReady,
// Live,
// Deprecated,
// }
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ComponentScope {
Collaboration,
@@ -241,24 +249,30 @@ pub struct ComponentExample {
impl RenderOnce for ComponentExample {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
div()
.pt_2()
.w_full()
.flex()
.flex_col()
.gap_3()
.child(
div()
.child(self.variant_name.clone())
.text_size(rems(1.25))
.text_color(cx.theme().colors().text),
.flex()
.flex_col()
.child(
div()
.child(self.variant_name.clone())
.text_size(rems(1.0))
.text_color(cx.theme().colors().text),
)
.when_some(self.description, |this, description| {
this.child(
div()
.text_size(rems(0.875))
.text_color(cx.theme().colors().text_muted)
.child(description.clone()),
)
}),
)
.when_some(self.description, |this, description| {
this.child(
div()
.text_size(rems(0.9375))
.text_color(cx.theme().colors().text_muted)
.child(description.clone()),
)
})
.child(
div()
.flex()
@@ -268,11 +282,11 @@ impl RenderOnce for ComponentExample {
.justify_center()
.p_8()
.border_1()
.border_color(cx.theme().colors().border)
.border_color(cx.theme().colors().border.opacity(0.5))
.bg(pattern_slash(
cx.theme().colors().surface_background.opacity(0.5),
24.0,
24.0,
12.0,
12.0,
))
.shadow_sm()
.child(self.element),

View File

@@ -16,12 +16,16 @@ default = []
[dependencies]
client.workspace = true
collections.workspace = true
component.workspace = true
gpui.workspace = true
languages.workspace = true
notifications.workspace = true
project.workspace = true
ui.workspace = true
workspace.workspace = true
notifications.workspace = true
collections.workspace = true
ui_input.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
db.workspace = true
anyhow.workspace = true
serde.workspace = true

View File

@@ -2,6 +2,8 @@
//!
//! A view for exploring Zed components.
mod persistence;
use std::iter::Iterator;
use std::sync::Arc;
@@ -9,24 +11,27 @@ use client::UserStore;
use component::{ComponentId, ComponentMetadata, components};
use gpui::{
App, Entity, EventEmitter, FocusHandle, Focusable, Task, WeakEntity, Window, list, prelude::*,
uniform_list,
};
use collections::HashMap;
use gpui::{ListState, ScrollHandle, UniformListScrollHandle};
use gpui::{ListState, ScrollHandle, ScrollStrategy, UniformListScrollHandle};
use languages::LanguageRegistry;
use notifications::status_toast::{StatusToast, ToastIcon};
use persistence::COMPONENT_PREVIEW_DB;
use project::Project;
use ui::{Divider, ListItem, ListSubHeader, prelude::*};
use ui::{Divider, HighlightedLabel, ListItem, ListSubHeader, prelude::*};
use ui_input::SingleLineInput;
use workspace::{AppState, ItemId, SerializableItem};
use workspace::{Item, Workspace, WorkspaceId, item::ItemEvent};
pub fn init(app_state: Arc<AppState>, cx: &mut App) {
workspace::register_serializable_item::<ComponentPreview>(cx);
let app_state = app_state.clone();
cx.observe_new(move |workspace: &mut Workspace, _, cx| {
cx.observe_new(move |workspace: &mut Workspace, _window, cx| {
let app_state = app_state.clone();
let weak_workspace = cx.entity().downgrade();
@@ -44,6 +49,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
user_store,
None,
None,
window,
cx,
)
});
@@ -64,13 +70,13 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
enum PreviewEntry {
AllComponents,
Separator,
Component(ComponentMetadata),
Component(ComponentMetadata, Option<Vec<usize>>),
SectionHeader(SharedString),
}
impl From<ComponentMetadata> for PreviewEntry {
fn from(component: ComponentMetadata) -> Self {
PreviewEntry::Component(component)
PreviewEntry::Component(component, None)
}
}
@@ -88,6 +94,7 @@ enum PreviewPage {
}
struct ComponentPreview {
workspace_id: Option<WorkspaceId>,
focus_handle: FocusHandle,
_view_scroll_handle: ScrollHandle,
nav_scroll_handle: UniformListScrollHandle,
@@ -99,6 +106,8 @@ struct ComponentPreview {
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
filter_editor: Entity<SingleLineInput>,
filter_text: String,
}
impl ComponentPreview {
@@ -108,11 +117,14 @@ impl ComponentPreview {
user_store: Entity<UserStore>,
selected_index: impl Into<Option<usize>>,
active_page: Option<PreviewPage>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let sorted_components = components().all_sorted();
let selected_index = selected_index.into().unwrap_or(0);
let active_page = active_page.unwrap_or(PreviewPage::AllComponents);
let filter_editor =
cx.new(|cx| SingleLineInput::new(window, cx, "Find components or usages…"));
let component_list = ListState::new(
sorted_components.len(),
@@ -132,6 +144,7 @@ impl ComponentPreview {
);
let mut component_preview = Self {
workspace_id: None,
focus_handle: cx.focus_handle(),
_view_scroll_handle: ScrollHandle::new(),
nav_scroll_handle: UniformListScrollHandle::new(),
@@ -143,6 +156,8 @@ impl ComponentPreview {
components: sorted_components,
component_list,
cursor_index: selected_index,
filter_editor,
filter_text: String::new(),
};
if component_preview.cursor_index > 0 {
@@ -154,6 +169,13 @@ impl ComponentPreview {
component_preview
}
pub fn active_page_id(&self, _cx: &App) -> ActivePageId {
match &self.active_page {
PreviewPage::AllComponents => ActivePageId::default(),
PreviewPage::Component(component_id) => ActivePageId(component_id.0.to_string()),
}
}
fn scroll_to_preview(&mut self, ix: usize, cx: &mut Context<Self>) {
self.component_list.scroll_to_reveal_item(ix);
self.cursor_index = ix;
@@ -162,6 +184,7 @@ impl ComponentPreview {
fn set_active_page(&mut self, page: PreviewPage, cx: &mut Context<Self>) {
self.active_page = page;
cx.emit(ItemEvent::UpdateTab);
cx.notify();
}
@@ -169,20 +192,94 @@ impl ComponentPreview {
self.components[ix].clone()
}
fn filtered_components(&self) -> Vec<ComponentMetadata> {
if self.filter_text.is_empty() {
return self.components.clone();
}
let filter = self.filter_text.to_lowercase();
self.components
.iter()
.filter(|component| {
let component_name = component.name().to_lowercase();
let scope_name = component.scope().to_string().to_lowercase();
let description = component
.description()
.map(|d| d.to_lowercase())
.unwrap_or_default();
component_name.contains(&filter)
|| scope_name.contains(&filter)
|| description.contains(&filter)
})
.cloned()
.collect()
}
fn scope_ordered_entries(&self) -> Vec<PreviewEntry> {
use std::collections::HashMap;
let mut scope_groups: HashMap<ComponentScope, Vec<ComponentMetadata>> = HashMap::default();
let mut scope_groups: HashMap<
ComponentScope,
Vec<(ComponentMetadata, Option<Vec<usize>>)>,
> = HashMap::default();
let lowercase_filter = self.filter_text.to_lowercase();
for component in &self.components {
scope_groups
.entry(component.scope())
.or_insert_with(Vec::new)
.push(component.clone());
if self.filter_text.is_empty() {
scope_groups
.entry(component.scope())
.or_insert_with(Vec::new)
.push((component.clone(), None));
continue;
}
// let full_component_name = component.name();
let scopeless_name = component.scopeless_name();
let scope_name = component.scope().to_string();
let description = component.description().unwrap_or_default();
let lowercase_scopeless = scopeless_name.to_lowercase();
let lowercase_scope = scope_name.to_lowercase();
let lowercase_desc = description.to_lowercase();
if lowercase_scopeless.contains(&lowercase_filter) {
if let Some(index) = lowercase_scopeless.find(&lowercase_filter) {
let end = index + lowercase_filter.len();
if end <= scopeless_name.len() {
let mut positions = Vec::new();
for i in index..end {
if scopeless_name.is_char_boundary(i) {
positions.push(i);
}
}
if !positions.is_empty() {
scope_groups
.entry(component.scope())
.or_insert_with(Vec::new)
.push((component.clone(), Some(positions)));
continue;
}
}
}
}
if lowercase_scopeless.contains(&lowercase_filter)
|| lowercase_scope.contains(&lowercase_filter)
|| lowercase_desc.contains(&lowercase_filter)
{
scope_groups
.entry(component.scope())
.or_insert_with(Vec::new)
.push((component.clone(), None));
}
}
// Sort the components in each group
for components in scope_groups.values_mut() {
components.sort_by_key(|c| c.name().to_lowercase());
components.sort_by_key(|(c, _)| c.sort_name());
}
let mut entries = Vec::new();
@@ -204,10 +301,10 @@ impl ComponentPreview {
if !components.is_empty() {
entries.push(PreviewEntry::SectionHeader(scope.to_string().into()));
let mut sorted_components = components;
sorted_components.sort_by_key(|component| component.sort_name());
sorted_components.sort_by_key(|(component, _)| component.sort_name());
for component in sorted_components {
entries.push(PreviewEntry::Component(component));
for (component, positions) in sorted_components {
entries.push(PreviewEntry::Component(component, positions));
}
}
}
@@ -219,10 +316,10 @@ impl ComponentPreview {
entries.push(PreviewEntry::Separator);
entries.push(PreviewEntry::SectionHeader("Uncategorized".into()));
let mut sorted_components = components.clone();
sorted_components.sort_by_key(|c| c.sort_name());
sorted_components.sort_by_key(|(c, _)| c.sort_name());
for component in sorted_components {
entries.push(PreviewEntry::Component(component.clone()));
for (component, positions) in sorted_components {
entries.push(PreviewEntry::Component(component, positions));
}
}
}
@@ -237,14 +334,33 @@ impl ComponentPreview {
cx: &Context<Self>,
) -> impl IntoElement + use<> {
match entry {
PreviewEntry::Component(component_metadata) => {
PreviewEntry::Component(component_metadata, highlight_positions) => {
let id = component_metadata.id();
let selected = self.active_page == PreviewPage::Component(id.clone());
let name = component_metadata.scopeless_name();
ListItem::new(ix)
.child(
Label::new(component_metadata.scopeless_name().clone())
.color(Color::Default),
)
.child(if let Some(_positions) = highlight_positions {
let name_lower = name.to_lowercase();
let filter_lower = self.filter_text.to_lowercase();
let valid_positions = if let Some(start) = name_lower.find(&filter_lower) {
let end = start + filter_lower.len();
(start..end).collect()
} else {
Vec::new()
};
if valid_positions.is_empty() {
Label::new(name.clone())
.color(Color::Default)
.into_any_element()
} else {
HighlightedLabel::new(name.clone(), valid_positions).into_any_element()
}
} else {
Label::new(name.clone())
.color(Color::Default)
.into_any_element()
})
.selectable(true)
.toggle_state(selected)
.inset(true)
@@ -282,20 +398,70 @@ impl ComponentPreview {
}
fn update_component_list(&mut self, cx: &mut Context<Self>) {
let new_len = self.scope_ordered_entries().len();
let entries = self.scope_ordered_entries();
let new_len = entries.len();
let weak_entity = cx.entity().downgrade();
if new_len > 0 {
self.nav_scroll_handle
.scroll_to_item(0, ScrollStrategy::Top);
}
let filtered_components = self.filtered_components();
if !self.filter_text.is_empty() && !matches!(self.active_page, PreviewPage::AllComponents) {
if let PreviewPage::Component(ref component_id) = self.active_page {
let component_still_visible = filtered_components
.iter()
.any(|component| component.id() == *component_id);
if !component_still_visible {
if !filtered_components.is_empty() {
let first_component = &filtered_components[0];
self.set_active_page(PreviewPage::Component(first_component.id()), cx);
} else {
self.set_active_page(PreviewPage::AllComponents, cx);
}
}
}
}
self.component_list = ListState::new(
filtered_components.len(),
gpui::ListAlignment::Top,
px(1500.0),
{
let components = filtered_components.clone();
let this = cx.entity().downgrade();
move |ix, window: &mut Window, cx: &mut App| {
if ix >= components.len() {
return div().w_full().h_0().into_any_element();
}
this.update(cx, |this, cx| {
let component = &components[ix];
this.render_preview(component, window, cx)
.into_any_element()
})
.unwrap()
}
},
);
let new_list = ListState::new(
new_len,
gpui::ListAlignment::Top,
px(1500.0),
move |ix, window, cx| {
if ix >= entries.len() {
return div().w_full().h_0().into_any_element();
}
let entry = &entries[ix];
weak_entity
.update(cx, |this, cx| match entry {
PreviewEntry::Component(component) => this
PreviewEntry::Component(component, _) => this
.render_preview(component, window, cx)
.into_any_element(),
PreviewEntry::SectionHeader(shared_string) => this
@@ -309,6 +475,7 @@ impl ComponentPreview {
);
self.component_list = new_list;
cx.emit(ItemEvent::UpdateTab);
}
fn render_scope_header(
@@ -377,16 +544,27 @@ impl ComponentPreview {
.into_any_element()
}
fn render_all_components(&self) -> impl IntoElement {
fn render_all_components(&self, cx: &Context<Self>) -> impl IntoElement {
v_flex()
.id("component-list")
.px_8()
.pt_4()
.size_full()
.child(
list(self.component_list.clone())
.flex_grow()
.with_sizing_behavior(gpui::ListSizingBehavior::Auto),
if self.filtered_components().is_empty() && !self.filter_text.is_empty() {
div()
.size_full()
.items_center()
.justify_center()
.text_color(cx.theme().colors().text_muted)
.child(format!("No components matching '{}'.", self.filter_text))
.into_any_element()
} else {
list(self.component_list.clone())
.flex_grow()
.with_sizing_behavior(gpui::ListSizingBehavior::Auto)
.into_any_element()
},
)
}
@@ -432,6 +610,19 @@ impl ComponentPreview {
impl Render for ComponentPreview {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
// TODO: move this into the struct
let current_filter = self.filter_editor.update(cx, |input, cx| {
if input.is_empty(cx) {
String::new()
} else {
input.editor().read(cx).text(cx).to_string()
}
});
if current_filter != self.filter_text {
self.filter_text = current_filter;
self.update_component_list(cx);
}
let sidebar_entries = self.scope_ordered_entries();
let active_page = self.active_page.clone();
@@ -449,14 +640,22 @@ impl Render for ComponentPreview {
.border_color(cx.theme().colors().border)
.h_full()
.child(
uniform_list(
gpui::uniform_list(
cx.entity().clone(),
"component-nav",
sidebar_entries.len(),
move |this, range, _window, cx| {
range
.map(|ix| {
this.render_sidebar_entry(ix, &sidebar_entries[ix], cx)
.filter_map(|ix| {
if ix < sidebar_entries.len() {
Some(this.render_sidebar_entry(
ix,
&sidebar_entries[ix],
cx,
))
} else {
None
}
})
.collect()
},
@@ -481,12 +680,29 @@ impl Render for ComponentPreview {
),
),
)
.child(match active_page {
PreviewPage::AllComponents => self.render_all_components().into_any_element(),
PreviewPage::Component(id) => self
.render_component_page(&id, window, cx)
.into_any_element(),
})
.child(
v_flex()
.id("content-area")
.flex_1()
.size_full()
.overflow_hidden()
.child(
div()
.p_2()
.w_full()
.border_b_1()
.border_color(cx.theme().colors().border)
.child(self.filter_editor.clone()),
)
.child(match active_page {
PreviewPage::AllComponents => {
self.render_all_components(cx).into_any_element()
}
PreviewPage::Component(id) => self
.render_component_page(&id, window, cx)
.into_any_element(),
}),
)
}
}
@@ -498,6 +714,21 @@ impl Focusable for ComponentPreview {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ActivePageId(pub String);
impl Default for ActivePageId {
fn default() -> Self {
ActivePageId("AllComponents".to_string())
}
}
impl From<ComponentId> for ActivePageId {
fn from(id: ComponentId) -> Self {
ActivePageId(id.0.to_string())
}
}
impl Item for ComponentPreview {
type Event = ItemEvent;
@@ -516,7 +747,7 @@ impl Item for ComponentPreview {
fn clone_on_split(
&self,
_workspace_id: Option<WorkspaceId>,
_window: &mut Window,
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<gpui::Entity<Self>>
where
@@ -535,6 +766,7 @@ impl Item for ComponentPreview {
user_store,
selected_index,
Some(active_page),
window,
cx,
)
}))
@@ -543,6 +775,15 @@ impl Item for ComponentPreview {
fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) {
f(*event)
}
fn added_to_workspace(
&mut self,
workspace: &mut Workspace,
_window: &mut Window,
_cx: &mut Context<Self>,
) {
self.workspace_id = workspace.database_id();
}
}
impl SerializableItem for ComponentPreview {
@@ -553,26 +794,53 @@ impl SerializableItem for ComponentPreview {
fn deserialize(
project: Entity<Project>,
workspace: WeakEntity<Workspace>,
_workspace_id: WorkspaceId,
_item_id: ItemId,
workspace_id: WorkspaceId,
item_id: ItemId,
window: &mut Window,
cx: &mut App,
) -> Task<gpui::Result<Entity<Self>>> {
let deserialized_active_page =
match COMPONENT_PREVIEW_DB.get_active_page(item_id, workspace_id) {
Ok(page) => {
if let Some(page) = page {
ActivePageId(page)
} else {
ActivePageId::default()
}
}
Err(_) => ActivePageId::default(),
};
let user_store = project.read(cx).user_store().clone();
let language_registry = project.read(cx).languages().clone();
let preview_page = if deserialized_active_page.0 == ActivePageId::default().0 {
Some(PreviewPage::default())
} else {
let component_str = deserialized_active_page.0;
let component_registry = components();
let all_components = component_registry.all();
let found_component = all_components.iter().find(|c| c.id().0 == component_str);
if let Some(component) = found_component {
Some(PreviewPage::Component(component.id().clone()))
} else {
Some(PreviewPage::default())
}
};
window.spawn(cx, async move |cx| {
let user_store = user_store.clone();
let language_registry = language_registry.clone();
let weak_workspace = workspace.clone();
cx.update(|_, cx| {
cx.update(move |window, cx| {
Ok(cx.new(|cx| {
ComponentPreview::new(
weak_workspace,
language_registry,
user_store,
None,
None,
preview_page,
window,
cx,
)
}))
@@ -581,34 +849,41 @@ impl SerializableItem for ComponentPreview {
}
fn cleanup(
_workspace_id: WorkspaceId,
_alive_items: Vec<ItemId>,
workspace_id: WorkspaceId,
alive_items: Vec<ItemId>,
_window: &mut Window,
_cx: &mut App,
cx: &mut App,
) -> Task<gpui::Result<()>> {
Task::ready(Ok(()))
// window.spawn(cx, |_| {
// ...
// })
cx.background_spawn(async move {
COMPONENT_PREVIEW_DB
.delete_unloaded_items(workspace_id, alive_items)
.await
})
}
fn serialize(
&mut self,
_workspace: &mut Workspace,
_item_id: ItemId,
item_id: ItemId,
_closing: bool,
_window: &mut Window,
_cx: &mut Context<Self>,
cx: &mut Context<Self>,
) -> Option<Task<gpui::Result<()>>> {
// TODO: Serialize the active index so we can re-open to the same place
None
let active_page = self.active_page_id(cx);
let workspace_id = self.workspace_id?;
Some(cx.background_spawn(async move {
COMPONENT_PREVIEW_DB
.save_active_page(item_id, workspace_id, active_page.0)
.await
}))
}
fn should_serialize(&self, _event: &Self::Event) -> bool {
false
fn should_serialize(&self, event: &Self::Event) -> bool {
matches!(event, ItemEvent::UpdateTab)
}
}
// TODO: use language registry to allow rendering markdown
#[derive(IntoElement)]
pub struct ComponentPreviewPage {
// languages: Arc<LanguageRegistry>,

View File

@@ -0,0 +1,73 @@
use anyhow::Result;
use db::{define_connection, query, sqlez::statement::Statement, sqlez_macros::sql};
use workspace::{ItemId, WorkspaceDb, WorkspaceId};
define_connection! {
pub static ref COMPONENT_PREVIEW_DB: ComponentPreviewDb<WorkspaceDb> =
&[sql!(
CREATE TABLE component_previews (
workspace_id INTEGER,
item_id INTEGER UNIQUE,
active_page_id TEXT,
PRIMARY KEY(workspace_id, item_id),
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE
) STRICT;
)];
}
impl ComponentPreviewDb {
pub async fn save_active_page(
&self,
item_id: ItemId,
workspace_id: WorkspaceId,
active_page_id: String,
) -> Result<()> {
let query = "INSERT INTO component_previews(item_id, workspace_id, active_page_id)
VALUES (?1, ?2, ?3)
ON CONFLICT DO UPDATE SET
active_page_id = ?3";
self.write(move |conn| {
let mut statement = Statement::prepare(conn, query)?;
let mut next_index = statement.bind(&item_id, 1)?;
next_index = statement.bind(&workspace_id, next_index)?;
statement.bind(&active_page_id, next_index)?;
statement.exec()
})
.await
}
query! {
pub fn get_active_page(item_id: ItemId, workspace_id: WorkspaceId) -> Result<Option<String>> {
SELECT active_page_id
FROM component_previews
WHERE item_id = ? AND workspace_id = ?
}
}
pub async fn delete_unloaded_items(
&self,
workspace: WorkspaceId,
alive_items: Vec<ItemId>,
) -> Result<()> {
let placeholders = alive_items
.iter()
.map(|_| "?")
.collect::<Vec<&str>>()
.join(", ");
let query = format!(
"DELETE FROM component_previews WHERE workspace_id = ? AND item_id NOT IN ({placeholders})"
);
self.write(move |conn| {
let mut statement = Statement::prepare(conn, query)?;
let mut next_index = statement.bind(&workspace, 1)?;
for id in alive_items {
next_index = statement.bind(&id, next_index)?;
}
statement.exec()
})
.await
}
}

View File

@@ -53,16 +53,18 @@ impl Tool for ContextServerTool {
true
}
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
match &self.tool.input_schema {
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
let mut schema = self.tool.input_schema.clone();
assistant_tool::adapt_schema_to_format(&mut schema, format)?;
Ok(match schema {
serde_json::Value::Null => {
serde_json::json!({ "type": "object", "properties": [] })
}
serde_json::Value::Object(map) if map.is_empty() => {
serde_json::json!({ "type": "object", "properties": [] })
}
_ => self.tool.input_schema.clone(),
}
_ => schema,
})
}
fn ui_text(&self, _input: &serde_json::Value) -> String {

View File

@@ -33,6 +33,8 @@ pub enum Model {
Gpt4o,
#[serde(alias = "gpt-4", rename = "gpt-4")]
Gpt4,
#[serde(alias = "gpt-4.1", rename = "gpt-4.1")]
Gpt4_1,
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
Gpt3_5Turbo,
#[serde(alias = "o1", rename = "o1")]
@@ -50,6 +52,8 @@ pub enum Model {
Claude3_7SonnetThinking,
#[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
Gemini20Flash,
#[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")]
Gemini25Pro,
}
impl Model {
@@ -57,11 +61,12 @@ impl Model {
match self {
Self::Gpt4o
| Self::Gpt4
| Self::Gpt4_1
| Self::Gpt3_5Turbo
| Self::Claude3_5Sonnet
| Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking => true,
Self::O3Mini | Self::O1 | Self::Gemini20Flash => false,
Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
}
}
@@ -69,6 +74,7 @@ impl Model {
match id {
"gpt-4o" => Ok(Self::Gpt4o),
"gpt-4" => Ok(Self::Gpt4),
"gpt-4.1" => Ok(Self::Gpt4_1),
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
"o1" => Ok(Self::O1),
"o3-mini" => Ok(Self::O3Mini),
@@ -76,6 +82,7 @@ impl Model {
"claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
"claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
"gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
"gemini-2.5-pro" => Ok(Self::Gemini25Pro),
_ => Err(anyhow!("Invalid model id: {}", id)),
}
}
@@ -84,6 +91,7 @@ impl Model {
match self {
Self::Gpt3_5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4_1 => "gpt-4.1",
Self::Gpt4o => "gpt-4o",
Self::O3Mini => "o3-mini",
Self::O1 => "o1",
@@ -91,6 +99,7 @@ impl Model {
Self::Claude3_7Sonnet => "claude-3-7-sonnet",
Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
Self::Gemini20Flash => "gemini-2.0-flash-001",
Self::Gemini25Pro => "gemini-2.5-pro",
}
}
@@ -98,6 +107,7 @@ impl Model {
match self {
Self::Gpt3_5Turbo => "GPT-3.5",
Self::Gpt4 => "GPT-4",
Self::Gpt4_1 => "GPT-4.1",
Self::Gpt4o => "GPT-4o",
Self::O3Mini => "o3-mini",
Self::O1 => "o1",
@@ -105,6 +115,7 @@ impl Model {
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
Self::Gemini20Flash => "Gemini 2.0 Flash",
Self::Gemini25Pro => "Gemini 2.5 Pro",
}
}
@@ -112,13 +123,15 @@ impl Model {
match self {
Self::Gpt4o => 64_000,
Self::Gpt4 => 32_768,
Self::Gpt4_1 => 1_047_576,
Self::Gpt3_5Turbo => 12_288,
Self::O3Mini => 64_000,
Self::O1 => 20_000,
Self::Claude3_5Sonnet => 200_000,
Self::Claude3_7Sonnet => 90_000,
Self::Claude3_7SonnetThinking => 90_000,
Model::Gemini20Flash => 128_000,
Self::Gemini20Flash => 128_000,
Self::Gemini25Pro => 128_000,
}
}
}

View File

@@ -39,7 +39,6 @@ log.workspace = true
node_runtime.workspace = true
parking_lot.workspace = true
paths.workspace = true
regex.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true

View File

@@ -20,7 +20,7 @@ use std::{
net::Ipv4Addr,
ops::Deref,
path::PathBuf,
sync::{Arc, LazyLock},
sync::Arc,
};
use task::{DebugAdapterConfig, DebugTaskDefinition};
use util::ResultExt;
@@ -291,14 +291,7 @@ pub trait DebugAdapter: 'static + Send + Sync {
/// Should return base configuration to make the debug adapter work
fn request_args(&self, config: &DebugTaskDefinition) -> Value;
fn attach_processes_filter(&self) -> regex::Regex {
EMPTY_REGEX.clone()
}
}
static EMPTY_REGEX: LazyLock<regex::Regex> =
LazyLock::new(|| regex::Regex::new("").expect("Regex compilation to succeed"));
#[cfg(any(test, feature = "test-support"))]
pub struct FakeAdapter {}
@@ -375,10 +368,4 @@ impl DebugAdapter for FakeAdapter {
},
})
}
fn attach_processes_filter(&self) -> regex::Regex {
static REGEX: LazyLock<regex::Regex> =
LazyLock::new(|| regex::Regex::new("^fake-binary").unwrap());
REGEX.clone()
}
}

View File

@@ -8,7 +8,7 @@ struct DapRegistryState {
adapters: BTreeMap<DebugAdapterName, Arc<dyn DebugAdapter>>,
}
#[derive(Default)]
#[derive(Clone, Default)]
/// Stores available debug adapters.
pub struct DapRegistry(Arc<RwLock<DapRegistryState>>);

View File

@@ -27,7 +27,6 @@ dap.workspace = true
gpui.workspace = true
language.workspace = true
paths.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
task.workspace = true

Some files were not shown because too many files have changed in this diff Show More