Merge branch 'main' into websearch-tool
This commit is contained in:
19
.github/ISSUE_TEMPLATE/99_other.yml
vendored
Normal file
19
.github/ISSUE_TEMPLATE/99_other.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Other [Staff Only]
|
||||
description: Zed Staff Only
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Summary
|
||||
value: |
|
||||
<!-- Please insert a one line summary of the issue below -->
|
||||
SUMMARY_SENTENCE_HERE
|
||||
|
||||
### Description
|
||||
|
||||
IF YOU DO NOT WORK FOR ZED INDUSTRIES DO NOT CREATE ISSUES WITH THIS TEMPLATE.
|
||||
THEY WILL BE AUTO-CLOSED AND MAY RESULT IN YOU BEING BANNED FROM THE ZED ISSUE TRACKER.
|
||||
|
||||
FEATURE REQUESTS / SUPPORT REQUESTS SHOULD BE OPENED AS DISCUSSIONS:
|
||||
https://github.com/zed-industries/zed/discussions/new/choose
|
||||
validations:
|
||||
required: true
|
||||
34
.github/workflows/ci.yml
vendored
34
.github/workflows/ci.yml
vendored
@@ -594,7 +594,7 @@ jobs:
|
||||
timeout-minutes: 60
|
||||
name: Linux x86_x64 release bundle
|
||||
runs-on:
|
||||
- buildjet-16vcpu-ubuntu-2004
|
||||
- buildjet-16vcpu-ubuntu-2004 # ubuntu 20.04 for minimal glibc
|
||||
if: |
|
||||
startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
@@ -622,26 +622,23 @@ jobs:
|
||||
- name: Create Linux .tar.gz bundle
|
||||
run: script/bundle-linux
|
||||
|
||||
- name: Upload Linux bundle to workflow run if main branch or specific label
|
||||
- name: Upload Artifact to Workflow - zed (run-bundling)
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
|
||||
if: |
|
||||
github.ref == 'refs/heads/main'
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
with:
|
||||
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.tar.gz
|
||||
path: target/release/zed-*.tar.gz
|
||||
|
||||
- name: Upload Linux remote server to workflow run if main branch or specific label
|
||||
- name: Upload Artifact to Workflow - zed-remote-server (run-bundling)
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
|
||||
if: |
|
||||
github.ref == 'refs/heads/main'
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
with:
|
||||
name: zed-remote-server-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.gz
|
||||
path: target/zed-remote-server-linux-x86_64.gz
|
||||
|
||||
- name: Upload app bundle to release
|
||||
- name: Upload Artifacts to release
|
||||
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
|
||||
if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) }}
|
||||
with:
|
||||
draft: true
|
||||
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
|
||||
@@ -680,29 +677,26 @@ jobs:
|
||||
# This exports RELEASE_CHANNEL into env (GITHUB_ENV)
|
||||
script/determine-release-channel
|
||||
|
||||
- name: Create and upload Linux .tar.gz bundle
|
||||
- name: Create and upload Linux .tar.gz bundles
|
||||
run: script/bundle-linux
|
||||
|
||||
- name: Upload Linux bundle to workflow run if main branch or specific label
|
||||
- name: Upload Artifact to Workflow - zed (run-bundling)
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
|
||||
if: |
|
||||
github.ref == 'refs/heads/main'
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
with:
|
||||
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.tar.gz
|
||||
path: target/release/zed-*.tar.gz
|
||||
|
||||
- name: Upload Linux remote server to workflow run if main branch or specific label
|
||||
- name: Upload Artifact to Workflow - zed-remote-server (run-bundling)
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4
|
||||
if: |
|
||||
github.ref == 'refs/heads/main'
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
with:
|
||||
name: zed-remote-server-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.gz
|
||||
path: target/zed-remote-server-linux-aarch64.gz
|
||||
|
||||
- name: Upload app bundle to release
|
||||
- name: Upload Artifacts to release
|
||||
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
|
||||
if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) }}
|
||||
with:
|
||||
draft: true
|
||||
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
|
||||
|
||||
8
.github/workflows/deploy_collab.yml
vendored
8
.github/workflows/deploy_collab.yml
vendored
@@ -117,12 +117,10 @@ jobs:
|
||||
export ZED_KUBE_NAMESPACE=production
|
||||
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=10
|
||||
export ZED_API_LOAD_BALANCER_SIZE_UNIT=2
|
||||
export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=2
|
||||
elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then
|
||||
export ZED_KUBE_NAMESPACE=staging
|
||||
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=1
|
||||
export ZED_API_LOAD_BALANCER_SIZE_UNIT=1
|
||||
export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=1
|
||||
else
|
||||
echo "cowardly refusing to deploy from an unknown branch"
|
||||
exit 1
|
||||
@@ -147,9 +145,3 @@ jobs:
|
||||
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
|
||||
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
|
||||
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
|
||||
|
||||
export ZED_SERVICE_NAME=llm
|
||||
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_LLM_LOAD_BALANCER_SIZE_UNIT
|
||||
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
|
||||
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
|
||||
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
|
||||
|
||||
45
Cargo.lock
generated
45
Cargo.lock
generated
@@ -326,7 +326,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"strum",
|
||||
"thiserror 2.0.12",
|
||||
"util",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -1183,6 +1182,18 @@ dependencies = [
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "auto_update_helper"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"log",
|
||||
"simplelog",
|
||||
"windows 0.61.1",
|
||||
"winresource",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "auto_update_ui"
|
||||
version = "0.1.0"
|
||||
@@ -2932,7 +2943,6 @@ dependencies = [
|
||||
name = "collab"
|
||||
version = "0.44.0"
|
||||
dependencies = [
|
||||
"anthropic",
|
||||
"anyhow",
|
||||
"assistant",
|
||||
"assistant_context_editor",
|
||||
@@ -3176,14 +3186,18 @@ dependencies = [
|
||||
name = "component_preview"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"collections",
|
||||
"component",
|
||||
"db",
|
||||
"gpui",
|
||||
"languages",
|
||||
"notifications",
|
||||
"project",
|
||||
"serde",
|
||||
"ui",
|
||||
"ui_input",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
]
|
||||
@@ -3988,7 +4002,6 @@ dependencies = [
|
||||
"node_runtime",
|
||||
"parking_lot",
|
||||
"paths",
|
||||
"regex",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -4020,7 +4033,6 @@ dependencies = [
|
||||
"gpui",
|
||||
"language",
|
||||
"paths",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"task",
|
||||
@@ -4164,6 +4176,7 @@ dependencies = [
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"dap",
|
||||
"db",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"feature_flags",
|
||||
@@ -4863,25 +4876,37 @@ dependencies = [
|
||||
"assistant_settings",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"async-watch",
|
||||
"chrono",
|
||||
"clap",
|
||||
"client",
|
||||
"collections",
|
||||
"context_server",
|
||||
"dap",
|
||||
"env_logger 0.11.8",
|
||||
"extension",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
"language",
|
||||
"language_extension",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"languages",
|
||||
"node_runtime",
|
||||
"paths",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
"serde",
|
||||
"settings",
|
||||
"shellexpand 2.1.2",
|
||||
"toml 0.8.20",
|
||||
"unindent",
|
||||
"util",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -4976,10 +5001,10 @@ dependencies = [
|
||||
"async-tar",
|
||||
"async-trait",
|
||||
"collections",
|
||||
"convert_case 0.8.0",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"heck 0.5.0",
|
||||
"http_client",
|
||||
"language",
|
||||
"log",
|
||||
@@ -7654,6 +7679,7 @@ dependencies = [
|
||||
name = "language_model_selector"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"collections",
|
||||
"feature_flags",
|
||||
"gpui",
|
||||
"language_model",
|
||||
@@ -7704,6 +7730,7 @@ dependencies = [
|
||||
"smol",
|
||||
"strum",
|
||||
"theme",
|
||||
"thiserror 2.0.12",
|
||||
"tiktoken-rs",
|
||||
"tokio",
|
||||
"ui",
|
||||
@@ -17628,6 +17655,7 @@ dependencies = [
|
||||
"ui",
|
||||
"util",
|
||||
"uuid",
|
||||
"windows 0.61.1",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
@@ -17790,6 +17818,8 @@ dependencies = [
|
||||
"wasmtime-cranelift",
|
||||
"wasmtime-environ",
|
||||
"winapi",
|
||||
"windows-core 0.61.0",
|
||||
"windows-numerics",
|
||||
"windows-sys 0.48.0",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
@@ -18134,7 +18164,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.182.0"
|
||||
version = "0.183.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
@@ -18230,7 +18260,6 @@ dependencies = [
|
||||
"settings",
|
||||
"settings_ui",
|
||||
"shellexpand 2.1.2",
|
||||
"simplelog",
|
||||
"smol",
|
||||
"snippet_provider",
|
||||
"snippets_ui",
|
||||
@@ -18581,7 +18610,9 @@ name = "zlog"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"log",
|
||||
"tempfile",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
|
||||
25
Cargo.toml
25
Cargo.toml
@@ -15,6 +15,7 @@ members = [
|
||||
"crates/assistant_tools",
|
||||
"crates/audio",
|
||||
"crates/auto_update",
|
||||
"crates/auto_update_helper",
|
||||
"crates/auto_update_ui",
|
||||
"crates/aws_http_client",
|
||||
"crates/bedrock",
|
||||
@@ -224,6 +225,7 @@ assistant_tool = { path = "crates/assistant_tool" }
|
||||
assistant_tools = { path = "crates/assistant_tools" }
|
||||
audio = { path = "crates/audio" }
|
||||
auto_update = { path = "crates/auto_update" }
|
||||
auto_update_helper = { path = "crates/auto_update_helper" }
|
||||
auto_update_ui = { path = "crates/auto_update_ui" }
|
||||
aws_http_client = { path = "crates/aws_http_client" }
|
||||
bedrock = { path = "crates/bedrock" }
|
||||
@@ -403,8 +405,12 @@ async-tungstenite = "0.29.1"
|
||||
async-watch = "0.3.1"
|
||||
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
|
||||
aws-config = { version = "1.6.1", features = ["behavior-version-latest"] }
|
||||
aws-credential-types = { version = "1.2.2", features = ["hardcoded-credentials"] }
|
||||
aws-sdk-bedrockruntime = { version = "1.80.0", features = ["behavior-version-latest"] }
|
||||
aws-credential-types = { version = "1.2.2", features = [
|
||||
"hardcoded-credentials",
|
||||
] }
|
||||
aws-sdk-bedrockruntime = { version = "1.80.0", features = [
|
||||
"behavior-version-latest",
|
||||
] }
|
||||
aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
|
||||
aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
|
||||
base64 = "0.22"
|
||||
@@ -443,6 +449,7 @@ futures-lite = "1.13"
|
||||
git2 = { version = "0.20.1", default-features = false }
|
||||
globset = "0.4"
|
||||
handlebars = "4.3"
|
||||
heck = "0.5"
|
||||
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
|
||||
hex = "0.4.3"
|
||||
html5ever = "0.27.0"
|
||||
@@ -619,12 +626,10 @@ features = [
|
||||
[workspace.dependencies.windows]
|
||||
version = "0.61"
|
||||
features = [
|
||||
"Foundation_Collections",
|
||||
"Foundation_Numerics",
|
||||
"Storage_Search",
|
||||
"Storage_Streams",
|
||||
"System_Threading",
|
||||
"UI_StartScreen",
|
||||
"UI_ViewManagement",
|
||||
"Wdk_System_SystemServices",
|
||||
"Win32_Globalization",
|
||||
@@ -651,6 +656,7 @@ features = [
|
||||
"Win32_System_SystemInformation",
|
||||
"Win32_System_SystemServices",
|
||||
"Win32_System_Threading",
|
||||
"Win32_System_Variant",
|
||||
"Win32_System_WinRT",
|
||||
"Win32_UI_Controls",
|
||||
"Win32_UI_HiDpi",
|
||||
@@ -658,6 +664,7 @@ features = [
|
||||
"Win32_UI_Input_KeyboardAndMouse",
|
||||
"Win32_UI_Shell",
|
||||
"Win32_UI_Shell_Common",
|
||||
"Win32_UI_Shell_PropertiesSystem",
|
||||
"Win32_UI_WindowsAndMessaging",
|
||||
]
|
||||
|
||||
@@ -781,4 +788,12 @@ let_underscore_future = "allow"
|
||||
too_many_arguments = "allow"
|
||||
|
||||
[workspace.metadata.cargo-machete]
|
||||
ignored = ["bindgen", "cbindgen", "prost_build", "serde", "component", "linkme", "workspace-hack"]
|
||||
ignored = [
|
||||
"bindgen",
|
||||
"cbindgen",
|
||||
"prost_build",
|
||||
"serde",
|
||||
"component",
|
||||
"linkme",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="16" height="16" rx="2" fill="black" fill-opacity="0.2"/>
|
||||
<g clip-path="url(#clip0_1916_18)">
|
||||
<path d="M10.652 3.79999H8.816L12.164 12.2H14L10.652 3.79999Z" fill="#1F1F1E"/>
|
||||
<path d="M5.348 3.79999L2 12.2H3.872L4.55672 10.436H8.05927L8.744 12.2H10.616L7.268 3.79999H5.348ZM5.16224 8.87599L6.308 5.92399L7.45374 8.87599H5.16224Z" fill="#1F1F1E"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_1916_18">
|
||||
<rect width="12" height="8.4" fill="white" transform="translate(2 3.79999)"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 601 B |
5
assets/icons/layout.svg
Normal file
5
assets/icons/layout.svg
Normal file
@@ -0,0 +1,5 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M20 14H4C3.44772 14 3 14.4477 3 15V20C3 20.5523 3.44772 21 4 21H20C20.5523 21 21 20.5523 21 20V15C21 14.4477 20.5523 14 20 14Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M11 3H4C3.44772 3 3 3.44772 3 4V9C3 9.55228 3.44772 10 4 10H11C11.5523 10 12 9.55228 12 9V4C12 3.44772 11.5523 3 11 3Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M20 3H17C16.4477 3 16 3.44772 16 4V9C16 9.55228 16.4477 10 17 10H20C20.5523 10 21 9.55228 21 9V4C21 3.44772 20.5523 3 20 3Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 746 B |
@@ -150,7 +150,9 @@
|
||||
"context": "AgentDiff",
|
||||
"bindings": {
|
||||
"ctrl-y": "agent::Keep",
|
||||
"ctrl-n": "agent::Reject"
|
||||
"ctrl-n": "agent::Reject",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -352,11 +354,11 @@
|
||||
"alt-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection
|
||||
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
|
||||
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
|
||||
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }],
|
||||
"ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match
|
||||
"ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }],
|
||||
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }],
|
||||
"ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }],
|
||||
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
|
||||
"ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch
|
||||
"ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch
|
||||
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip
|
||||
"ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch
|
||||
"ctrl-k ctrl-i": "editor::Hover",
|
||||
"ctrl-/": ["editor::ToggleComments", { "advance_downwards": false }],
|
||||
"ctrl-u": "editor::UndoSelection",
|
||||
@@ -780,6 +782,7 @@
|
||||
"shift-tab": "git_panel::FocusEditor",
|
||||
"escape": "git_panel::ToggleFocus",
|
||||
"ctrl-enter": "git::Commit",
|
||||
"ctrl-shift-enter": "git::Amend",
|
||||
"alt-enter": "menu::SecondaryConfirm",
|
||||
"delete": ["git::RestoreFile", { "skip_prompt": false }],
|
||||
"backspace": ["git::RestoreFile", { "skip_prompt": false }],
|
||||
@@ -788,12 +791,20 @@
|
||||
"ctrl-delete": ["git::RestoreFile", { "skip_prompt": false }]
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitPanel && CommitEditor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"escape": "git::Cancel"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitCommit > Editor",
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel",
|
||||
"enter": "editor::Newline",
|
||||
"ctrl-enter": "git::Commit",
|
||||
"ctrl-shift-enter": "git::Amend",
|
||||
"alt-l": "git::GenerateCommitMessage"
|
||||
}
|
||||
},
|
||||
@@ -815,6 +826,7 @@
|
||||
"context": "GitDiff > Editor",
|
||||
"bindings": {
|
||||
"ctrl-enter": "git::Commit",
|
||||
"ctrl-shift-enter": "git::Amend",
|
||||
"ctrl-space": "git::StageAll",
|
||||
"ctrl-shift-space": "git::UnstageAll"
|
||||
}
|
||||
@@ -833,6 +845,7 @@
|
||||
"shift-tab": "git_panel::FocusChanges",
|
||||
"enter": "editor::Newline",
|
||||
"ctrl-enter": "git::Commit",
|
||||
"ctrl-shift-enter": "git::Amend",
|
||||
"alt-up": "git_panel::FocusChanges",
|
||||
"alt-l": "git::GenerateCommitMessage"
|
||||
}
|
||||
|
||||
@@ -242,7 +242,9 @@
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-y": "agent::Keep",
|
||||
"cmd-n": "agent::Reject"
|
||||
"cmd-n": "agent::Reject",
|
||||
"cmd-shift-y": "agent::KeepAll",
|
||||
"cmd-shift-n": "agent::RejectAll"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -489,12 +491,15 @@
|
||||
"alt-shift-down": "editor::DuplicateLineDown",
|
||||
"ctrl-shift-right": "editor::SelectLargerSyntaxNode", // Expand Selection
|
||||
"ctrl-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection
|
||||
"cmd-d": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match
|
||||
"cmd-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
|
||||
"cmd-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
|
||||
"cmd-f2": "editor::SelectAllMatches", // Select all occurrences of current word
|
||||
"ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": false }],
|
||||
"cmd-k cmd-d": ["editor::SelectNext", { "replace_newest": true }],
|
||||
"cmd-k ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": true }],
|
||||
"cmd-k cmd-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip
|
||||
// macOS binds `ctrl-cmd-d` to Show Dictionary which breaks these two binds
|
||||
// To use `ctrl-cmd-d` or `ctrl-k ctrl-cmd-d` in Zed you must execute this command and then restart:
|
||||
// defaults write com.apple.symbolichotkeys AppleSymbolicHotKeys -dict-add 70 '<dict><key>enabled</key><false/></dict>'
|
||||
"ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch
|
||||
"cmd-k ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch
|
||||
"cmd-k cmd-i": "editor::Hover",
|
||||
"cmd-/": ["editor::ToggleComments", { "advance_downwards": false }],
|
||||
"cmd-u": "editor::UndoSelection",
|
||||
@@ -850,17 +855,26 @@
|
||||
"shift-tab": "git_panel::FocusEditor",
|
||||
"escape": "git_panel::ToggleFocus",
|
||||
"cmd-enter": "git::Commit",
|
||||
"cmd-shift-enter": "git::Amend",
|
||||
"backspace": ["git::RestoreFile", { "skip_prompt": false }],
|
||||
"delete": ["git::RestoreFile", { "skip_prompt": false }],
|
||||
"cmd-backspace": ["git::RestoreFile", { "skip_prompt": true }],
|
||||
"cmd-delete": ["git::RestoreFile", { "skip_prompt": true }]
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitPanel && CommitEditor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"escape": "git::Cancel"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitDiff > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-enter": "git::Commit",
|
||||
"cmd-shift-enter": "git::Amend",
|
||||
"cmd-ctrl-y": "git::StageAll",
|
||||
"cmd-ctrl-shift-y": "git::UnstageAll"
|
||||
}
|
||||
@@ -871,6 +885,7 @@
|
||||
"bindings": {
|
||||
"enter": "editor::Newline",
|
||||
"cmd-enter": "git::Commit",
|
||||
"cmd-shift-enter": "git::Amend",
|
||||
"tab": "git_panel::FocusChanges",
|
||||
"shift-tab": "git_panel::FocusChanges",
|
||||
"alt-up": "git_panel::FocusChanges",
|
||||
@@ -900,6 +915,7 @@
|
||||
"enter": "editor::Newline",
|
||||
"escape": "menu::Cancel",
|
||||
"cmd-enter": "git::Commit",
|
||||
"cmd-shift-enter": "git::Amend",
|
||||
"alt-tab": "git::GenerateCommitMessage"
|
||||
}
|
||||
},
|
||||
|
||||
@@ -37,6 +37,8 @@
|
||||
"ctrl-shift-a": "editor::SelectLargerSyntaxNode",
|
||||
"ctrl-shift-d": "editor::DuplicateSelection",
|
||||
"alt-f3": "editor::SelectAllMatches", // find_all_under
|
||||
// "ctrl-f3": "", // find_under (cancels any selections)
|
||||
// "cmd-alt-shift-g": "" // find_under_prev (cancels any selections)
|
||||
"f9": "editor::SortLinesCaseSensitive",
|
||||
"ctrl-f9": "editor::SortLinesCaseInsensitive",
|
||||
"f12": "editor::GoToDefinition",
|
||||
|
||||
@@ -38,6 +38,8 @@
|
||||
"cmd-shift-a": "editor::SelectLargerSyntaxNode",
|
||||
"cmd-shift-d": "editor::DuplicateSelection",
|
||||
"ctrl-cmd-g": "editor::SelectAllMatches", // find_all_under
|
||||
// "cmd-alt-g": "", // find_under (cancels any selections)
|
||||
// "cmd-alt-shift-g": "" // find_under_prev (cancels any selections)
|
||||
"f5": "editor::SortLinesCaseSensitive",
|
||||
"ctrl-f5": "editor::SortLinesCaseInsensitive",
|
||||
"shift-f12": "editor::FindAllReferences",
|
||||
|
||||
@@ -539,6 +539,7 @@
|
||||
"bindings": {
|
||||
"d": "vim::CurrentLine",
|
||||
"s": "vim::PushDeleteSurrounds",
|
||||
"v": "vim::PushForcedMotion", // "d v"
|
||||
"o": "editor::ToggleSelectedDiffHunks", // "d o"
|
||||
"shift-o": "git::ToggleStaged",
|
||||
"p": "git::Restore", // "d p"
|
||||
@@ -587,6 +588,7 @@
|
||||
"context": "vim_operator == y",
|
||||
"bindings": {
|
||||
"y": "vim::CurrentLine",
|
||||
"v": "vim::PushForcedMotion",
|
||||
"s": ["vim::PushAddSurrounds", {}]
|
||||
}
|
||||
},
|
||||
|
||||
@@ -80,6 +80,8 @@
|
||||
// Values are clamped to the [0.0, 1.0] range.
|
||||
"inactive_opacity": 1.0
|
||||
},
|
||||
// Layout mode of the bottom dock. Defaults to "contained"
|
||||
"bottom_dock_layout": "contained",
|
||||
// The direction that you want to split panes horizontally. Defaults to "up"
|
||||
"pane_split_direction_horizontal": "up",
|
||||
// The direction that you want to split panes horizontally. Defaults to "left"
|
||||
@@ -642,6 +644,7 @@
|
||||
// We don't know which of the context server tools are safe for the "Ask" profile, so we don't enable them by default.
|
||||
// "enable_all_context_servers": true,
|
||||
"tools": {
|
||||
"contents": true,
|
||||
"diagnostics": true,
|
||||
"fetch": true,
|
||||
"list_directory": false,
|
||||
@@ -661,6 +664,7 @@
|
||||
"batch_tool": true,
|
||||
"code_actions": true,
|
||||
"code_symbols": true,
|
||||
"contents": true,
|
||||
"copy_path": false,
|
||||
"create_file": true,
|
||||
"delete_path": false,
|
||||
|
||||
@@ -13,18 +13,18 @@ use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
|
||||
use assistant_tool::ToolUseStatus;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::scroll::Autoscroll;
|
||||
use editor::{Editor, MultiBuffer};
|
||||
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
|
||||
use gpui::{
|
||||
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem,
|
||||
DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Hsla, ListAlignment, ListState,
|
||||
MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, Task,
|
||||
TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
|
||||
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
|
||||
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
|
||||
};
|
||||
use language::{Buffer, LanguageRegistry};
|
||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason};
|
||||
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
|
||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
|
||||
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
|
||||
use project::ProjectItem as _;
|
||||
use rope::Point;
|
||||
use settings::{Settings as _, update_settings_file};
|
||||
@@ -34,7 +34,9 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use text::ToPoint;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*};
|
||||
use ui::{
|
||||
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{OpenOptions, Workspace};
|
||||
|
||||
@@ -66,8 +68,6 @@ pub struct ActiveThread {
|
||||
open_feedback_editors: HashMap<MessageId, Entity<Editor>>,
|
||||
}
|
||||
|
||||
const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 5;
|
||||
|
||||
struct RenderedMessage {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
segments: Vec<RenderedMessageSegment>,
|
||||
@@ -176,11 +176,37 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
});
|
||||
|
||||
MarkdownStyle {
|
||||
base_text_style: text_style,
|
||||
base_text_style: text_style.clone(),
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
selection_background_color: cx.theme().players().local().selection,
|
||||
code_block_overflow_x_scroll: true,
|
||||
table_overflow_x_scroll: true,
|
||||
heading_level_styles: Some(HeadingLevelStyles {
|
||||
h1: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.15).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h2: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.1).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h3: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.05).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h4: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h5: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(0.95).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h6: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(0.875).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
}),
|
||||
code_block: StyleRefinement {
|
||||
padding: EdgesRefinement {
|
||||
top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
|
||||
@@ -292,6 +318,8 @@ fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
|
||||
}
|
||||
}
|
||||
|
||||
const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 10;
|
||||
|
||||
fn render_markdown_code_block(
|
||||
message_id: MessageId,
|
||||
ix: usize,
|
||||
@@ -578,7 +606,7 @@ fn render_markdown_code_block(
|
||||
if is_expanded {
|
||||
this.h_full()
|
||||
} else {
|
||||
this.max_h_40()
|
||||
this.max_h_80()
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -1497,12 +1525,36 @@ impl ActiveThread {
|
||||
.when(!message_is_empty, |parent| {
|
||||
parent.child(
|
||||
if let Some(edit_message_editor) = edit_message_editor.clone() {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let font_size = TextSize::Small.rems(cx);
|
||||
let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
|
||||
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_size: font_size.into(),
|
||||
line_height: line_height.into(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
div()
|
||||
.key_context("EditMessageEditor")
|
||||
.on_action(cx.listener(Self::cancel_editing_message))
|
||||
.on_action(cx.listener(Self::confirm_editing_message))
|
||||
.min_h_6()
|
||||
.child(edit_message_editor)
|
||||
.pt_1()
|
||||
.child(EditorElement::new(
|
||||
&edit_message_editor,
|
||||
EditorStyle {
|
||||
background: colors.editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
))
|
||||
.into_any()
|
||||
} else {
|
||||
div()
|
||||
@@ -1667,11 +1719,9 @@ impl ActiveThread {
|
||||
),
|
||||
Role::Assistant => v_flex()
|
||||
.id(("message-container", ix))
|
||||
.ml_2()
|
||||
.ml_2p5()
|
||||
.pl_2()
|
||||
.pr_4()
|
||||
.border_l_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.children(message_content)
|
||||
.when(has_tool_uses, |parent| {
|
||||
parent.children(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{Keep, Reject, Thread, ThreadEvent};
|
||||
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
|
||||
use anyhow::Result;
|
||||
use buffer_diff::DiffHunkStatus;
|
||||
use collections::HashSet;
|
||||
@@ -843,7 +843,7 @@ impl ToolbarItemView for AgentDiffToolbar {
|
||||
}
|
||||
|
||||
impl Render for AgentDiffToolbar {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let agent_diff = match self.agent_diff(cx) {
|
||||
Some(ad) => ad,
|
||||
None => return div(),
|
||||
@@ -855,6 +855,8 @@ impl Render for AgentDiffToolbar {
|
||||
return div();
|
||||
}
|
||||
|
||||
let focus_handle = agent_diff.focus_handle(cx);
|
||||
|
||||
h_group_xl()
|
||||
.my_neg_1()
|
||||
.items_center()
|
||||
@@ -864,15 +866,25 @@ impl Render for AgentDiffToolbar {
|
||||
.child(
|
||||
h_group_sm()
|
||||
.child(
|
||||
Button::new("reject-all", "Reject All").on_click(cx.listener(
|
||||
|this, _, window, cx| {
|
||||
this.dispatch_action(&crate::RejectAll, window, cx)
|
||||
},
|
||||
)),
|
||||
Button::new("reject-all", "Reject All")
|
||||
.key_binding({
|
||||
KeyBinding::for_action_in(&RejectAll, &focus_handle, window, cx)
|
||||
.map(|kb| kb.size(rems_from_px(12.)))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.dispatch_action(&RejectAll, window, cx)
|
||||
})),
|
||||
)
|
||||
.child(Button::new("keep-all", "Keep All").on_click(cx.listener(
|
||||
|this, _, window, cx| this.dispatch_action(&crate::KeepAll, window, cx),
|
||||
))),
|
||||
.child(
|
||||
Button::new("keep-all", "Keep All")
|
||||
.key_binding({
|
||||
KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx)
|
||||
.map(|kb| kb.size(rems_from_px(12.)))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.dispatch_action(&KeepAll, window, cx)
|
||||
})),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -882,6 +894,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::{ThreadStore, thread_store};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use context_server::ContextServerSettings;
|
||||
use editor::EditorSettings;
|
||||
use gpui::TestAppContext;
|
||||
@@ -925,7 +938,7 @@ mod tests {
|
||||
.update(|cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
cx.new(|_| ToolWorkingSet::default()),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,9 @@ use fs::Fs;
|
||||
use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
|
||||
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use ui::{Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, prelude::*};
|
||||
use ui::{
|
||||
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use zed_actions::ExtensionCategoryFilter;
|
||||
|
||||
@@ -27,7 +29,7 @@ pub struct AssistantConfiguration {
|
||||
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
expanded_context_server_tools: HashMap<Arc<str>, bool>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
_registry_subscription: Subscription,
|
||||
}
|
||||
|
||||
@@ -35,7 +37,7 @@ impl AssistantConfiguration {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -224,7 +226,7 @@ impl AssistantConfiguration {
|
||||
|
||||
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
|
||||
let tools_by_source = self.tools.tools_by_source(cx);
|
||||
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
|
||||
let empty = Vec::new();
|
||||
|
||||
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
|
||||
@@ -236,7 +238,10 @@ impl AssistantConfiguration {
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small))
|
||||
.child(
|
||||
Headline::new("Model Context Protocol (MCP) Servers")
|
||||
.size(HeadlineSize::Small),
|
||||
)
|
||||
.child(Label::new(SUBHEADING).color(Color::Muted)),
|
||||
)
|
||||
.children(context_servers.into_iter().map(|context_server| {
|
||||
@@ -262,10 +267,9 @@ impl AssistantConfiguration {
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.child(
|
||||
h_flex()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.when(are_tools_expanded, |element| {
|
||||
.when(are_tools_expanded && tool_count > 1, |element| {
|
||||
element
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
@@ -275,6 +279,7 @@ impl AssistantConfiguration {
|
||||
.gap_2()
|
||||
.child(
|
||||
Disclosure::new("tool-list-disclosure", are_tools_expanded)
|
||||
.disabled(tool_count == 0)
|
||||
.on_click(cx.listener({
|
||||
let context_server_id = context_server.id();
|
||||
move |this, _event, _window, _cx| {
|
||||
@@ -295,10 +300,11 @@ impl AssistantConfiguration {
|
||||
.child(Label::new(context_server.id()))
|
||||
.child(
|
||||
Label::new(format!("{tool_count} tools"))
|
||||
.color(Color::Muted),
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.child(h_flex().child(
|
||||
.child(
|
||||
Switch::new("context-server-switch", is_running.into()).on_click({
|
||||
let context_server_manager =
|
||||
self.context_server_manager.clone();
|
||||
@@ -334,7 +340,7 @@ impl AssistantConfiguration {
|
||||
}
|
||||
}
|
||||
}),
|
||||
)),
|
||||
),
|
||||
)
|
||||
.map(|parent| {
|
||||
if !are_tools_expanded {
|
||||
@@ -344,14 +350,29 @@ impl AssistantConfiguration {
|
||||
parent.child(v_flex().children(tools.into_iter().enumerate().map(
|
||||
|(ix, tool)| {
|
||||
h_flex()
|
||||
.px_2()
|
||||
.id("tool-item")
|
||||
.pl_2()
|
||||
.pr_1()
|
||||
.py_1()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.when(ix < tool_count - 1, |element| {
|
||||
element
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
})
|
||||
.child(Label::new(tool.name()))
|
||||
.child(
|
||||
Label::new(tool.name())
|
||||
.buffer_font(cx)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
IconButton::new(("tool-description", ix), IconName::Info)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Ignored)
|
||||
.tooltip(Tooltip::text(tool.description())),
|
||||
)
|
||||
},
|
||||
)))
|
||||
})
|
||||
@@ -362,7 +383,7 @@ impl AssistantConfiguration {
|
||||
.gap_2()
|
||||
.child(
|
||||
h_flex().w_full().child(
|
||||
Button::new("add-context-server", "Add Context Server")
|
||||
Button::new("add-context-server", "Add MCPs Directly")
|
||||
.style(ButtonStyle::Filled)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.full_width()
|
||||
@@ -378,7 +399,7 @@ impl AssistantConfiguration {
|
||||
h_flex().w_full().child(
|
||||
Button::new(
|
||||
"install-context-server-extensions",
|
||||
"Install Context Server Extensions",
|
||||
"Install MCP Extensions",
|
||||
)
|
||||
.style(ButtonStyle::Filled)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
|
||||
@@ -84,7 +84,7 @@ pub struct NewProfileMode {
|
||||
|
||||
pub struct ManageProfilesModal {
|
||||
fs: Arc<dyn Fs>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
focus_handle: FocusHandle,
|
||||
mode: Mode,
|
||||
@@ -117,7 +117,7 @@ impl ManageProfilesModal {
|
||||
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
|
||||
@@ -60,7 +60,7 @@ pub struct ToolPickerDelegate {
|
||||
impl ToolPickerDelegate {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
tool_set: Arc<ToolWorkingSet>,
|
||||
tool_set: Entity<ToolWorkingSet>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
profile_id: AgentProfileId,
|
||||
profile: AgentProfile,
|
||||
@@ -68,7 +68,7 @@ impl ToolPickerDelegate {
|
||||
) -> Self {
|
||||
let mut tool_entries = Vec::new();
|
||||
|
||||
for (source, tools) in tool_set.tools_by_source(cx) {
|
||||
for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
|
||||
tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
|
||||
name: tool.name().into(),
|
||||
source: source.clone(),
|
||||
@@ -192,7 +192,7 @@ impl PickerDelegate for ToolPickerDelegate {
|
||||
if active_profile_id == &self.profile_id {
|
||||
self.thread_store
|
||||
.update(cx, |this, cx| {
|
||||
this.load_profile(&self.profile, cx);
|
||||
this.load_profile(self.profile.clone(), cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
@@ -80,17 +80,16 @@ impl AssistantModelSelector {
|
||||
|
||||
impl Render for AssistantModelSelector {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model = match self.model_type {
|
||||
ModelType::Default => model_registry.default_model(),
|
||||
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
|
||||
};
|
||||
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
let model_name = match model {
|
||||
Some(model) => model.model.name().0,
|
||||
_ => SharedString::from("No model selected"),
|
||||
let (model_name, model_icon) = match model {
|
||||
Some(model) => (model.model.name().0, Some(model.provider.icon())),
|
||||
_ => (SharedString::from("No model selected"), None),
|
||||
};
|
||||
|
||||
LanguageModelSelectorPopoverMenu::new(
|
||||
@@ -100,10 +99,16 @@ impl Render for AssistantModelSelector {
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_0p5()
|
||||
.children(
|
||||
model_icon.map(|icon| {
|
||||
Icon::new(icon).color(Color::Muted).size(IconSize::Small)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
Label::new(model_name)
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
.color(Color::Muted)
|
||||
.ml_1(),
|
||||
)
|
||||
.child(
|
||||
Icon::new(IconName::ChevronDown)
|
||||
|
||||
@@ -44,8 +44,8 @@ use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
|
||||
use crate::thread_history::{PastContext, PastThread, ThreadHistory};
|
||||
use crate::thread_store::ThreadStore;
|
||||
use crate::{
|
||||
AgentDiff, InlineAssistant, NewTextThread, NewThread, OpenActiveThreadAsMarkdown,
|
||||
OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker,
|
||||
AgentDiff, ExpandMessageEditor, InlineAssistant, NewTextThread, NewThread,
|
||||
OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker,
|
||||
};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
@@ -90,6 +90,16 @@ pub fn init(cx: &mut App) {
|
||||
let thread = panel.read(cx).thread.read(cx).thread().clone();
|
||||
AgentDiff::deploy_in_workspace(thread, workspace, window, cx);
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &ExpandMessageEditor, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.message_editor.update(cx, |editor, cx| {
|
||||
editor.expand_message_editor(&ExpandMessageEditor, window, cx);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
},
|
||||
)
|
||||
@@ -193,7 +203,7 @@ impl AssistantPanel {
|
||||
cx: AsyncWindowContext,
|
||||
) -> Task<Result<Entity<Self>>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let tools = Arc::new(ToolWorkingSet::default());
|
||||
let tools = cx.new(|_| ToolWorkingSet::default())?;
|
||||
let thread_store = workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().clone();
|
||||
@@ -559,6 +569,7 @@ impl AssistantPanel {
|
||||
ActiveView::Configuration | ActiveView::History => {
|
||||
self.active_view =
|
||||
ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx);
|
||||
self.message_editor.focus_handle(cx).focus(window);
|
||||
cx.notify();
|
||||
}
|
||||
_ => {}
|
||||
@@ -1088,20 +1099,30 @@ impl AssistantPanel {
|
||||
window,
|
||||
cx,
|
||||
|menu, _window, _cx| {
|
||||
menu.action(
|
||||
menu
|
||||
.when(!is_empty, |menu| {
|
||||
menu.action(
|
||||
"Start New From Summary",
|
||||
Box::new(NewThread {
|
||||
from_thread_id: Some(thread_id.clone()),
|
||||
}),
|
||||
).separator()
|
||||
})
|
||||
.action(
|
||||
"New Text Thread",
|
||||
NewTextThread.boxed_clone(),
|
||||
)
|
||||
.when(!is_empty, |menu| {
|
||||
menu.action(
|
||||
"Continue in New Thread",
|
||||
Box::new(NewThread {
|
||||
from_thread_id: Some(thread_id.clone()),
|
||||
}),
|
||||
)
|
||||
})
|
||||
.separator()
|
||||
.action("Settings", OpenConfiguration.boxed_clone())
|
||||
.separator()
|
||||
.action(
|
||||
"Install MCPs",
|
||||
zed_actions::Extensions {
|
||||
category_filter: Some(
|
||||
zed_actions::ExtensionCategoryFilter::ContextServers,
|
||||
),
|
||||
}
|
||||
.boxed_clone(),
|
||||
)
|
||||
},
|
||||
))
|
||||
}),
|
||||
|
||||
@@ -34,12 +34,6 @@ use crate::context_store::ContextStore;
|
||||
use crate::thread::ThreadId;
|
||||
use crate::thread_store::ThreadStore;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum ConfirmBehavior {
|
||||
KeepOpen,
|
||||
Close,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ContextPickerMode {
|
||||
File,
|
||||
@@ -105,7 +99,6 @@ pub(super) struct ContextPicker {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -114,7 +107,6 @@ impl ContextPicker {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -143,7 +135,6 @@ impl ContextPicker {
|
||||
workspace,
|
||||
context_store,
|
||||
thread_store,
|
||||
confirm_behavior,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
}
|
||||
@@ -166,37 +157,32 @@ impl ContextPicker {
|
||||
|
||||
let modes = supported_context_picker_modes(&self.thread_store);
|
||||
|
||||
let menu = menu
|
||||
.when(has_recent, |menu| {
|
||||
menu.custom_row(|_, _| {
|
||||
div()
|
||||
.mb_1()
|
||||
.child(
|
||||
Label::new("Recent")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.into_any_element()
|
||||
})
|
||||
menu.when(has_recent, |menu| {
|
||||
menu.custom_row(|_, _| {
|
||||
div()
|
||||
.mb_1()
|
||||
.child(
|
||||
Label::new("Recent")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.into_any_element()
|
||||
})
|
||||
.extend(recent_entries)
|
||||
.when(has_recent, |menu| menu.separator())
|
||||
.extend(modes.into_iter().map(|mode| {
|
||||
let context_picker = context_picker.clone();
|
||||
})
|
||||
.extend(recent_entries)
|
||||
.when(has_recent, |menu| menu.separator())
|
||||
.extend(modes.into_iter().map(|mode| {
|
||||
let context_picker = context_picker.clone();
|
||||
|
||||
ContextMenuEntry::new(mode.label())
|
||||
.icon(mode.icon())
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
|
||||
})
|
||||
}));
|
||||
|
||||
match self.confirm_behavior {
|
||||
ConfirmBehavior::KeepOpen => menu.keep_open_on_confirm(),
|
||||
ConfirmBehavior::Close => menu,
|
||||
}
|
||||
ContextMenuEntry::new(mode.label())
|
||||
.icon(mode.icon())
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
|
||||
})
|
||||
}))
|
||||
.keep_open_on_confirm()
|
||||
});
|
||||
|
||||
cx.subscribe(&menu, move |_, _, _: &DismissEvent, cx| {
|
||||
@@ -227,7 +213,6 @@ impl ContextPicker {
|
||||
context_picker.clone(),
|
||||
self.workspace.clone(),
|
||||
self.context_store.clone(),
|
||||
self.confirm_behavior,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -239,7 +224,6 @@ impl ContextPicker {
|
||||
context_picker.clone(),
|
||||
self.workspace.clone(),
|
||||
self.context_store.clone(),
|
||||
self.confirm_behavior,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -251,7 +235,6 @@ impl ContextPicker {
|
||||
context_picker.clone(),
|
||||
self.workspace.clone(),
|
||||
self.context_store.clone(),
|
||||
self.confirm_behavior,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -264,7 +247,6 @@ impl ContextPicker {
|
||||
thread_store.clone(),
|
||||
context_picker.clone(),
|
||||
self.context_store.clone(),
|
||||
self.confirm_behavior,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ use picker::{Picker, PickerDelegate};
|
||||
use ui::{Context, ListItem, Window, prelude::*};
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::context_picker::{ConfirmBehavior, ContextPicker};
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
|
||||
pub struct FetchContextPicker {
|
||||
@@ -23,16 +23,10 @@ impl FetchContextPicker {
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let delegate = FetchContextPickerDelegate::new(
|
||||
context_picker,
|
||||
workspace,
|
||||
context_store,
|
||||
confirm_behavior,
|
||||
);
|
||||
let delegate = FetchContextPickerDelegate::new(context_picker, workspace, context_store);
|
||||
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
|
||||
|
||||
Self { picker }
|
||||
@@ -62,7 +56,6 @@ pub struct FetchContextPickerDelegate {
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
url: String,
|
||||
}
|
||||
|
||||
@@ -71,13 +64,11 @@ impl FetchContextPickerDelegate {
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
) -> Self {
|
||||
FetchContextPickerDelegate {
|
||||
context_picker,
|
||||
workspace,
|
||||
context_store,
|
||||
confirm_behavior,
|
||||
url: String::new(),
|
||||
}
|
||||
}
|
||||
@@ -204,25 +195,15 @@ impl PickerDelegate for FetchContextPickerDelegate {
|
||||
|
||||
let http_client = workspace.read(cx).client().http_client().clone();
|
||||
let url = self.url.clone();
|
||||
let confirm_behavior = self.confirm_behavior;
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let text = cx
|
||||
.background_spawn(fetch_url_content(http_client, url.clone()))
|
||||
.await?;
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.delegate
|
||||
.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
context_store.add_fetched_url(url, text, cx)
|
||||
})?;
|
||||
|
||||
match confirm_behavior {
|
||||
ConfirmBehavior::KeepOpen => {}
|
||||
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
this.update(cx, |this, cx| {
|
||||
this.delegate.context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_fetched_url(url, text, cx)
|
||||
})
|
||||
})??;
|
||||
|
||||
anyhow::Ok(())
|
||||
|
||||
@@ -11,9 +11,9 @@ use picker::{Picker, PickerDelegate};
|
||||
use project::{PathMatchCandidateSet, ProjectPath, WorktreeId};
|
||||
use ui::{ListItem, Tooltip, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::context_picker::{ConfirmBehavior, ContextPicker};
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::{ContextStore, FileInclusion};
|
||||
|
||||
pub struct FileContextPicker {
|
||||
@@ -25,16 +25,10 @@ impl FileContextPicker {
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let delegate = FileContextPickerDelegate::new(
|
||||
context_picker,
|
||||
workspace,
|
||||
context_store,
|
||||
confirm_behavior,
|
||||
);
|
||||
let delegate = FileContextPickerDelegate::new(context_picker, workspace, context_store);
|
||||
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
|
||||
|
||||
Self { picker }
|
||||
@@ -57,7 +51,6 @@ pub struct FileContextPickerDelegate {
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
matches: Vec<FileMatch>,
|
||||
selected_index: usize,
|
||||
}
|
||||
@@ -67,13 +60,11 @@ impl FileContextPickerDelegate {
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
) -> Self {
|
||||
Self {
|
||||
context_picker,
|
||||
workspace,
|
||||
context_store,
|
||||
confirm_behavior,
|
||||
matches: Vec::new(),
|
||||
selected_index: 0,
|
||||
}
|
||||
@@ -127,7 +118,7 @@ impl PickerDelegate for FileContextPickerDelegate {
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
let Some(FileMatch { mat, .. }) = self.matches.get(self.selected_index) else {
|
||||
return;
|
||||
};
|
||||
@@ -153,17 +144,7 @@ impl PickerDelegate for FileContextPickerDelegate {
|
||||
return;
|
||||
};
|
||||
|
||||
let confirm_behavior = self.confirm_behavior;
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
match task.await.notify_async_err(cx) {
|
||||
None => anyhow::Ok(()),
|
||||
Some(()) => this.update_in(cx, |this, window, cx| match confirm_behavior {
|
||||
ConfirmBehavior::KeepOpen => {}
|
||||
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
|
||||
}),
|
||||
}
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
task.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
|
||||
@@ -15,7 +15,7 @@ use ui::{ListItem, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::context_picker::{ConfirmBehavior, ContextPicker};
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
|
||||
pub struct SymbolContextPicker {
|
||||
@@ -27,16 +27,10 @@ impl SymbolContextPicker {
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let delegate = SymbolContextPickerDelegate::new(
|
||||
context_picker,
|
||||
workspace,
|
||||
context_store,
|
||||
confirm_behavior,
|
||||
);
|
||||
let delegate = SymbolContextPickerDelegate::new(context_picker, workspace, context_store);
|
||||
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
|
||||
|
||||
Self { picker }
|
||||
@@ -59,7 +53,6 @@ pub struct SymbolContextPickerDelegate {
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
matches: Vec<SymbolEntry>,
|
||||
selected_index: usize,
|
||||
}
|
||||
@@ -69,13 +62,11 @@ impl SymbolContextPickerDelegate {
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
) -> Self {
|
||||
Self {
|
||||
context_picker,
|
||||
workspace,
|
||||
context_store,
|
||||
confirm_behavior,
|
||||
matches: Vec::new(),
|
||||
selected_index: 0,
|
||||
}
|
||||
@@ -135,7 +126,7 @@ impl PickerDelegate for SymbolContextPickerDelegate {
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
let Some(mat) = self.matches.get(self.selected_index) else {
|
||||
return;
|
||||
};
|
||||
@@ -143,7 +134,6 @@ impl PickerDelegate for SymbolContextPickerDelegate {
|
||||
return;
|
||||
};
|
||||
|
||||
let confirm_behavior = self.confirm_behavior;
|
||||
let add_symbol_task = add_symbol(
|
||||
mat.symbol.clone(),
|
||||
true,
|
||||
@@ -153,16 +143,12 @@ impl PickerDelegate for SymbolContextPickerDelegate {
|
||||
);
|
||||
|
||||
let selected_index = self.selected_index;
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
cx.spawn(async move |this, cx| {
|
||||
let included = add_symbol_task.await?;
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.update(cx, |this, _| {
|
||||
if let Some(mat) = this.delegate.matches.get_mut(selected_index) {
|
||||
mat.is_included = included;
|
||||
}
|
||||
match confirm_behavior {
|
||||
ConfirmBehavior::KeepOpen => {}
|
||||
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
@@ -6,7 +6,7 @@ use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use ui::{ListItem, prelude::*};
|
||||
|
||||
use crate::context_picker::{ConfirmBehavior, ContextPicker};
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::{self, ContextStore};
|
||||
use crate::thread::ThreadId;
|
||||
use crate::thread_store::ThreadStore;
|
||||
@@ -20,16 +20,11 @@ impl ThreadContextPicker {
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let delegate = ThreadContextPickerDelegate::new(
|
||||
thread_store,
|
||||
context_picker,
|
||||
context_store,
|
||||
confirm_behavior,
|
||||
);
|
||||
let delegate =
|
||||
ThreadContextPickerDelegate::new(thread_store, context_picker, context_store);
|
||||
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
|
||||
|
||||
ThreadContextPicker { picker }
|
||||
@@ -58,7 +53,6 @@ pub struct ThreadContextPickerDelegate {
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
matches: Vec<ThreadContextEntry>,
|
||||
selected_index: usize,
|
||||
}
|
||||
@@ -68,13 +62,11 @@ impl ThreadContextPickerDelegate {
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
confirm_behavior: ConfirmBehavior,
|
||||
) -> Self {
|
||||
ThreadContextPickerDelegate {
|
||||
thread_store,
|
||||
context_picker,
|
||||
context_store,
|
||||
confirm_behavior,
|
||||
matches: Vec::new(),
|
||||
selected_index: 0,
|
||||
}
|
||||
@@ -127,7 +119,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
let Some(entry) = self.matches.get(self.selected_index) else {
|
||||
return;
|
||||
};
|
||||
@@ -138,20 +130,15 @@ impl PickerDelegate for ThreadContextPickerDelegate {
|
||||
|
||||
let open_thread_task = thread_store.update(cx, |this, cx| this.open_thread(&entry.id, cx));
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
cx.spawn(async move |this, cx| {
|
||||
let thread = open_thread_task.await?;
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.delegate
|
||||
.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
context_store.add_thread(thread, true, cx)
|
||||
})
|
||||
.ok();
|
||||
|
||||
match this.delegate.confirm_behavior {
|
||||
ConfirmBehavior::KeepOpen => {}
|
||||
ConfirmBehavior::Close => this.delegate.dismissed(window, cx),
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
@@ -15,7 +15,7 @@ use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
|
||||
use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
|
||||
use crate::context::{ContextId, ContextKind};
|
||||
use crate::context_picker::{ConfirmBehavior, ContextPicker};
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread::Thread;
|
||||
use crate::thread_store::ThreadStore;
|
||||
@@ -52,7 +52,6 @@ impl ContextStrip {
|
||||
workspace.clone(),
|
||||
thread_store.clone(),
|
||||
context_store.downgrade(),
|
||||
ConfirmBehavior::KeepOpen,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -86,7 +86,7 @@ impl ProfileSelector {
|
||||
|
||||
thread_store
|
||||
.update(cx, |this, cx| {
|
||||
this.load_profile_by_id(&profile_id, cx);
|
||||
this.load_profile_by_id(profile_id.clone(), cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
@@ -15,10 +15,11 @@ use futures::{FutureExt, StreamExt as _};
|
||||
use git::repository::DiffType;
|
||||
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
PaymentRequiredError, Role, StopReason, TokenUsage,
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
||||
Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::Project;
|
||||
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
||||
@@ -228,7 +229,7 @@ pub struct TotalTokenUsage {
|
||||
pub ratio: TokenUsageRatio,
|
||||
}
|
||||
|
||||
#[derive(Default, PartialEq, Eq)]
|
||||
#[derive(Debug, Default, PartialEq, Eq)]
|
||||
pub enum TokenUsageRatio {
|
||||
#[default]
|
||||
Normal,
|
||||
@@ -253,22 +254,31 @@ pub struct Thread {
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
project: Entity<Project>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
tool_use: ToolUseState,
|
||||
action_log: Entity<ActionLog>,
|
||||
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
|
||||
pending_checkpoint: Option<ThreadCheckpoint>,
|
||||
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
|
||||
cumulative_token_usage: TokenUsage,
|
||||
exceeded_window_error: Option<ExceededWindowError>,
|
||||
feedback: Option<ThreadFeedback>,
|
||||
message_feedback: HashMap<MessageId, ThreadFeedback>,
|
||||
last_auto_capture_at: Option<Instant>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExceededWindowError {
|
||||
/// Model used when last message exceeded context window
|
||||
model_id: LanguageModelId,
|
||||
/// Token count including last message
|
||||
token_count: usize,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
system_prompt: SharedProjectContext,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -301,6 +311,7 @@ impl Thread {
|
||||
.shared()
|
||||
},
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
exceeded_window_error: None,
|
||||
feedback: None,
|
||||
message_feedback: HashMap::default(),
|
||||
last_auto_capture_at: None,
|
||||
@@ -311,7 +322,7 @@ impl Thread {
|
||||
id: ThreadId,
|
||||
serialized: SerializedThread,
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
project_context: SharedProjectContext,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -367,6 +378,7 @@ impl Thread {
|
||||
action_log: cx.new(|_| ActionLog::new(project)),
|
||||
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
|
||||
cumulative_token_usage: serialized.cumulative_token_usage,
|
||||
exceeded_window_error: None,
|
||||
feedback: None,
|
||||
message_feedback: HashMap::default(),
|
||||
last_auto_capture_at: None,
|
||||
@@ -446,7 +458,7 @@ impl Thread {
|
||||
!self.pending_completions.is_empty() || !self.all_tools_finished()
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> &Arc<ToolWorkingSet> {
|
||||
pub fn tools(&self) -> &Entity<ToolWorkingSet> {
|
||||
&self.tools
|
||||
}
|
||||
|
||||
@@ -819,8 +831,9 @@ impl Thread {
|
||||
})
|
||||
.collect(),
|
||||
initial_project_snapshot,
|
||||
cumulative_token_usage: this.cumulative_token_usage.clone(),
|
||||
cumulative_token_usage: this.cumulative_token_usage,
|
||||
detailed_summary_state: this.detailed_summary_state.clone(),
|
||||
exceeded_window_error: this.exceeded_window_error.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -835,13 +848,21 @@ impl Thread {
|
||||
if model.supports_tools() {
|
||||
request.tools = {
|
||||
let mut tools = Vec::new();
|
||||
tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
|
||||
LanguageModelRequestTool {
|
||||
name: tool.name(),
|
||||
description: tool.description(),
|
||||
input_schema: tool.input_schema(model.tool_input_format()),
|
||||
}
|
||||
}));
|
||||
tools.extend(
|
||||
self.tools()
|
||||
.read(cx)
|
||||
.enabled_tools(cx)
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
// Skip tools that cannot be supported
|
||||
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool.name(),
|
||||
description: tool.description(),
|
||||
input_schema,
|
||||
})
|
||||
}),
|
||||
);
|
||||
|
||||
tools
|
||||
};
|
||||
@@ -1000,7 +1021,7 @@ impl Thread {
|
||||
let task = cx.spawn(async move |thread, cx| {
|
||||
let stream = model.stream_completion(request, &cx);
|
||||
let initial_token_usage =
|
||||
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
|
||||
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
|
||||
let stream_completion = async {
|
||||
let mut events = stream.await?;
|
||||
let mut stop_reason = StopReason::EndTurn;
|
||||
@@ -1022,9 +1043,9 @@ impl Thread {
|
||||
stop_reason = reason;
|
||||
}
|
||||
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
|
||||
thread.cumulative_token_usage =
|
||||
thread.cumulative_token_usage.clone() + token_usage.clone()
|
||||
- current_token_usage.clone();
|
||||
thread.cumulative_token_usage = thread.cumulative_token_usage
|
||||
+ token_usage
|
||||
- current_token_usage;
|
||||
current_token_usage = token_usage;
|
||||
}
|
||||
LanguageModelCompletionEvent::Text(chunk) => {
|
||||
@@ -1133,6 +1154,20 @@ impl Thread {
|
||||
cx.emit(ThreadEvent::ShowError(
|
||||
ThreadError::MaxMonthlySpendReached,
|
||||
));
|
||||
} else if let Some(known_error) =
|
||||
error.downcast_ref::<LanguageModelKnownError>()
|
||||
{
|
||||
match known_error {
|
||||
LanguageModelKnownError::ContextWindowLimitExceeded {
|
||||
tokens,
|
||||
} => {
|
||||
thread.exceeded_window_error = Some(ExceededWindowError {
|
||||
model_id: model.id(),
|
||||
token_count: *tokens,
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let error_message = error
|
||||
.chain()
|
||||
@@ -1153,7 +1188,7 @@ impl Thread {
|
||||
thread.auto_capture_telemetry(cx);
|
||||
|
||||
if let Ok(initial_usage) = initial_token_usage {
|
||||
let usage = thread.cumulative_token_usage.clone() - initial_usage;
|
||||
let usage = thread.cumulative_token_usage - initial_usage;
|
||||
|
||||
telemetry::event!(
|
||||
"Assistant Thread Completion",
|
||||
@@ -1324,7 +1359,7 @@ impl Thread {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for tool_use in pending_tool_uses.iter() {
|
||||
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
||||
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
|
||||
if tool.needs_confirmation(&tool_use.input, cx)
|
||||
&& !AssistantSettings::get_global(cx).always_allow_tool_actions
|
||||
{
|
||||
@@ -1376,7 +1411,7 @@ impl Thread {
|
||||
) -> Task<()> {
|
||||
let tool_name: Arc<str> = tool.name().into();
|
||||
|
||||
let tool_result = if self.tools.is_disabled(&tool.source(), &tool_name) {
|
||||
let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
|
||||
ToolResult {
|
||||
output: Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))),
|
||||
card: None,
|
||||
@@ -1500,6 +1535,7 @@ impl Thread {
|
||||
|
||||
let enabled_tool_names: Vec<String> = self
|
||||
.tools()
|
||||
.read(cx)
|
||||
.enabled_tools(cx)
|
||||
.iter()
|
||||
.map(|tool| tool.name().to_string())
|
||||
@@ -1797,10 +1833,6 @@ impl Thread {
|
||||
&self.project
|
||||
}
|
||||
|
||||
pub fn cumulative_token_usage(&self) -> TokenUsage {
|
||||
self.cumulative_token_usage.clone()
|
||||
}
|
||||
|
||||
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
|
||||
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
|
||||
return;
|
||||
@@ -1845,6 +1877,10 @@ impl Thread {
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub fn cumulative_token_usage(&self) -> TokenUsage {
|
||||
self.cumulative_token_usage
|
||||
}
|
||||
|
||||
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let Some(model) = model_registry.default_model() else {
|
||||
@@ -1853,6 +1889,16 @@ impl Thread {
|
||||
|
||||
let max = model.model.max_token_count();
|
||||
|
||||
if let Some(exceeded_error) = &self.exceeded_window_error {
|
||||
if model.model.id() == exceeded_error.model_id {
|
||||
return TotalTokenUsage {
|
||||
total: exceeded_error.token_count,
|
||||
max,
|
||||
ratio: TokenUsageRatio::Exceeded,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
|
||||
.unwrap_or("0.8".to_string())
|
||||
@@ -2310,7 +2356,7 @@ fn main() {{
|
||||
.update(|_, cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
cx.new(|_| ToolWorkingSet::default()),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -4,11 +4,14 @@ use assistant_context_editor::SavedContextMetadata;
|
||||
use editor::{Editor, EditorEvent};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate};
|
||||
use gpui::{
|
||||
App, Entity, FocusHandle, Focusable, ScrollStrategy, Task, UniformListScrollHandle, WeakEntity,
|
||||
Window, uniform_list,
|
||||
App, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, UniformListScrollHandle,
|
||||
WeakEntity, Window, uniform_list,
|
||||
};
|
||||
use time::{OffsetDateTime, UtcOffset};
|
||||
use ui::{HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Tooltip, prelude::*};
|
||||
use ui::{
|
||||
HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState,
|
||||
Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::history_store::{HistoryEntry, HistoryStore};
|
||||
@@ -26,6 +29,8 @@ pub struct ThreadHistory {
|
||||
matches: Vec<StringMatch>,
|
||||
_subscriptions: Vec<gpui::Subscription>,
|
||||
_search_task: Option<Task<()>>,
|
||||
scrollbar_visibility: bool,
|
||||
scrollbar_state: ScrollbarState,
|
||||
}
|
||||
|
||||
impl ThreadHistory {
|
||||
@@ -58,10 +63,13 @@ impl ThreadHistory {
|
||||
this.update_all_entries(cx);
|
||||
});
|
||||
|
||||
let scroll_handle = UniformListScrollHandle::default();
|
||||
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
|
||||
|
||||
Self {
|
||||
assistant_panel,
|
||||
history_store,
|
||||
scroll_handle: UniformListScrollHandle::default(),
|
||||
scroll_handle,
|
||||
selected_index: 0,
|
||||
search_query: SharedString::new_static(""),
|
||||
all_entries: entries,
|
||||
@@ -69,6 +77,8 @@ impl ThreadHistory {
|
||||
search_editor,
|
||||
_subscriptions: vec![search_editor_subscription, history_store_subscription],
|
||||
_search_task: None,
|
||||
scrollbar_visibility: true,
|
||||
scrollbar_state,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,6 +230,43 @@ impl ThreadHistory {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
|
||||
if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(
|
||||
div()
|
||||
.occlude()
|
||||
.id("thread-history-scroll")
|
||||
.h_full()
|
||||
.bg(cx.theme().colors().panel_background.opacity(0.8))
|
||||
.border_l_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_0()
|
||||
.bottom_0()
|
||||
.w_4()
|
||||
.pl_1()
|
||||
.cursor_default()
|
||||
.on_mouse_move(cx.listener(|_, _, _window, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _window, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _window, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_scroll_wheel(cx.listener(|_, _, _window, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
|
||||
)
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if let Some(entry) = self.get_match(self.selected_index) {
|
||||
let task_result = match entry {
|
||||
@@ -305,7 +352,11 @@ impl Render for ThreadHistory {
|
||||
)
|
||||
})
|
||||
.child({
|
||||
let view = v_flex().overflow_hidden().flex_grow();
|
||||
let view = v_flex()
|
||||
.id("list-container")
|
||||
.relative()
|
||||
.overflow_hidden()
|
||||
.flex_grow();
|
||||
|
||||
if self.all_entries.is_empty() {
|
||||
view.justify_center()
|
||||
@@ -322,59 +373,70 @@ impl Render for ThreadHistory {
|
||||
),
|
||||
)
|
||||
} else {
|
||||
view.p_1().child(
|
||||
uniform_list(
|
||||
cx.entity().clone(),
|
||||
"thread-history",
|
||||
self.matched_count(),
|
||||
move |history, range, _window, _cx| {
|
||||
let range_start = range.start;
|
||||
let assistant_panel = history.assistant_panel.clone();
|
||||
view.pr_5()
|
||||
.child(
|
||||
uniform_list(
|
||||
cx.entity().clone(),
|
||||
"thread-history",
|
||||
self.matched_count(),
|
||||
move |history, range, _window, _cx| {
|
||||
let range_start = range.start;
|
||||
let assistant_panel = history.assistant_panel.clone();
|
||||
|
||||
let render_item = |index: usize,
|
||||
entry: &HistoryEntry,
|
||||
highlight_positions: Vec<usize>|
|
||||
-> Div {
|
||||
h_flex().w_full().pb_1().child(match entry {
|
||||
HistoryEntry::Thread(thread) => PastThread::new(
|
||||
thread.clone(),
|
||||
assistant_panel.clone(),
|
||||
selected_index == index + range_start,
|
||||
highlight_positions,
|
||||
)
|
||||
.into_any_element(),
|
||||
HistoryEntry::Context(context) => PastContext::new(
|
||||
context.clone(),
|
||||
assistant_panel.clone(),
|
||||
selected_index == index + range_start,
|
||||
highlight_positions,
|
||||
)
|
||||
.into_any_element(),
|
||||
})
|
||||
};
|
||||
|
||||
if history.has_search_query() {
|
||||
history.matches[range]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, m)| {
|
||||
history.all_entries.get(m.candidate_id).map(|entry| {
|
||||
render_item(index, entry, m.positions.clone())
|
||||
})
|
||||
let render_item = |index: usize,
|
||||
entry: &HistoryEntry,
|
||||
highlight_positions: Vec<usize>|
|
||||
-> Div {
|
||||
h_flex().w_full().pb_1().child(match entry {
|
||||
HistoryEntry::Thread(thread) => PastThread::new(
|
||||
thread.clone(),
|
||||
assistant_panel.clone(),
|
||||
selected_index == index + range_start,
|
||||
highlight_positions,
|
||||
)
|
||||
.into_any_element(),
|
||||
HistoryEntry::Context(context) => PastContext::new(
|
||||
context.clone(),
|
||||
assistant_panel.clone(),
|
||||
selected_index == index + range_start,
|
||||
highlight_positions,
|
||||
)
|
||||
.into_any_element(),
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
history.all_entries[range]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, entry)| render_item(index, entry, vec![]))
|
||||
.collect()
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
if history.has_search_query() {
|
||||
history.matches[range]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, m)| {
|
||||
history.all_entries.get(m.candidate_id).map(
|
||||
|entry| {
|
||||
render_item(
|
||||
index,
|
||||
entry,
|
||||
m.positions.clone(),
|
||||
)
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
history.all_entries[range]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, entry)| render_item(index, entry, vec![]))
|
||||
.collect()
|
||||
}
|
||||
},
|
||||
)
|
||||
.p_1()
|
||||
.track_scroll(self.scroll_handle.clone())
|
||||
.flex_grow(),
|
||||
)
|
||||
.track_scroll(self.scroll_handle.clone())
|
||||
.flex_grow(),
|
||||
)
|
||||
.when_some(self.render_scrollbar(cx), |div, scrollbar| {
|
||||
div.child(scrollbar)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -440,6 +502,7 @@ impl RenderOnce for PastThread {
|
||||
IconButton::new("delete", IconName::TrashAlt)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip(move |window, cx| {
|
||||
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
|
||||
})
|
||||
@@ -531,6 +594,7 @@ impl RenderOnce for PastContext {
|
||||
IconButton::new("delete", IconName::TrashAlt)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip(move |window, cx| {
|
||||
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
|
||||
})
|
||||
|
||||
@@ -27,7 +27,9 @@ use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
|
||||
use crate::thread::{
|
||||
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
|
||||
};
|
||||
|
||||
const RULES_FILE_NAMES: [&'static str; 6] = [
|
||||
".rules",
|
||||
@@ -54,7 +56,7 @@ impl SharedProjectContext {
|
||||
|
||||
pub struct ThreadStore {
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
|
||||
@@ -72,7 +74,7 @@ impl EventEmitter<RulesLoadingError> for ThreadStore {}
|
||||
impl ThreadStore {
|
||||
pub fn load(
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: &mut App,
|
||||
) -> Task<Entity<Self>> {
|
||||
@@ -86,7 +88,7 @@ impl ThreadStore {
|
||||
|
||||
fn new(
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -246,7 +248,7 @@ impl ThreadStore {
|
||||
self.context_server_manager.clone()
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> Arc<ToolWorkingSet> {
|
||||
pub fn tools(&self) -> Entity<ToolWorkingSet> {
|
||||
self.tools.clone()
|
||||
}
|
||||
|
||||
@@ -353,52 +355,60 @@ impl ThreadStore {
|
||||
})
|
||||
}
|
||||
|
||||
fn load_default_profile(&self, cx: &Context<Self>) {
|
||||
fn load_default_profile(&self, cx: &mut Context<Self>) {
|
||||
let assistant_settings = AssistantSettings::get_global(cx);
|
||||
|
||||
self.load_profile_by_id(&assistant_settings.default_profile, cx);
|
||||
self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
|
||||
}
|
||||
|
||||
pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
|
||||
pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
|
||||
let assistant_settings = AssistantSettings::get_global(cx);
|
||||
|
||||
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
|
||||
self.load_profile(profile, cx);
|
||||
if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
|
||||
self.load_profile(profile.clone(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
|
||||
self.tools.disable_all_tools();
|
||||
self.tools.enable(
|
||||
ToolSource::Native,
|
||||
&profile
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.disable_all_tools(cx);
|
||||
tools.enable(
|
||||
ToolSource::Native,
|
||||
&profile
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
if profile.enable_all_context_servers {
|
||||
for context_server in self.context_server_manager.read(cx).all_servers() {
|
||||
self.tools.enable_source(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server.id().into(),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.enable_source(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server.id().into(),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
for (context_server_id, preset) in &profile.context_servers {
|
||||
self.tools.enable(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server_id.clone().into(),
|
||||
},
|
||||
&preset
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.enable(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server_id.clone().into(),
|
||||
},
|
||||
&preset
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -432,29 +442,36 @@ impl ThreadStore {
|
||||
|
||||
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
||||
if let Some(tools) = protocol.list_tools().await.log_err() {
|
||||
let tool_ids = tools
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
log::info!(
|
||||
"registering context server tool: {:?}",
|
||||
tool.name
|
||||
);
|
||||
tool_working_set.insert(Arc::new(
|
||||
ContextServerTool::new(
|
||||
context_server_manager.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
),
|
||||
))
|
||||
let tool_ids = tool_working_set
|
||||
.update(cx, |tool_working_set, _| {
|
||||
tools
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
log::info!(
|
||||
"registering context server tool: {:?}",
|
||||
tool.name
|
||||
);
|
||||
tool_working_set.insert(Arc::new(
|
||||
ContextServerTool::new(
|
||||
context_server_manager.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
.log_err();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.context_server_tool_ids.insert(server_id, tool_ids);
|
||||
this.load_default_profile(cx);
|
||||
})
|
||||
.log_err();
|
||||
if let Some(tool_ids) = tool_ids {
|
||||
this.update(cx, |this, cx| {
|
||||
this.context_server_tool_ids
|
||||
.insert(server_id, tool_ids);
|
||||
this.load_default_profile(cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -464,7 +481,9 @@ impl ThreadStore {
|
||||
}
|
||||
context_server::manager::Event::ServerStopped { server_id } => {
|
||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
tool_working_set.update(cx, |tool_working_set, _| {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
});
|
||||
self.load_default_profile(cx);
|
||||
}
|
||||
}
|
||||
@@ -491,6 +510,8 @@ pub struct SerializedThread {
|
||||
pub cumulative_token_usage: TokenUsage,
|
||||
#[serde(default)]
|
||||
pub detailed_summary_state: DetailedSummaryState,
|
||||
#[serde(default)]
|
||||
pub exceeded_window_error: Option<ExceededWindowError>,
|
||||
}
|
||||
|
||||
impl SerializedThread {
|
||||
@@ -577,6 +598,7 @@ impl LegacySerializedThread {
|
||||
initial_project_snapshot: self.initial_project_snapshot,
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
detailed_summary_state: DetailedSummaryState::default(),
|
||||
exceeded_window_error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
|
||||
use collections::HashMap;
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::Shared;
|
||||
use gpui::{App, SharedString, Task};
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
|
||||
@@ -30,7 +30,7 @@ pub struct ToolUse {
|
||||
pub const USING_TOOL_MARKER: &str = "<using_tool>";
|
||||
|
||||
pub struct ToolUseState {
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
@@ -39,7 +39,7 @@ pub struct ToolUseState {
|
||||
}
|
||||
|
||||
impl ToolUseState {
|
||||
pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
|
||||
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
|
||||
Self {
|
||||
tools,
|
||||
tool_uses_by_assistant_message: HashMap::default(),
|
||||
@@ -54,7 +54,7 @@ impl ToolUseState {
|
||||
///
|
||||
/// Accepts a function to filter the tools that should be used to populate the state.
|
||||
pub fn from_serialized_messages(
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
messages: &[SerializedMessage],
|
||||
mut filter_by_tool_name: impl FnMut(&str) -> bool,
|
||||
) -> Self {
|
||||
@@ -180,12 +180,12 @@ impl ToolUseState {
|
||||
}
|
||||
})();
|
||||
|
||||
let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
|
||||
{
|
||||
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
|
||||
} else {
|
||||
(IconName::Cog, false)
|
||||
};
|
||||
let (icon, needs_confirmation) =
|
||||
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
|
||||
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
|
||||
} else {
|
||||
(IconName::Cog, false)
|
||||
};
|
||||
|
||||
tool_uses.push(ToolUse {
|
||||
id: tool_use.id.clone(),
|
||||
@@ -207,7 +207,7 @@ impl ToolUseState {
|
||||
input: &serde_json::Value,
|
||||
cx: &App,
|
||||
) -> SharedString {
|
||||
if let Some(tool) = self.tools.tool(tool_name, cx) {
|
||||
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
|
||||
tool.ui_text(input).into()
|
||||
} else {
|
||||
format!("Unknown tool {tool_name:?}").into()
|
||||
|
||||
@@ -25,5 +25,4 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
@@ -10,7 +10,6 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
use util::ResultExt as _;
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
@@ -363,11 +362,25 @@ pub struct RateLimitInfo {
|
||||
|
||||
impl RateLimitInfo {
|
||||
fn from_headers(headers: &HeaderMap<HeaderValue>) -> Self {
|
||||
// Check if any rate limit headers exist
|
||||
let has_rate_limit_headers = headers
|
||||
.keys()
|
||||
.any(|k| k.as_str().starts_with("anthropic-ratelimit-"));
|
||||
|
||||
if !has_rate_limit_headers {
|
||||
return Self {
|
||||
requests: None,
|
||||
tokens: None,
|
||||
input_tokens: None,
|
||||
output_tokens: None,
|
||||
};
|
||||
}
|
||||
|
||||
Self {
|
||||
requests: RateLimit::from_headers("requests", headers).log_err(),
|
||||
tokens: RateLimit::from_headers("tokens", headers).log_err(),
|
||||
input_tokens: RateLimit::from_headers("input-tokens", headers).log_err(),
|
||||
output_tokens: RateLimit::from_headers("output-tokens", headers).log_err(),
|
||||
requests: RateLimit::from_headers("requests", headers).ok(),
|
||||
tokens: RateLimit::from_headers("tokens", headers).ok(),
|
||||
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
|
||||
output_tokens: RateLimit::from_headers("output-tokens", headers).ok(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -724,4 +737,54 @@ impl ApiError {
|
||||
pub fn is_rate_limit_error(&self) -> bool {
|
||||
matches!(self.error_type.as_str(), "rate_limit_error")
|
||||
}
|
||||
|
||||
pub fn match_window_exceeded(&self) -> Option<usize> {
|
||||
let Some(ApiErrorCode::InvalidRequestError) = self.code() else {
|
||||
return None;
|
||||
};
|
||||
|
||||
parse_prompt_too_long(&self.message)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_prompt_too_long(message: &str) -> Option<usize> {
|
||||
message
|
||||
.strip_prefix("prompt is too long: ")?
|
||||
.split_once(" tokens")?
|
||||
.0
|
||||
.parse::<usize>()
|
||||
.ok()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_match_window_exceeded() {
|
||||
let error = ApiError {
|
||||
error_type: "invalid_request_error".to_string(),
|
||||
message: "prompt is too long: 220000 tokens > 200000".to_string(),
|
||||
};
|
||||
assert_eq!(error.match_window_exceeded(), Some(220_000));
|
||||
|
||||
let error = ApiError {
|
||||
error_type: "invalid_request_error".to_string(),
|
||||
message: "prompt is too long: 1234953 tokens".to_string(),
|
||||
};
|
||||
assert_eq!(error.match_window_exceeded(), Some(1234953));
|
||||
|
||||
let error = ApiError {
|
||||
error_type: "invalid_request_error".to_string(),
|
||||
message: "not a prompt length error".to_string(),
|
||||
};
|
||||
assert_eq!(error.match_window_exceeded(), None);
|
||||
|
||||
let error = ApiError {
|
||||
error_type: "rate_limit_error".to_string(),
|
||||
message: "prompt is too long: 12345 tokens".to_string(),
|
||||
};
|
||||
assert_eq!(error.match_window_exceeded(), None);
|
||||
|
||||
let error = ApiError {
|
||||
error_type: "invalid_request_error".to_string(),
|
||||
message: "prompt is too long: invalid tokens".to_string(),
|
||||
};
|
||||
assert_eq!(error.match_window_exceeded(), None);
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use collections::BTreeMap;
|
||||
use futures::{StreamExt, channel::mpsc};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
|
||||
use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
|
||||
use project::{Project, ProjectItem};
|
||||
use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
|
||||
use std::{cmp, ops::Range, sync::Arc};
|
||||
use text::{Edit, Patch, Rope};
|
||||
use util::RangeExt;
|
||||
@@ -49,6 +49,10 @@ impl ActionLog {
|
||||
.tracked_buffers
|
||||
.entry(buffer.clone())
|
||||
.or_insert_with(|| {
|
||||
let open_lsp_handle = self.project.update(cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&buffer, cx)
|
||||
});
|
||||
|
||||
let text_snapshot = buffer.read(cx).text_snapshot();
|
||||
let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
|
||||
let (diff_update_tx, diff_update_rx) = mpsc::unbounded();
|
||||
@@ -76,6 +80,7 @@ impl ActionLog {
|
||||
version: buffer.read(cx).version(),
|
||||
diff,
|
||||
diff_update: diff_update_tx,
|
||||
_open_lsp_handle: open_lsp_handle,
|
||||
_maintain_diff: cx.spawn({
|
||||
let buffer = buffer.clone();
|
||||
async move |this, cx| {
|
||||
@@ -615,6 +620,7 @@ struct TrackedBuffer {
|
||||
diff: Entity<BufferDiff>,
|
||||
snapshot: text::BufferSnapshot,
|
||||
diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>,
|
||||
_open_lsp_handle: OpenLspBufferHandle,
|
||||
_maintain_diff: Task<()>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod action_log;
|
||||
mod tool_registry;
|
||||
mod tool_schema;
|
||||
mod tool_working_set;
|
||||
|
||||
use std::fmt;
|
||||
@@ -20,6 +21,7 @@ use project::Project;
|
||||
|
||||
pub use crate::action_log::*;
|
||||
pub use crate::tool_registry::*;
|
||||
pub use crate::tool_schema::*;
|
||||
pub use crate::tool_working_set::*;
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
@@ -139,8 +141,8 @@ pub trait Tool: 'static + Send + Sync {
|
||||
fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool;
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
serde_json::Value::Object(serde_json::Map::default())
|
||||
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
Ok(serde_json::Value::Object(serde_json::Map::default()))
|
||||
}
|
||||
|
||||
/// Returns markdown to be displayed in the UI for this tool.
|
||||
|
||||
236
crates/assistant_tool/src/tool_schema.rs
Normal file
236
crates/assistant_tool/src/tool_schema.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::LanguageModelToolSchemaFormat;
|
||||
|
||||
/// Tries to adapt a JSON schema representation to be compatible with the specified format.
|
||||
///
|
||||
/// If the json cannot be made compatible with the specified format, an error is returned.
|
||||
pub fn adapt_schema_to_format(
|
||||
json: &mut Value,
|
||||
format: LanguageModelToolSchemaFormat,
|
||||
) -> Result<()> {
|
||||
match format {
|
||||
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
|
||||
}
|
||||
}
|
||||
|
||||
/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
|
||||
fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
||||
if let Value::Object(obj) = json {
|
||||
const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
|
||||
|
||||
for key in UNSUPPORTED_KEYS {
|
||||
if obj.contains_key(key) {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Schema cannot be made compatible because it contains \"{}\" ",
|
||||
key
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
|
||||
for key in KEYS_TO_REMOVE {
|
||||
obj.remove(key);
|
||||
}
|
||||
|
||||
if let Some(default) = obj.get("default") {
|
||||
let is_null = default.is_null();
|
||||
// Default is not supported, so we need to remove it
|
||||
obj.remove("default");
|
||||
if is_null {
|
||||
obj.insert("nullable".to_string(), Value::Bool(true));
|
||||
}
|
||||
}
|
||||
|
||||
// If a type is not specified for an input parameter, add a default type
|
||||
if obj.contains_key("description")
|
||||
&& !obj.contains_key("type")
|
||||
&& !(obj.contains_key("anyOf")
|
||||
|| obj.contains_key("oneOf")
|
||||
|| obj.contains_key("allOf"))
|
||||
{
|
||||
obj.insert("type".to_string(), Value::String("string".to_string()));
|
||||
}
|
||||
|
||||
// Handle oneOf -> anyOf conversion
|
||||
if let Some(subschemas) = obj.get_mut("oneOf") {
|
||||
if subschemas.is_array() {
|
||||
let subschemas_clone = subschemas.clone();
|
||||
obj.remove("oneOf");
|
||||
obj.insert("anyOf".to_string(), subschemas_clone);
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively process all nested objects and arrays
|
||||
for (_, value) in obj.iter_mut() {
|
||||
if let Value::Object(_) | Value::Array(_) = value {
|
||||
adapt_to_json_schema_subset(value)?;
|
||||
}
|
||||
}
|
||||
} else if let Value::Array(arr) = json {
|
||||
for item in arr.iter_mut() {
|
||||
adapt_to_json_schema_subset(item)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_transform_default_null_to_nullable() {
|
||||
let mut json = json!({
|
||||
"description": "A test field",
|
||||
"type": "string",
|
||||
"default": null
|
||||
});
|
||||
|
||||
adapt_to_json_schema_subset(&mut json).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"description": "A test field",
|
||||
"type": "string",
|
||||
"nullable": true
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_adds_type_when_missing() {
|
||||
let mut json = json!({
|
||||
"description": "A test field without type"
|
||||
});
|
||||
|
||||
adapt_to_json_schema_subset(&mut json).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"description": "A test field without type",
|
||||
"type": "string"
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_removes_format() {
|
||||
let mut json = json!({
|
||||
"description": "A test field",
|
||||
"type": "integer",
|
||||
"format": "uint32"
|
||||
});
|
||||
|
||||
adapt_to_json_schema_subset(&mut json).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"description": "A test field",
|
||||
"type": "integer"
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_one_of_to_any_of() {
|
||||
let mut json = json!({
|
||||
"description": "A test field",
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
});
|
||||
|
||||
adapt_to_json_schema_subset(&mut json).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"description": "A test field",
|
||||
"anyOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_nested_objects() {
|
||||
let mut json = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nested": {
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"format": "email"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
adapt_to_json_schema_subset(&mut json).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nested": {
|
||||
"anyOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_fails_if_unsupported_keys_exist() {
|
||||
let mut json = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"$ref": "#/definitions/User",
|
||||
}
|
||||
});
|
||||
|
||||
assert!(adapt_to_json_schema_subset(&mut json).is_err());
|
||||
|
||||
let mut json = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"if": "...",
|
||||
}
|
||||
});
|
||||
|
||||
assert!(adapt_to_json_schema_subset(&mut json).is_err());
|
||||
|
||||
let mut json = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"then": "...",
|
||||
}
|
||||
});
|
||||
|
||||
assert!(adapt_to_json_schema_subset(&mut json).is_err());
|
||||
|
||||
let mut json = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"else": "...",
|
||||
}
|
||||
});
|
||||
|
||||
assert!(adapt_to_json_schema_subset(&mut json).is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use collections::{HashMap, HashSet, IndexMap};
|
||||
use gpui::App;
|
||||
use parking_lot::Mutex;
|
||||
use gpui::{App, Context, EventEmitter};
|
||||
|
||||
use crate::{Tool, ToolRegistry, ToolSource};
|
||||
|
||||
@@ -12,11 +11,6 @@ pub struct ToolId(usize);
|
||||
/// A working set of tools for use in one instance of the Assistant Panel.
|
||||
#[derive(Default)]
|
||||
pub struct ToolWorkingSet {
|
||||
state: Mutex<WorkingSetState>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct WorkingSetState {
|
||||
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
|
||||
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
|
||||
enabled_sources: HashSet<ToolSource>,
|
||||
@@ -24,99 +18,27 @@ struct WorkingSetState {
|
||||
next_tool_id: ToolId,
|
||||
}
|
||||
|
||||
pub enum ToolWorkingSetEvent {
|
||||
EnabledToolsChanged,
|
||||
}
|
||||
|
||||
impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
|
||||
|
||||
impl ToolWorkingSet {
|
||||
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
|
||||
self.state
|
||||
.lock()
|
||||
.context_server_tools_by_name
|
||||
self.context_server_tools_by_name
|
||||
.get(name)
|
||||
.cloned()
|
||||
.or_else(|| ToolRegistry::global(cx).tool(name))
|
||||
}
|
||||
|
||||
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
self.state.lock().tools(cx)
|
||||
}
|
||||
|
||||
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
|
||||
self.state.lock().tools_by_source(cx)
|
||||
}
|
||||
|
||||
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
self.state.lock().enabled_tools(cx)
|
||||
}
|
||||
|
||||
pub fn disable_all_tools(&self) {
|
||||
let mut state = self.state.lock();
|
||||
state.disable_all_tools();
|
||||
}
|
||||
|
||||
pub fn enable_source(&self, source: ToolSource, cx: &App) {
|
||||
let mut state = self.state.lock();
|
||||
state.enable_source(source, cx);
|
||||
}
|
||||
|
||||
pub fn disable_source(&self, source: &ToolSource) {
|
||||
let mut state = self.state.lock();
|
||||
state.disable_source(source);
|
||||
}
|
||||
|
||||
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
|
||||
let mut state = self.state.lock();
|
||||
let tool_id = state.next_tool_id;
|
||||
state.next_tool_id.0 += 1;
|
||||
state
|
||||
.context_server_tools_by_id
|
||||
.insert(tool_id, tool.clone());
|
||||
state.tools_changed();
|
||||
tool_id
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
|
||||
self.state.lock().is_enabled(source, name)
|
||||
}
|
||||
|
||||
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
|
||||
self.state.lock().is_disabled(source, name)
|
||||
}
|
||||
|
||||
pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
|
||||
let mut state = self.state.lock();
|
||||
state.enable(source, tools_to_enable);
|
||||
}
|
||||
|
||||
pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
|
||||
let mut state = self.state.lock();
|
||||
state.disable(source, tools_to_disable);
|
||||
}
|
||||
|
||||
pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
|
||||
let mut state = self.state.lock();
|
||||
state
|
||||
.context_server_tools_by_id
|
||||
.retain(|id, _| !tool_ids_to_remove.contains(id));
|
||||
state.tools_changed();
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkingSetState {
|
||||
fn tools_changed(&mut self) {
|
||||
self.context_server_tools_by_name.clear();
|
||||
self.context_server_tools_by_name.extend(
|
||||
self.context_server_tools_by_id
|
||||
.values()
|
||||
.map(|tool| (tool.name(), tool.clone())),
|
||||
);
|
||||
}
|
||||
|
||||
fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
let mut tools = ToolRegistry::global(cx).tools();
|
||||
tools.extend(self.context_server_tools_by_id.values().cloned());
|
||||
|
||||
tools
|
||||
}
|
||||
|
||||
fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
|
||||
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
|
||||
let mut tools_by_source = IndexMap::default();
|
||||
|
||||
for tool in self.tools(cx) {
|
||||
@@ -135,7 +57,7 @@ impl WorkingSetState {
|
||||
tools_by_source
|
||||
}
|
||||
|
||||
fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
let all_tools = self.tools(cx);
|
||||
|
||||
all_tools
|
||||
@@ -144,31 +66,12 @@ impl WorkingSetState {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
|
||||
self.enabled_tools_by_source
|
||||
.get(source)
|
||||
.map_or(false, |enabled_tools| enabled_tools.contains(name))
|
||||
pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
|
||||
self.enabled_tools_by_source.clear();
|
||||
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
|
||||
}
|
||||
|
||||
fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
|
||||
!self.is_enabled(source, name)
|
||||
}
|
||||
|
||||
fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
|
||||
self.enabled_tools_by_source
|
||||
.entry(source)
|
||||
.or_default()
|
||||
.extend(tools_to_enable.into_iter().cloned());
|
||||
}
|
||||
|
||||
fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
|
||||
self.enabled_tools_by_source
|
||||
.entry(source)
|
||||
.or_default()
|
||||
.retain(|name| !tools_to_disable.contains(name));
|
||||
}
|
||||
|
||||
fn enable_source(&mut self, source: ToolSource, cx: &App) {
|
||||
pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
|
||||
self.enabled_sources.insert(source.clone());
|
||||
|
||||
let tools_by_source = self.tools_by_source(cx);
|
||||
@@ -181,14 +84,72 @@ impl WorkingSetState {
|
||||
.collect::<HashSet<_>>(),
|
||||
);
|
||||
}
|
||||
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
|
||||
}
|
||||
|
||||
fn disable_source(&mut self, source: &ToolSource) {
|
||||
pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
|
||||
self.enabled_sources.remove(source);
|
||||
self.enabled_tools_by_source.remove(source);
|
||||
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
|
||||
}
|
||||
|
||||
fn disable_all_tools(&mut self) {
|
||||
self.enabled_tools_by_source.clear();
|
||||
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
|
||||
let tool_id = self.next_tool_id;
|
||||
self.next_tool_id.0 += 1;
|
||||
self.context_server_tools_by_id
|
||||
.insert(tool_id, tool.clone());
|
||||
self.tools_changed();
|
||||
tool_id
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
|
||||
self.enabled_tools_by_source
|
||||
.get(source)
|
||||
.map_or(false, |enabled_tools| enabled_tools.contains(name))
|
||||
}
|
||||
|
||||
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
|
||||
!self.is_enabled(source, name)
|
||||
}
|
||||
|
||||
pub fn enable(
|
||||
&mut self,
|
||||
source: ToolSource,
|
||||
tools_to_enable: &[Arc<str>],
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.enabled_tools_by_source
|
||||
.entry(source)
|
||||
.or_default()
|
||||
.extend(tools_to_enable.into_iter().cloned());
|
||||
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
|
||||
}
|
||||
|
||||
pub fn disable(
|
||||
&mut self,
|
||||
source: ToolSource,
|
||||
tools_to_disable: &[Arc<str>],
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.enabled_tools_by_source
|
||||
.entry(source)
|
||||
.or_default()
|
||||
.retain(|name| !tools_to_disable.contains(name));
|
||||
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
|
||||
self.context_server_tools_by_id
|
||||
.retain(|id, _| !tool_ids_to_remove.contains(id));
|
||||
self.tools_changed();
|
||||
}
|
||||
|
||||
fn tools_changed(&mut self) {
|
||||
self.context_server_tools_by_name.clear();
|
||||
self.context_server_tools_by_name.extend(
|
||||
self.context_server_tools_by_id
|
||||
.values()
|
||||
.map(|tool| (tool.name(), tool.clone())),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod batch_tool;
|
||||
mod code_action_tool;
|
||||
mod code_symbols_tool;
|
||||
mod contents_tool;
|
||||
mod copy_path_tool;
|
||||
mod create_directory_tool;
|
||||
mod create_file_tool;
|
||||
@@ -35,6 +36,7 @@ use web_search_tool::WebSearchTool;
|
||||
use crate::batch_tool::BatchTool;
|
||||
use crate::code_action_tool::CodeActionTool;
|
||||
use crate::code_symbols_tool::CodeSymbolsTool;
|
||||
use crate::contents_tool::ContentsTool;
|
||||
use crate::create_directory_tool::CreateDirectoryTool;
|
||||
use crate::create_file_tool::CreateFileTool;
|
||||
use crate::delete_path_tool::DeletePathTool;
|
||||
@@ -59,6 +61,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
registry.register_tool(BatchTool);
|
||||
registry.register_tool(CodeActionTool);
|
||||
registry.register_tool(CodeSymbolsTool);
|
||||
registry.register_tool(ContentsTool);
|
||||
registry.register_tool(CopyPathTool);
|
||||
registry.register_tool(CreateDirectoryTool);
|
||||
registry.register_tool(CreateFileTool);
|
||||
@@ -79,3 +82,42 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
registry.register_tool(ThinkingTool);
|
||||
registry.register_tool(WebSearchTool);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use http_client::FakeHttpClient;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
|
||||
crate::init(
|
||||
Arc::new(http_client::HttpClientWithUrl::new(
|
||||
FakeHttpClient::with_200_response(),
|
||||
"https://zed.dev",
|
||||
None,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
|
||||
for tool in ToolRegistry::global(cx).tools() {
|
||||
let actual_schema = tool
|
||||
.input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset)
|
||||
.unwrap();
|
||||
let mut expected_schema = actual_schema.clone();
|
||||
assistant_tool::adapt_schema_to_format(
|
||||
&mut expected_schema,
|
||||
language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let error_message = format!(
|
||||
"Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\
|
||||
Are you using `schema::json_schema_for<T>(format)` to generate the schema?",
|
||||
tool.name(),
|
||||
);
|
||||
|
||||
assert_eq!(actual_schema, expected_schema, "{}", error_message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,7 +172,7 @@ impl Tool for BatchTool {
|
||||
IconName::Cog
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<BatchToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ use anyhow::{Context as _, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language::{self, Anchor, Buffer, ToPointUtf16};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{self, LspAction, Project};
|
||||
use regex::Regex;
|
||||
use schemars::JsonSchema;
|
||||
@@ -10,6 +10,8 @@ use serde::{Deserialize, Serialize};
|
||||
use std::{ops::Range, sync::Arc};
|
||||
use ui::IconName;
|
||||
|
||||
use crate::schema::json_schema_for;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CodeActionToolInput {
|
||||
/// The relative path to the file containing the text range.
|
||||
@@ -95,12 +97,8 @@ impl Tool for CodeActionTool {
|
||||
IconName::Wand
|
||||
}
|
||||
|
||||
fn input_schema(
|
||||
&self,
|
||||
_format: language_model::LanguageModelToolSchemaFormat,
|
||||
) -> serde_json::Value {
|
||||
let schema = schemars::schema_for!(CodeActionToolInput);
|
||||
serde_json::to_value(&schema).unwrap()
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<CodeActionToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
|
||||
@@ -91,7 +91,7 @@ impl Tool for CodeSymbolsTool {
|
||||
IconName::Code
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<CodeSymbolsInput>(format)
|
||||
}
|
||||
|
||||
|
||||
239
crates/assistant_tools/src/contents_tool.rs
Normal file
239
crates/assistant_tools/src/contents_tool.rs
Normal file
@@ -0,0 +1,239 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool};
|
||||
use gpui::{App, Entity, Task};
|
||||
use itertools::Itertools;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Write, path::Path};
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
|
||||
/// If the model requests to read a file whose size exceeds this, then
|
||||
/// the tool will return the file's symbol outline instead of its contents,
|
||||
/// and suggest trying again using line ranges from the outline.
|
||||
const MAX_FILE_SIZE_TO_READ: usize = 16384;
|
||||
|
||||
/// If the model requests to list the entries in a directory with more
|
||||
/// entries than this, then the tool will return a subset of the entries
|
||||
/// and suggest trying again.
|
||||
const MAX_DIR_ENTRIES: usize = 1024;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ContentsToolInput {
|
||||
/// The relative path of the file or directory to access.
|
||||
///
|
||||
/// This path should never be absolute, and the first component
|
||||
/// of the path should always be a root directory in a project.
|
||||
///
|
||||
/// <example>
|
||||
/// If the project has the following root directories:
|
||||
///
|
||||
/// - directory1
|
||||
/// - directory2
|
||||
///
|
||||
/// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
|
||||
/// If you want to list contents in the directory `directory2/subfolder`, you should use the path `directory2/subfolder`.
|
||||
/// </example>
|
||||
pub path: String,
|
||||
|
||||
/// Optional position (1-based index) to start reading on, if you want to read a subset of the contents.
|
||||
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
|
||||
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
|
||||
///
|
||||
/// Defaults to 1.
|
||||
pub start: Option<u32>,
|
||||
|
||||
/// Optional position (1-based index) to end reading on, if you want to read a subset of the contents.
|
||||
/// When reading a file, this refers to a line number in the file (e.g. 1 is the first line).
|
||||
/// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry).
|
||||
///
|
||||
/// Defaults to reading until the end of the file or directory.
|
||||
pub end: Option<u32>,
|
||||
}
|
||||
|
||||
pub struct ContentsTool;
|
||||
|
||||
impl Tool for ContentsTool {
|
||||
fn name(&self) -> String {
|
||||
"contents".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
include_str!("./contents_tool/description.md").into()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::FileSearch
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<ContentsToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<ContentsToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let path = MarkdownString::inline_code(&input.path);
|
||||
|
||||
match (input.start, input.end) {
|
||||
(Some(start), None) => format!("Read {path} (from line {start})"),
|
||||
(Some(start), Some(end)) => {
|
||||
format!("Read {path} (lines {start}-{end})")
|
||||
}
|
||||
_ => format!("Read {path}"),
|
||||
}
|
||||
}
|
||||
Err(_) => "Read file or directory".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let input = match serde_json::from_value::<ContentsToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||
};
|
||||
|
||||
// Sometimes models will return these even though we tell it to give a path and not a glob.
|
||||
// When this happens, just list the root worktree directories.
|
||||
if matches!(input.path.as_str(), "." | "" | "./" | "*") {
|
||||
let output = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.filter_map(|worktree| {
|
||||
worktree.read(cx).root_entry().and_then(|entry| {
|
||||
if entry.is_dir() {
|
||||
entry.path.to_str()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
return Task::ready(Ok(output));
|
||||
}
|
||||
|
||||
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
|
||||
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
|
||||
};
|
||||
|
||||
let Some(worktree) = project
|
||||
.read(cx)
|
||||
.worktree_for_id(project_path.worktree_id, cx)
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Worktree not found")));
|
||||
};
|
||||
let worktree = worktree.read(cx);
|
||||
|
||||
let Some(entry) = worktree.entry_for_path(&project_path.path) else {
|
||||
return Task::ready(Err(anyhow!("Path not found: {}", input.path)));
|
||||
};
|
||||
|
||||
// If it's a directory, list its contents
|
||||
if entry.is_dir() {
|
||||
let mut output = String::new();
|
||||
let start_index = input
|
||||
.start
|
||||
.map(|line| (line as usize).saturating_sub(1))
|
||||
.unwrap_or(0);
|
||||
let end_index = input
|
||||
.end
|
||||
.map(|line| (line as usize).saturating_sub(1))
|
||||
.unwrap_or(MAX_DIR_ENTRIES);
|
||||
let mut skipped = 0;
|
||||
|
||||
for (index, entry) in worktree.child_entries(&project_path.path).enumerate() {
|
||||
if index >= start_index && index <= end_index {
|
||||
writeln!(
|
||||
output,
|
||||
"{}",
|
||||
Path::new(worktree.root_name()).join(&entry.path).display(),
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
skipped += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
output.push_str(&input.path);
|
||||
output.push_str(" is empty.");
|
||||
}
|
||||
|
||||
if skipped > 0 {
|
||||
write!(
|
||||
output,
|
||||
"\n\nNote: Skipped {skipped} entries. Adjust start and end to see other entries.",
|
||||
).ok();
|
||||
}
|
||||
|
||||
Task::ready(Ok(output))
|
||||
} else {
|
||||
// It's a file, so read its contents
|
||||
let file_path = input.path.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let buffer = cx
|
||||
.update(|cx| {
|
||||
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
|
||||
})?
|
||||
.await?;
|
||||
|
||||
if input.start.is_some() || input.end.is_some() {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
let start = input.start.unwrap_or(1);
|
||||
let lines = text.split('\n').skip(start as usize - 1);
|
||||
if let Some(end) = input.end {
|
||||
let count = end.saturating_sub(start).max(1); // Ensure at least 1 line
|
||||
Itertools::intersperse(lines.take(count as usize), "\n").collect()
|
||||
} else {
|
||||
Itertools::intersperse(lines, "\n").collect()
|
||||
}
|
||||
})?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
} else {
|
||||
// No line ranges specified, so check file size to see if it's too big.
|
||||
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
|
||||
|
||||
if file_size <= MAX_FILE_SIZE_TO_READ {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
} else {
|
||||
// File is too big, so return its outline and a suggestion to
|
||||
// read again with a line number range specified.
|
||||
let outline = file_outline(project, file_path, action_log, None, 0, cx).await?;
|
||||
|
||||
Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline."))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
9
crates/assistant_tools/src/contents_tool/description.md
Normal file
9
crates/assistant_tools/src/contents_tool/description.md
Normal file
@@ -0,0 +1,9 @@
|
||||
Reads the contents of a path on the filesystem.
|
||||
|
||||
If the path is a directory, this lists all files and directories within that path.
|
||||
If the path is a file, this returns the file's contents.
|
||||
|
||||
When reading a file, if the file is too big and no line range is specified, an outline of the file's code symbols is listed instead, which can be used to request specific line ranges in a subsequent call.
|
||||
|
||||
Similarly, if a directory has too many entries to show at once, a subset of entries will be shown,
|
||||
and subsequent requests can use starting and ending line numbers to get other subsets.
|
||||
@@ -55,7 +55,7 @@ impl Tool for CopyPathTool {
|
||||
IconName::Clipboard
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<CopyPathToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ impl Tool for CreateDirectoryTool {
|
||||
IconName::Folder
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<CreateDirectoryToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ impl Tool for CreateFileTool {
|
||||
IconName::FileCreate
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<CreateFileToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ impl Tool for DeletePathTool {
|
||||
IconName::FileDelete
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<DeletePathToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ impl Tool for DiagnosticsTool {
|
||||
IconName::XCircle
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<DiagnosticsToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ impl Tool for FetchTool {
|
||||
IconName::Globe
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<FetchToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ impl Tool for FindReplaceFileTool {
|
||||
IconName::Pencil
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<FindReplaceFileToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ impl Tool for ListDirectoryTool {
|
||||
IconName::Folder
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<ListDirectoryToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ impl Tool for MovePathTool {
|
||||
IconName::ArrowRightLeft
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<MovePathToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ impl Tool for NowTool {
|
||||
IconName::Info
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<NowToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ impl Tool for OpenTool {
|
||||
IconName::ArrowUpRight
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<OpenToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ impl Tool for PathSearchTool {
|
||||
IconName::SearchCode
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<PathSearchToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ impl Tool for ReadFileTool {
|
||||
IconName::FileSearch
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<ReadFileToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ impl Tool for RegexSearchTool {
|
||||
IconName::Regex
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<RegexSearchToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,15 @@ use anyhow::{Context as _, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language::{self, Buffer, ToPointUtf16};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use ui::IconName;
|
||||
|
||||
use crate::schema::json_schema_for;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct RenameToolInput {
|
||||
/// The relative path to the file containing the symbol to rename.
|
||||
@@ -66,12 +68,8 @@ impl Tool for RenameTool {
|
||||
IconName::Pencil
|
||||
}
|
||||
|
||||
fn input_schema(
|
||||
&self,
|
||||
_format: language_model::LanguageModelToolSchemaFormat,
|
||||
) -> serde_json::Value {
|
||||
let schema = schemars::schema_for!(RenameToolInput);
|
||||
serde_json::to_value(&schema).unwrap()
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<RenameToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
|
||||
@@ -5,23 +5,20 @@ use schemars::{
|
||||
schema::{RootSchema, Schema, SchemaObject},
|
||||
};
|
||||
|
||||
pub fn json_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
pub fn json_schema_for<T: JsonSchema>(
|
||||
format: LanguageModelToolSchemaFormat,
|
||||
) -> Result<serde_json::Value> {
|
||||
let schema = root_schema_for::<T>(format);
|
||||
schema_to_json(&schema, format).expect("Failed to convert tool calling schema to JSON")
|
||||
schema_to_json(&schema, format)
|
||||
}
|
||||
|
||||
pub fn schema_to_json(
|
||||
fn schema_to_json(
|
||||
schema: &RootSchema,
|
||||
format: LanguageModelToolSchemaFormat,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut value = serde_json::to_value(schema)?;
|
||||
match format {
|
||||
LanguageModelToolSchemaFormat::JsonSchema => Ok(value),
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => {
|
||||
transform_fields_to_json_schema_subset(&mut value);
|
||||
Ok(value)
|
||||
}
|
||||
}
|
||||
assistant_tool::adapt_schema_to_format(&mut value, format)?;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
|
||||
@@ -79,42 +76,3 @@ impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
|
||||
schemars::visit::visit_schema_object(self, schema)
|
||||
}
|
||||
}
|
||||
|
||||
fn transform_fields_to_json_schema_subset(json: &mut serde_json::Value) {
|
||||
if let serde_json::Value::Object(obj) = json {
|
||||
if let Some(default) = obj.get("default") {
|
||||
let is_null = default.is_null();
|
||||
//Default is not supported, so we need to remove it.
|
||||
obj.remove("default");
|
||||
if is_null {
|
||||
obj.insert("nullable".to_string(), serde_json::Value::Bool(true));
|
||||
}
|
||||
}
|
||||
|
||||
// If a type is not specified for an input parameter we need to add it.
|
||||
if obj.contains_key("description")
|
||||
&& !obj.contains_key("type")
|
||||
&& !(obj.contains_key("anyOf")
|
||||
|| obj.contains_key("oneOf")
|
||||
|| obj.contains_key("allOf"))
|
||||
{
|
||||
obj.insert(
|
||||
"type".to_string(),
|
||||
serde_json::Value::String("string".to_string()),
|
||||
);
|
||||
}
|
||||
|
||||
//Format field is only partially supported (e.g. not uint compatibility)
|
||||
obj.remove("format");
|
||||
|
||||
for (_, value) in obj.iter_mut() {
|
||||
if let serde_json::Value::Object(_) | serde_json::Value::Array(_) = value {
|
||||
transform_fields_to_json_schema_subset(value);
|
||||
}
|
||||
}
|
||||
} else if let serde_json::Value::Array(arr) = json {
|
||||
for item in arr.iter_mut() {
|
||||
transform_fields_to_json_schema_subset(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ impl Tool for SymbolInfoTool {
|
||||
IconName::Code
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<SymbolInfoToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ impl Tool for TerminalTool {
|
||||
IconName::Terminal
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<TerminalToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ impl Tool for ThinkingTool {
|
||||
IconName::LightBulb
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<ThinkingToolInput>(format)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
tempfile.workspace = true
|
||||
which.workspace = true
|
||||
workspace.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[target.'cfg(not(target_os = "windows"))'.dependencies]
|
||||
which.workspace = true
|
||||
|
||||
@@ -23,7 +23,6 @@ use std::{
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use which::which;
|
||||
use workspace::Workspace;
|
||||
|
||||
const SHOULD_SHOW_UPDATE_NOTIFICATION_KEY: &str = "auto-updater-should-show-updated-notification";
|
||||
@@ -63,7 +62,7 @@ pub struct AutoUpdater {
|
||||
pending_poll: Option<Task<Option<()>>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct JsonRelease {
|
||||
pub version: String,
|
||||
pub url: String,
|
||||
@@ -237,6 +236,46 @@ pub fn view_release_notes(_: &ViewReleaseNotes, cx: &mut App) -> Option<()> {
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
struct InstallerDir(tempfile::TempDir);
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
impl InstallerDir {
|
||||
async fn new() -> Result<Self> {
|
||||
Ok(Self(
|
||||
tempfile::Builder::new()
|
||||
.prefix("zed-auto-update")
|
||||
.tempdir()?,
|
||||
))
|
||||
}
|
||||
|
||||
fn path(&self) -> &Path {
|
||||
self.0.path()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
struct InstallerDir(PathBuf);
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
impl InstallerDir {
|
||||
async fn new() -> Result<Self> {
|
||||
let installer_dir = std::env::current_exe()?
|
||||
.parent()
|
||||
.context("No parent dir for Zed.exe")?
|
||||
.join("updates");
|
||||
if smol::fs::metadata(&installer_dir).await.is_ok() {
|
||||
smol::fs::remove_dir_all(&installer_dir).await?;
|
||||
}
|
||||
smol::fs::create_dir(&installer_dir).await?;
|
||||
Ok(Self(installer_dir))
|
||||
}
|
||||
|
||||
fn path(&self) -> &Path {
|
||||
self.0.as_path()
|
||||
}
|
||||
}
|
||||
|
||||
impl AutoUpdater {
|
||||
pub fn get(cx: &mut App) -> Option<Entity<Self>> {
|
||||
cx.default_global::<GlobalAutoUpdate>().0.clone()
|
||||
@@ -469,22 +508,21 @@ impl AutoUpdater {
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
let temp_dir = tempfile::Builder::new()
|
||||
.prefix("zed-auto-update")
|
||||
.tempdir()?;
|
||||
|
||||
let installer_dir = InstallerDir::new().await?;
|
||||
let filename = match OS {
|
||||
"macos" => Ok("Zed.dmg"),
|
||||
"linux" => Ok("zed.tar.gz"),
|
||||
"windows" => Ok("ZedUpdateInstaller.exe"),
|
||||
_ => Err(anyhow!("not supported: {:?}", OS)),
|
||||
}?;
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
anyhow::ensure!(
|
||||
which("rsync").is_ok(),
|
||||
which::which("rsync").is_ok(),
|
||||
"Aborting. Could not find rsync which is required for auto-updates."
|
||||
);
|
||||
|
||||
let downloaded_asset = temp_dir.path().join(filename);
|
||||
let downloaded_asset = installer_dir.path().join(filename);
|
||||
download_release(&downloaded_asset, release, client, &cx).await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
@@ -493,8 +531,9 @@ impl AutoUpdater {
|
||||
})?;
|
||||
|
||||
let binary_path = match OS {
|
||||
"macos" => install_release_macos(&temp_dir, downloaded_asset, &cx).await,
|
||||
"linux" => install_release_linux(&temp_dir, downloaded_asset, &cx).await,
|
||||
"macos" => install_release_macos(&installer_dir, downloaded_asset, &cx).await,
|
||||
"linux" => install_release_linux(&installer_dir, downloaded_asset, &cx).await,
|
||||
"windows" => install_release_windows(downloaded_asset).await,
|
||||
_ => Err(anyhow!("not supported: {:?}", OS)),
|
||||
}?;
|
||||
|
||||
@@ -629,7 +668,7 @@ async fn download_release(
|
||||
}
|
||||
|
||||
async fn install_release_linux(
|
||||
temp_dir: &tempfile::TempDir,
|
||||
temp_dir: &InstallerDir,
|
||||
downloaded_tar_gz: PathBuf,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<PathBuf> {
|
||||
@@ -696,7 +735,7 @@ async fn install_release_linux(
|
||||
}
|
||||
|
||||
async fn install_release_macos(
|
||||
temp_dir: &tempfile::TempDir,
|
||||
temp_dir: &InstallerDir,
|
||||
downloaded_dmg: PathBuf,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<PathBuf> {
|
||||
@@ -743,3 +782,41 @@ async fn install_release_macos(
|
||||
|
||||
Ok(running_app_path)
|
||||
}
|
||||
|
||||
async fn install_release_windows(downloaded_installer: PathBuf) -> Result<PathBuf> {
|
||||
let output = Command::new(downloaded_installer)
|
||||
.arg("/verysilent")
|
||||
.arg("/update=true")
|
||||
.arg("!desktopicon")
|
||||
.arg("!quicklaunchicon")
|
||||
.output()
|
||||
.await?;
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"failed to start installer: {:?}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
Ok(std::env::current_exe()?)
|
||||
}
|
||||
|
||||
pub fn check_pending_installation() -> bool {
|
||||
let Some(installer_path) = std::env::current_exe()
|
||||
.ok()
|
||||
.and_then(|p| p.parent().map(|p| p.join("updates")))
|
||||
else {
|
||||
return false;
|
||||
};
|
||||
|
||||
// The installer will create a flag file after it finishes updating
|
||||
let flag_file = installer_path.join("versions.txt");
|
||||
if flag_file.exists() {
|
||||
if let Some(helper) = installer_path
|
||||
.parent()
|
||||
.map(|p| p.join("tools\\auto_update_helper.exe"))
|
||||
{
|
||||
let _ = std::process::Command::new(helper).spawn();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
29
crates/auto_update_helper/Cargo.toml
Normal file
29
crates/auto_update_helper/Cargo.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "auto_update_helper"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "auto_update_helper"
|
||||
path = "src/auto_update_helper.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
log.workspace = true
|
||||
simplelog.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
windows.workspace = true
|
||||
|
||||
[target.'cfg(target_os = "windows")'.build-dependencies]
|
||||
winresource = "0.1"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
targets = ["x86_64-pc-windows-msvc"]
|
||||
1
crates/auto_update_helper/LICENSE-GPL
Symbolic link
1
crates/auto_update_helper/LICENSE-GPL
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-GPL
|
||||
BIN
crates/auto_update_helper/app-icon.ico
Normal file
BIN
crates/auto_update_helper/app-icon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 577 KiB |
15
crates/auto_update_helper/build.rs
Normal file
15
crates/auto_update_helper/build.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
fn main() {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
println!("cargo:rerun-if-changed=manifest.xml");
|
||||
|
||||
let mut res = winresource::WindowsResource::new();
|
||||
res.set_manifest_file("manifest.xml");
|
||||
res.set_icon("app-icon.ico");
|
||||
|
||||
if let Err(e) = res.compile() {
|
||||
eprintln!("{}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
16
crates/auto_update_helper/manifest.xml
Normal file
16
crates/auto_update_helper/manifest.xml
Normal file
@@ -0,0 +1,16 @@
|
||||
<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0" xmlns:asmv3="urn:schemas-microsoft-com:asm.v3">
|
||||
<asmv3:application>
|
||||
<asmv3:windowsSettings>
|
||||
<dpiAware xmlns="http://schemas.microsoft.com/SMI/2005/WindowsSettings">true</dpiAware>
|
||||
<dpiAwareness xmlns="http://schemas.microsoft.com/SMI/2016/WindowsSettings">PerMonitorV2</dpiAwareness>
|
||||
</asmv3:windowsSettings>
|
||||
</asmv3:application>
|
||||
<dependency>
|
||||
<dependentAssembly>
|
||||
<assemblyIdentity type='win32'
|
||||
name='Microsoft.Windows.Common-Controls'
|
||||
version='6.0.0.0' processorArchitecture='*'
|
||||
publicKeyToken='6595b64144ccf1df' />
|
||||
</dependentAssembly>
|
||||
</dependency>
|
||||
</assembly>
|
||||
94
crates/auto_update_helper/src/auto_update_helper.rs
Normal file
94
crates/auto_update_helper/src/auto_update_helper.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
mod dialog;
|
||||
#[cfg(target_os = "windows")]
|
||||
mod updater;
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn main() {
|
||||
if let Err(e) = windows_impl::run() {
|
||||
log::error!("Error: Zed update failed, {:?}", e);
|
||||
windows_impl::show_error(format!("Error: {:?}", e));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn main() {}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
mod windows_impl {
|
||||
use std::path::Path;
|
||||
|
||||
use super::dialog::create_dialog_window;
|
||||
use super::updater::perform_update;
|
||||
use anyhow::{Context, Result};
|
||||
use windows::{
|
||||
Win32::{
|
||||
Foundation::{HWND, LPARAM, WPARAM},
|
||||
UI::WindowsAndMessaging::{
|
||||
DispatchMessageW, GetMessageW, MB_ICONERROR, MB_SYSTEMMODAL, MSG, MessageBoxW,
|
||||
PostMessageW, WM_USER,
|
||||
},
|
||||
},
|
||||
core::HSTRING,
|
||||
};
|
||||
|
||||
pub(crate) const WM_JOB_UPDATED: u32 = WM_USER + 1;
|
||||
pub(crate) const WM_TERMINATE: u32 = WM_USER + 2;
|
||||
|
||||
pub(crate) fn run() -> Result<()> {
|
||||
let helper_dir = std::env::current_exe()?
|
||||
.parent()
|
||||
.context("No parent directory")?
|
||||
.to_path_buf();
|
||||
init_log(&helper_dir)?;
|
||||
let app_dir = helper_dir
|
||||
.parent()
|
||||
.context("No parent directory")?
|
||||
.to_path_buf();
|
||||
|
||||
log::info!("======= Starting Zed update =======");
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let hwnd = create_dialog_window(rx)?.0 as isize;
|
||||
std::thread::spawn(move || {
|
||||
let result = perform_update(app_dir.as_path(), Some(hwnd));
|
||||
tx.send(result).ok();
|
||||
unsafe { PostMessageW(Some(HWND(hwnd as _)), WM_TERMINATE, WPARAM(0), LPARAM(0)) }.ok();
|
||||
});
|
||||
unsafe {
|
||||
let mut message = MSG::default();
|
||||
while GetMessageW(&mut message, None, 0, 0).as_bool() {
|
||||
DispatchMessageW(&message);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn init_log(helper_dir: &Path) -> Result<()> {
|
||||
simplelog::WriteLogger::init(
|
||||
simplelog::LevelFilter::Info,
|
||||
simplelog::Config::default(),
|
||||
std::fs::File::options()
|
||||
.append(true)
|
||||
.create(true)
|
||||
.open(helper_dir.join("auto_update_helper.log"))?,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn show_error(mut content: String) {
|
||||
if content.len() > 600 {
|
||||
content.truncate(600);
|
||||
content.push_str("...\n");
|
||||
}
|
||||
let _ = unsafe {
|
||||
MessageBoxW(
|
||||
None,
|
||||
&HSTRING::from(content),
|
||||
windows::core::w!("Error: Zed update failed."),
|
||||
MB_ICONERROR | MB_SYSTEMMODAL,
|
||||
)
|
||||
};
|
||||
}
|
||||
}
|
||||
236
crates/auto_update_helper/src/dialog.rs
Normal file
236
crates/auto_update_helper/src/dialog.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
use std::{cell::RefCell, sync::mpsc::Receiver};
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use windows::{
|
||||
Win32::{
|
||||
Foundation::{HWND, LPARAM, LRESULT, RECT, WPARAM},
|
||||
Graphics::Gdi::{
|
||||
BeginPaint, CLEARTYPE_QUALITY, CLIP_DEFAULT_PRECIS, CreateFontW, DEFAULT_CHARSET,
|
||||
DeleteObject, EndPaint, FW_NORMAL, LOGFONTW, OUT_TT_ONLY_PRECIS, PAINTSTRUCT,
|
||||
ReleaseDC, SelectObject, TextOutW,
|
||||
},
|
||||
System::LibraryLoader::GetModuleHandleW,
|
||||
UI::{
|
||||
Controls::{PBM_SETRANGE, PBM_SETSTEP, PBM_STEPIT, PROGRESS_CLASS},
|
||||
WindowsAndMessaging::{
|
||||
CREATESTRUCTW, CS_HREDRAW, CS_VREDRAW, CreateWindowExW, DefWindowProcW,
|
||||
GWLP_USERDATA, GetDesktopWindow, GetWindowLongPtrW, GetWindowRect, HICON,
|
||||
IMAGE_ICON, LR_DEFAULTSIZE, LR_SHARED, LoadImageW, PostQuitMessage, RegisterClassW,
|
||||
SPI_GETICONTITLELOGFONT, SYSTEM_PARAMETERS_INFO_UPDATE_FLAGS, SendMessageW,
|
||||
SetWindowLongPtrW, SystemParametersInfoW, WINDOW_EX_STYLE, WM_CLOSE, WM_CREATE,
|
||||
WM_DESTROY, WM_NCCREATE, WM_PAINT, WNDCLASSW, WS_CAPTION, WS_CHILD, WS_EX_TOPMOST,
|
||||
WS_POPUP, WS_VISIBLE,
|
||||
},
|
||||
},
|
||||
},
|
||||
core::HSTRING,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
updater::JOBS,
|
||||
windows_impl::{WM_JOB_UPDATED, WM_TERMINATE, show_error},
|
||||
};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
struct DialogInfo {
|
||||
rx: Receiver<Result<()>>,
|
||||
progress_bar: isize,
|
||||
}
|
||||
|
||||
pub(crate) fn create_dialog_window(receiver: Receiver<Result<()>>) -> Result<HWND> {
|
||||
unsafe {
|
||||
let class_name = windows::core::w!("Zed-Auto-Updater-Dialog-Class");
|
||||
let module = GetModuleHandleW(None).context("unable to get module handle")?;
|
||||
let handle = LoadImageW(
|
||||
Some(module.into()),
|
||||
windows::core::PCWSTR(1 as _),
|
||||
IMAGE_ICON,
|
||||
0,
|
||||
0,
|
||||
LR_DEFAULTSIZE | LR_SHARED,
|
||||
)
|
||||
.context("unable to load icon file")?;
|
||||
let wc = WNDCLASSW {
|
||||
lpfnWndProc: Some(wnd_proc),
|
||||
lpszClassName: class_name,
|
||||
style: CS_HREDRAW | CS_VREDRAW,
|
||||
hIcon: HICON(handle.0),
|
||||
..Default::default()
|
||||
};
|
||||
RegisterClassW(&wc);
|
||||
let mut rect = RECT::default();
|
||||
GetWindowRect(GetDesktopWindow(), &mut rect)
|
||||
.context("unable to get desktop window rect")?;
|
||||
let width = 400;
|
||||
let height = 150;
|
||||
let info = Box::new(RefCell::new(DialogInfo {
|
||||
rx: receiver,
|
||||
progress_bar: 0,
|
||||
}));
|
||||
|
||||
let hwnd = CreateWindowExW(
|
||||
WS_EX_TOPMOST,
|
||||
class_name,
|
||||
windows::core::w!("Zed Editor"),
|
||||
WS_VISIBLE | WS_POPUP | WS_CAPTION,
|
||||
rect.right / 2 - width / 2,
|
||||
rect.bottom / 2 - height / 2,
|
||||
width,
|
||||
height,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(Box::into_raw(info) as _),
|
||||
)
|
||||
.context("unable to create dialog window")?;
|
||||
Ok(hwnd)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! return_if_failed {
|
||||
($e:expr) => {
|
||||
match $e {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return LRESULT(e.code().0 as _);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! make_lparam {
|
||||
($l:expr, $h:expr) => {
|
||||
LPARAM(($l as u32 | ($h as u32) << 16) as isize)
|
||||
};
|
||||
}
|
||||
|
||||
unsafe extern "system" fn wnd_proc(
|
||||
hwnd: HWND,
|
||||
msg: u32,
|
||||
wparam: WPARAM,
|
||||
lparam: LPARAM,
|
||||
) -> LRESULT {
|
||||
match msg {
|
||||
WM_NCCREATE => unsafe {
|
||||
let create_struct = lparam.0 as *const CREATESTRUCTW;
|
||||
let info = (*create_struct).lpCreateParams as *mut RefCell<DialogInfo>;
|
||||
let info = Box::from_raw(info);
|
||||
SetWindowLongPtrW(hwnd, GWLP_USERDATA, Box::into_raw(info) as _);
|
||||
DefWindowProcW(hwnd, msg, wparam, lparam)
|
||||
},
|
||||
WM_CREATE => unsafe {
|
||||
// Create progress bar
|
||||
let mut rect = RECT::default();
|
||||
return_if_failed!(GetWindowRect(hwnd, &mut rect));
|
||||
let progress_bar = return_if_failed!(CreateWindowExW(
|
||||
WINDOW_EX_STYLE(0),
|
||||
PROGRESS_CLASS,
|
||||
None,
|
||||
WS_CHILD | WS_VISIBLE,
|
||||
20,
|
||||
50,
|
||||
340,
|
||||
35,
|
||||
Some(hwnd),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
));
|
||||
SendMessageW(
|
||||
progress_bar,
|
||||
PBM_SETRANGE,
|
||||
None,
|
||||
Some(make_lparam!(0, JOBS.len() * 10)),
|
||||
);
|
||||
SendMessageW(progress_bar, PBM_SETSTEP, Some(WPARAM(10)), None);
|
||||
with_dialog_data(hwnd, |data| {
|
||||
data.borrow_mut().progress_bar = progress_bar.0 as isize
|
||||
});
|
||||
LRESULT(0)
|
||||
},
|
||||
WM_PAINT => unsafe {
|
||||
let mut ps = PAINTSTRUCT::default();
|
||||
let hdc = BeginPaint(hwnd, &mut ps);
|
||||
|
||||
let font_name = get_system_ui_font_name();
|
||||
let font = CreateFontW(
|
||||
24,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
FW_NORMAL.0 as _,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
DEFAULT_CHARSET,
|
||||
OUT_TT_ONLY_PRECIS,
|
||||
CLIP_DEFAULT_PRECIS,
|
||||
CLEARTYPE_QUALITY,
|
||||
0,
|
||||
&HSTRING::from(font_name),
|
||||
);
|
||||
let temp = SelectObject(hdc, font.into());
|
||||
let string = HSTRING::from("Zed Editor is updating...");
|
||||
return_if_failed!(TextOutW(hdc, 20, 15, &string).ok());
|
||||
return_if_failed!(DeleteObject(temp).ok());
|
||||
|
||||
return_if_failed!(EndPaint(hwnd, &ps).ok());
|
||||
ReleaseDC(Some(hwnd), hdc);
|
||||
|
||||
LRESULT(0)
|
||||
},
|
||||
WM_JOB_UPDATED => with_dialog_data(hwnd, |data| {
|
||||
let progress_bar = data.borrow().progress_bar;
|
||||
unsafe { SendMessageW(HWND(progress_bar as _), PBM_STEPIT, None, None) }
|
||||
}),
|
||||
WM_TERMINATE => {
|
||||
with_dialog_data(hwnd, |data| {
|
||||
if let Ok(result) = data.borrow_mut().rx.recv() {
|
||||
if let Err(e) = result {
|
||||
log::error!("Failed to update Zed: {:?}", e);
|
||||
show_error(format!("Error: {:?}", e));
|
||||
}
|
||||
}
|
||||
});
|
||||
unsafe { PostQuitMessage(0) };
|
||||
LRESULT(0)
|
||||
}
|
||||
WM_CLOSE => LRESULT(0), // Prevent user occasionally closing the window
|
||||
WM_DESTROY => {
|
||||
unsafe { PostQuitMessage(0) };
|
||||
LRESULT(0)
|
||||
}
|
||||
_ => unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) },
|
||||
}
|
||||
}
|
||||
|
||||
fn with_dialog_data<F, T>(hwnd: HWND, f: F) -> T
|
||||
where
|
||||
F: FnOnce(&RefCell<DialogInfo>) -> T,
|
||||
{
|
||||
let raw = unsafe { GetWindowLongPtrW(hwnd, GWLP_USERDATA) as *mut RefCell<DialogInfo> };
|
||||
let data = unsafe { Box::from_raw(raw) };
|
||||
let result = f(data.as_ref());
|
||||
unsafe { SetWindowLongPtrW(hwnd, GWLP_USERDATA, Box::into_raw(data) as _) };
|
||||
result
|
||||
}
|
||||
|
||||
fn get_system_ui_font_name() -> String {
|
||||
unsafe {
|
||||
let mut info: LOGFONTW = std::mem::zeroed();
|
||||
if SystemParametersInfoW(
|
||||
SPI_GETICONTITLELOGFONT,
|
||||
std::mem::size_of::<LOGFONTW>() as u32,
|
||||
Some(&mut info as *mut _ as _),
|
||||
SYSTEM_PARAMETERS_INFO_UPDATE_FLAGS(0),
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
let font_name = String::from_utf16_lossy(&info.lfFaceName);
|
||||
font_name.trim_matches(char::from(0)).to_owned()
|
||||
} else {
|
||||
"MS Shell Dlg".to_owned()
|
||||
}
|
||||
}
|
||||
}
|
||||
171
crates/auto_update_helper/src/updater.rs
Normal file
171
crates/auto_update_helper/src/updater.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use std::{
|
||||
os::windows::process::CommandExt,
|
||||
path::Path,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use windows::Win32::{
|
||||
Foundation::{HWND, LPARAM, WPARAM},
|
||||
System::Threading::CREATE_NEW_PROCESS_GROUP,
|
||||
UI::WindowsAndMessaging::PostMessageW,
|
||||
};
|
||||
|
||||
use crate::windows_impl::WM_JOB_UPDATED;
|
||||
|
||||
type Job = fn(&Path) -> Result<()>;
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub(crate) const JOBS: [Job; 6] = [
|
||||
// Delete old files
|
||||
|app_dir| {
|
||||
let zed_executable = app_dir.join("Zed.exe");
|
||||
log::info!("Removing old file: {}", zed_executable.display());
|
||||
std::fs::remove_file(&zed_executable).context(format!(
|
||||
"Failed to remove old file {}",
|
||||
zed_executable.display()
|
||||
))
|
||||
},
|
||||
|app_dir| {
|
||||
let zed_cli = app_dir.join("bin\\zed.exe");
|
||||
log::info!("Removing old file: {}", zed_cli.display());
|
||||
std::fs::remove_file(&zed_cli)
|
||||
.context(format!("Failed to remove old file {}", zed_cli.display()))
|
||||
},
|
||||
// Copy new files
|
||||
|app_dir| {
|
||||
let zed_executable_source = app_dir.join("install\\Zed.exe");
|
||||
let zed_executable_dest = app_dir.join("Zed.exe");
|
||||
log::info!(
|
||||
"Copying new file {} to {}",
|
||||
zed_executable_source.display(),
|
||||
zed_executable_dest.display()
|
||||
);
|
||||
std::fs::copy(&zed_executable_source, &zed_executable_dest)
|
||||
.map(|_| ())
|
||||
.context(format!(
|
||||
"Failed to copy new file {} to {}",
|
||||
zed_executable_source.display(),
|
||||
zed_executable_dest.display()
|
||||
))
|
||||
},
|
||||
|app_dir| {
|
||||
let zed_cli_source = app_dir.join("install\\bin\\zed.exe");
|
||||
let zed_cli_dest = app_dir.join("bin\\zed.exe");
|
||||
log::info!(
|
||||
"Copying new file {} to {}",
|
||||
zed_cli_source.display(),
|
||||
zed_cli_dest.display()
|
||||
);
|
||||
std::fs::copy(&zed_cli_source, &zed_cli_dest)
|
||||
.map(|_| ())
|
||||
.context(format!(
|
||||
"Failed to copy new file {} to {}",
|
||||
zed_cli_source.display(),
|
||||
zed_cli_dest.display()
|
||||
))
|
||||
},
|
||||
// Clean up installer folder and updates folder
|
||||
|app_dir| {
|
||||
let updates_folder = app_dir.join("updates");
|
||||
log::info!("Cleaning up: {}", updates_folder.display());
|
||||
std::fs::remove_dir_all(&updates_folder).context(format!(
|
||||
"Failed to remove updates folder {}",
|
||||
updates_folder.display()
|
||||
))
|
||||
},
|
||||
|app_dir| {
|
||||
let installer_folder = app_dir.join("install");
|
||||
log::info!("Cleaning up: {}", installer_folder.display());
|
||||
std::fs::remove_dir_all(&installer_folder).context(format!(
|
||||
"Failed to remove installer folder {}",
|
||||
installer_folder.display()
|
||||
))
|
||||
},
|
||||
];
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) const JOBS: [Job; 2] = [
|
||||
|_| {
|
||||
std::thread::sleep(Duration::from_millis(1000));
|
||||
if let Ok(config) = std::env::var("ZED_AUTO_UPDATE") {
|
||||
match config.as_str() {
|
||||
"err" => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"Simulated error",
|
||||
))
|
||||
.context("Anyhow!"),
|
||||
_ => panic!("Unknown ZED_AUTO_UPDATE value: {}", config),
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
|_| {
|
||||
std::thread::sleep(Duration::from_millis(1000));
|
||||
if let Ok(config) = std::env::var("ZED_AUTO_UPDATE") {
|
||||
match config.as_str() {
|
||||
"err" => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"Simulated error",
|
||||
))
|
||||
.context("Anyhow!"),
|
||||
_ => panic!("Unknown ZED_AUTO_UPDATE value: {}", config),
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
];
|
||||
|
||||
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>) -> Result<()> {
|
||||
let hwnd = hwnd.map(|ptr| HWND(ptr as _));
|
||||
|
||||
for job in JOBS.iter() {
|
||||
let start = Instant::now();
|
||||
loop {
|
||||
if start.elapsed().as_secs() > 2 {
|
||||
return Err(anyhow::anyhow!("Timed out"));
|
||||
}
|
||||
match (*job)(app_dir) {
|
||||
Ok(_) => {
|
||||
unsafe { PostMessageW(hwnd, WM_JOB_UPDATED, WPARAM(0), LPARAM(0))? };
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
// Check if it's a "not found" error
|
||||
let io_err = err.downcast_ref::<std::io::Error>().unwrap();
|
||||
if io_err.kind() == std::io::ErrorKind::NotFound {
|
||||
log::warn!("File or folder not found.");
|
||||
unsafe { PostMessageW(hwnd, WM_JOB_UPDATED, WPARAM(0), LPARAM(0))? };
|
||||
break;
|
||||
}
|
||||
|
||||
log::error!("Operation failed: {}", err);
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = std::process::Command::new(app_dir.join("Zed.exe"))
|
||||
.creation_flags(CREATE_NEW_PROCESS_GROUP.0)
|
||||
.spawn();
|
||||
log::info!("Update completed successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::perform_update;
|
||||
|
||||
#[test]
|
||||
fn test_perform_update() {
|
||||
let app_dir = std::path::Path::new("C:/");
|
||||
assert!(perform_update(app_dir, None).is_ok());
|
||||
|
||||
// Simulate a timeout
|
||||
unsafe { std::env::set_var("ZED_AUTO_UPDATE", "err") };
|
||||
let ret = perform_update(app_dir, None);
|
||||
assert!(ret.is_err_and(|e| e.to_string().as_str() == "Timed out"));
|
||||
}
|
||||
}
|
||||
@@ -7,8 +7,6 @@ fn main() {
|
||||
|
||||
if cfg!(target_os = "macos") {
|
||||
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
|
||||
// Weakly link ScreenCaptureKit to ensure can be used on macOS 10.15+.
|
||||
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ScreenCaptureKit");
|
||||
}
|
||||
|
||||
// Populate git sha environment variable if git is available
|
||||
|
||||
@@ -18,7 +18,6 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
|
||||
test-support = ["sqlite"]
|
||||
|
||||
[dependencies]
|
||||
anthropic.workspace = true
|
||||
anyhow.workspace = true
|
||||
async-stripe.workspace = true
|
||||
async-tungstenite.workspace = true
|
||||
|
||||
@@ -253,7 +253,6 @@ impl Config {
|
||||
pub enum ServiceMode {
|
||||
Api,
|
||||
Collab,
|
||||
Llm,
|
||||
All,
|
||||
}
|
||||
|
||||
@@ -265,10 +264,6 @@ impl ServiceMode {
|
||||
pub fn is_api(&self) -> bool {
|
||||
matches!(self, Self::Api | Self::All)
|
||||
}
|
||||
|
||||
pub fn is_llm(&self) -> bool {
|
||||
matches!(self, Self::Llm | Self::All)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
|
||||
@@ -1,448 +1,10 @@
|
||||
mod authorization;
|
||||
pub mod db;
|
||||
mod token;
|
||||
|
||||
use crate::api::CloudflareIpCountryHeader;
|
||||
use crate::api::events::SnowflakeRow;
|
||||
use crate::build_kinesis_client;
|
||||
use crate::rpc::MIN_ACCOUNT_AGE_FOR_LLM_USE;
|
||||
use crate::{Cents, Config, Error, Result, db::UserId, executor::Executor};
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use authorization::authorize_access_to_language_model;
|
||||
use axum::routing::get;
|
||||
use axum::{
|
||||
Extension, Json, Router, TypedHeader,
|
||||
body::Body,
|
||||
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
||||
middleware::{self, Next},
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use collections::HashMap;
|
||||
use db::TokenUsage;
|
||||
use db::{ActiveUserCount, LlmDatabase, usage_measure::UsageMeasure};
|
||||
use futures::{Stream, StreamExt as _};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use rpc::{
|
||||
EXPIRED_LLM_TOKEN_HEADER_NAME, LanguageModelProvider, PerformCompletionParams, proto::Plan,
|
||||
};
|
||||
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
|
||||
use serde_json::json;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use strum::IntoEnumIterator;
|
||||
use tokio::sync::RwLock;
|
||||
use util::ResultExt;
|
||||
use crate::Cents;
|
||||
|
||||
pub use token::*;
|
||||
|
||||
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
|
||||
|
||||
pub struct LlmState {
|
||||
pub config: Config,
|
||||
pub executor: Executor,
|
||||
pub db: Arc<LlmDatabase>,
|
||||
pub http_client: ReqwestClient,
|
||||
pub kinesis_client: Option<aws_sdk_kinesis::Client>,
|
||||
active_user_count_by_model:
|
||||
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
|
||||
}
|
||||
|
||||
impl LlmState {
|
||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||
let database_url = config
|
||||
.llm_database_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
|
||||
let max_connections = config
|
||||
.llm_database_max_connections
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
|
||||
|
||||
let mut db_options = db::ConnectOptions::new(database_url);
|
||||
db_options.max_connections(max_connections);
|
||||
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
|
||||
db.initialize().await?;
|
||||
|
||||
let db = Arc::new(db);
|
||||
|
||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||
let http_client =
|
||||
ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
|
||||
|
||||
let this = Self {
|
||||
executor,
|
||||
db,
|
||||
http_client,
|
||||
kinesis_client: if config.kinesis_access_key.is_some() {
|
||||
build_kinesis_client(&config).await.log_err()
|
||||
} else {
|
||||
None
|
||||
},
|
||||
active_user_count_by_model: RwLock::new(HashMap::default()),
|
||||
config,
|
||||
};
|
||||
|
||||
Ok(Arc::new(this))
|
||||
}
|
||||
|
||||
pub async fn get_active_user_count(
|
||||
&self,
|
||||
provider: LanguageModelProvider,
|
||||
model: &str,
|
||||
) -> Result<ActiveUserCount> {
|
||||
let now = Utc::now();
|
||||
|
||||
{
|
||||
let active_user_count_by_model = self.active_user_count_by_model.read().await;
|
||||
if let Some((last_updated, count)) =
|
||||
active_user_count_by_model.get(&(provider, model.to_string()))
|
||||
{
|
||||
if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
|
||||
return Ok(*count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut cache = self.active_user_count_by_model.write().await;
|
||||
let new_count = self.db.get_active_user_count(provider, model, now).await?;
|
||||
cache.insert((provider, model.to_string()), (now, new_count));
|
||||
Ok(new_count)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes() -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/models", get(list_models))
|
||||
.route("/completion", post(perform_completion))
|
||||
.layer(middleware::from_fn(validate_api_token))
|
||||
}
|
||||
|
||||
async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||
let token = req
|
||||
.headers()
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.ok_or_else(|| {
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"missing authorization header".to_string(),
|
||||
)
|
||||
})?
|
||||
.strip_prefix("Bearer ")
|
||||
.ok_or_else(|| {
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid authorization header".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let state = req.extensions().get::<Arc<LlmState>>().unwrap();
|
||||
match LlmTokenClaims::validate(token, &state.config) {
|
||||
Ok(claims) => {
|
||||
if state.db.is_access_token_revoked(&claims.jti).await? {
|
||||
return Err(Error::http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
tracing::Span::current()
|
||||
.record("user_id", claims.user_id)
|
||||
.record("login", claims.github_user_login.clone())
|
||||
.record("authn.jti", &claims.jti)
|
||||
.record("is_staff", claims.is_staff);
|
||||
|
||||
req.extensions_mut().insert(claims);
|
||||
Ok::<_, Error>(next.run(req).await.into_response())
|
||||
}
|
||||
Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized".to_string(),
|
||||
[(
|
||||
HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
|
||||
HeaderValue::from_static("true"),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)),
|
||||
Err(_err) => Err(Error::http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_models(
|
||||
Extension(state): Extension<Arc<LlmState>>,
|
||||
Extension(claims): Extension<LlmTokenClaims>,
|
||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
) -> Result<Json<ListModelsResponse>> {
|
||||
let country_code = country_code_header.map(|header| header.to_string());
|
||||
|
||||
let mut accessible_models = Vec::new();
|
||||
|
||||
for (provider, model) in state.db.all_models() {
|
||||
let authorize_result = authorize_access_to_language_model(
|
||||
&state.config,
|
||||
&claims,
|
||||
country_code.as_deref(),
|
||||
provider,
|
||||
&model.name,
|
||||
);
|
||||
|
||||
if authorize_result.is_ok() {
|
||||
accessible_models.push(rpc::LanguageModel {
|
||||
provider,
|
||||
name: model.name,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(ListModelsResponse {
|
||||
models: accessible_models,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn perform_completion(
|
||||
Extension(state): Extension<Arc<LlmState>>,
|
||||
Extension(claims): Extension<LlmTokenClaims>,
|
||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
Json(params): Json<PerformCompletionParams>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let model = normalize_model_name(
|
||||
state.db.model_names_for_provider(params.provider),
|
||||
params.model,
|
||||
);
|
||||
|
||||
let bypass_account_age_check = claims.has_llm_subscription || claims.bypass_account_age_check;
|
||||
if !bypass_account_age_check {
|
||||
if Utc::now().naive_utc() - claims.account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
|
||||
Err(anyhow!("account too young"))?
|
||||
}
|
||||
}
|
||||
|
||||
authorize_access_to_language_model(
|
||||
&state.config,
|
||||
&claims,
|
||||
country_code_header
|
||||
.map(|header| header.to_string())
|
||||
.as_deref(),
|
||||
params.provider,
|
||||
&model,
|
||||
)?;
|
||||
|
||||
check_usage_limit(&state, params.provider, &model, &claims).await?;
|
||||
|
||||
let stream = match params.provider {
|
||||
LanguageModelProvider::Anthropic => {
|
||||
let api_key = if claims.is_staff {
|
||||
state
|
||||
.config
|
||||
.anthropic_staff_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI staff API key configured on the server")?
|
||||
} else {
|
||||
state
|
||||
.config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?
|
||||
};
|
||||
|
||||
let mut request: anthropic::Request =
|
||||
serde_json::from_str(params.provider_request.get())?;
|
||||
|
||||
// Override the model on the request with the latest version of the model that is
|
||||
// known to the server.
|
||||
//
|
||||
// Right now, we use the version that's defined in `model.id()`, but we will likely
|
||||
// want to change this code once a new version of an Anthropic model is released,
|
||||
// so that users can use the new version, without having to update Zed.
|
||||
request.model = match model.as_str() {
|
||||
"claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
|
||||
"claude-3-7-sonnet" => anthropic::Model::Claude3_7Sonnet.id().to_string(),
|
||||
"claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
|
||||
"claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
|
||||
"claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
|
||||
_ => request.model,
|
||||
};
|
||||
|
||||
let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
|
||||
&state.http_client,
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
request,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
|
||||
Some(anthropic::ApiErrorCode::RateLimitError) => {
|
||||
tracing::info!(
|
||||
target: "upstream rate limit exceeded",
|
||||
user_id = claims.user_id,
|
||||
login = claims.github_user_login,
|
||||
authn.jti = claims.jti,
|
||||
is_staff = claims.is_staff,
|
||||
provider = params.provider.to_string(),
|
||||
model = model
|
||||
);
|
||||
|
||||
Error::http(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
"Upstream Anthropic rate limit exceeded.".to_string(),
|
||||
)
|
||||
}
|
||||
Some(anthropic::ApiErrorCode::InvalidRequestError) => {
|
||||
Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
|
||||
}
|
||||
Some(anthropic::ApiErrorCode::OverloadedError) => {
|
||||
Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
|
||||
}
|
||||
Some(_) => {
|
||||
Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
|
||||
}
|
||||
None => Error::Internal(anyhow!(err)),
|
||||
},
|
||||
anthropic::AnthropicError::Other(err) => Error::Internal(err),
|
||||
})?;
|
||||
|
||||
if let Some(rate_limit_info) = rate_limit_info {
|
||||
tracing::info!(
|
||||
target: "upstream rate limit",
|
||||
is_staff = claims.is_staff,
|
||||
provider = params.provider.to_string(),
|
||||
model = model,
|
||||
tokens_remaining = rate_limit_info.tokens.as_ref().map(|limits| limits.remaining),
|
||||
input_tokens_remaining = rate_limit_info.input_tokens.as_ref().map(|limits| limits.remaining),
|
||||
output_tokens_remaining = rate_limit_info.output_tokens.as_ref().map(|limits| limits.remaining),
|
||||
requests_remaining = rate_limit_info.requests.as_ref().map(|limits| limits.remaining),
|
||||
requests_reset = ?rate_limit_info.requests.as_ref().map(|limits| limits.reset),
|
||||
tokens_reset = ?rate_limit_info.tokens.as_ref().map(|limits| limits.reset),
|
||||
input_tokens_reset = ?rate_limit_info.input_tokens.as_ref().map(|limits| limits.reset),
|
||||
output_tokens_reset = ?rate_limit_info.output_tokens.as_ref().map(|limits| limits.reset),
|
||||
);
|
||||
}
|
||||
|
||||
chunks
|
||||
.map(move |event| {
|
||||
let chunk = event?;
|
||||
let (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_input_tokens,
|
||||
cache_read_input_tokens,
|
||||
) = match &chunk {
|
||||
anthropic::Event::MessageStart {
|
||||
message: anthropic::Response { usage, .. },
|
||||
}
|
||||
| anthropic::Event::MessageDelta { usage, .. } => (
|
||||
usage.input_tokens.unwrap_or(0) as usize,
|
||||
usage.output_tokens.unwrap_or(0) as usize,
|
||||
usage.cache_creation_input_tokens.unwrap_or(0) as usize,
|
||||
usage.cache_read_input_tokens.unwrap_or(0) as usize,
|
||||
),
|
||||
_ => (0, 0, 0, 0),
|
||||
};
|
||||
|
||||
anyhow::Ok(CompletionChunk {
|
||||
bytes: serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_input_tokens,
|
||||
cache_read_input_tokens,
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
LanguageModelProvider::OpenAi => {
|
||||
let api_key = state
|
||||
.config
|
||||
.openai_api_key
|
||||
.as_ref()
|
||||
.context("no OpenAI API key configured on the server")?;
|
||||
let chunks = open_ai::stream_completion(
|
||||
&state.http_client,
|
||||
open_ai::OPEN_AI_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(params.provider_request.get())?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
chunks
|
||||
.map(|event| {
|
||||
event.map(|chunk| {
|
||||
let input_tokens =
|
||||
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
|
||||
let output_tokens =
|
||||
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
|
||||
CompletionChunk {
|
||||
bytes: serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
}
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
LanguageModelProvider::Google => {
|
||||
let api_key = state
|
||||
.config
|
||||
.google_ai_api_key
|
||||
.as_ref()
|
||||
.context("no Google AI API key configured on the server")?;
|
||||
let chunks = google_ai::stream_generate_content(
|
||||
&state.http_client,
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(params.provider_request.get())?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
chunks
|
||||
.map(|event| {
|
||||
event.map(|chunk| {
|
||||
// TODO - implement token counting for Google AI
|
||||
CompletionChunk {
|
||||
bytes: serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
}
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(TokenCountingStream {
|
||||
state,
|
||||
claims,
|
||||
provider: params.provider,
|
||||
model,
|
||||
tokens: TokenUsage::default(),
|
||||
inner_stream: stream,
|
||||
})))
|
||||
}
|
||||
|
||||
fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
|
||||
if let Some(known_model_name) = known_models
|
||||
.iter()
|
||||
.filter(|known_model_name| name.starts_with(known_model_name.as_str()))
|
||||
.max_by_key(|known_model_name| known_model_name.len())
|
||||
{
|
||||
known_model_name.to_string()
|
||||
} else {
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
/// The maximum monthly spending an individual user can reach on the free tier
|
||||
/// before they have to pay.
|
||||
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
|
||||
@@ -452,330 +14,3 @@ pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
|
||||
///
|
||||
/// Used to prevent surprise bills.
|
||||
pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
|
||||
|
||||
async fn check_usage_limit(
|
||||
state: &Arc<LlmState>,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
claims: &LlmTokenClaims,
|
||||
) -> Result<()> {
|
||||
if claims.is_staff {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let user_id = UserId::from_proto(claims.user_id);
|
||||
let model = state.db.model(provider, model_name)?;
|
||||
let free_tier = claims.free_tier_monthly_spending_limit();
|
||||
|
||||
let spending_this_month = state
|
||||
.db
|
||||
.get_user_spending_for_month(user_id, Utc::now())
|
||||
.await?;
|
||||
if spending_this_month >= free_tier {
|
||||
if !claims.has_llm_subscription {
|
||||
return Err(Error::http(
|
||||
StatusCode::PAYMENT_REQUIRED,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let monthly_spend = spending_this_month.saturating_sub(free_tier);
|
||||
if monthly_spend >= Cents(claims.max_monthly_spend_in_cents) {
|
||||
return Err(Error::Http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
[(
|
||||
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
|
||||
HeaderValue::from_static("true"),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let active_users = state.get_active_user_count(provider, model_name).await?;
|
||||
|
||||
let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
|
||||
let users_in_recent_days = active_users.users_in_recent_days.max(1);
|
||||
|
||||
let per_user_max_requests_per_minute =
|
||||
model.max_requests_per_minute as usize / users_in_recent_minutes;
|
||||
let per_user_max_tokens_per_minute =
|
||||
model.max_tokens_per_minute as usize / users_in_recent_minutes;
|
||||
let per_user_max_input_tokens_per_minute =
|
||||
model.max_input_tokens_per_minute as usize / users_in_recent_minutes;
|
||||
let per_user_max_output_tokens_per_minute =
|
||||
model.max_output_tokens_per_minute as usize / users_in_recent_minutes;
|
||||
let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
|
||||
|
||||
let usage = state
|
||||
.db
|
||||
.get_usage(user_id, provider, model_name, Utc::now())
|
||||
.await?;
|
||||
|
||||
let checks = match (provider, model_name) {
|
||||
(LanguageModelProvider::Anthropic, "claude-3-7-sonnet") => vec![
|
||||
(
|
||||
usage.requests_this_minute,
|
||||
per_user_max_requests_per_minute,
|
||||
UsageMeasure::RequestsPerMinute,
|
||||
),
|
||||
(
|
||||
usage.input_tokens_this_minute,
|
||||
per_user_max_tokens_per_minute,
|
||||
UsageMeasure::InputTokensPerMinute,
|
||||
),
|
||||
(
|
||||
usage.output_tokens_this_minute,
|
||||
per_user_max_tokens_per_minute,
|
||||
UsageMeasure::OutputTokensPerMinute,
|
||||
),
|
||||
(
|
||||
usage.tokens_this_day,
|
||||
per_user_max_tokens_per_day,
|
||||
UsageMeasure::TokensPerDay,
|
||||
),
|
||||
],
|
||||
_ => vec![
|
||||
(
|
||||
usage.requests_this_minute,
|
||||
per_user_max_requests_per_minute,
|
||||
UsageMeasure::RequestsPerMinute,
|
||||
),
|
||||
(
|
||||
usage.tokens_this_minute,
|
||||
per_user_max_tokens_per_minute,
|
||||
UsageMeasure::TokensPerMinute,
|
||||
),
|
||||
(
|
||||
usage.tokens_this_day,
|
||||
per_user_max_tokens_per_day,
|
||||
UsageMeasure::TokensPerDay,
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
for (used, limit, usage_measure) in checks {
|
||||
if used > limit {
|
||||
let resource = match usage_measure {
|
||||
UsageMeasure::RequestsPerMinute => "requests_per_minute",
|
||||
UsageMeasure::TokensPerMinute => "tokens_per_minute",
|
||||
UsageMeasure::InputTokensPerMinute => "input_tokens_per_minute",
|
||||
UsageMeasure::OutputTokensPerMinute => "output_tokens_per_minute",
|
||||
UsageMeasure::TokensPerDay => "tokens_per_day",
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
target: "user rate limit",
|
||||
user_id = claims.user_id,
|
||||
login = claims.github_user_login,
|
||||
authn.jti = claims.jti,
|
||||
is_staff = claims.is_staff,
|
||||
provider = provider.to_string(),
|
||||
model = model.name,
|
||||
usage_measure = resource,
|
||||
requests_this_minute = usage.requests_this_minute,
|
||||
tokens_this_minute = usage.tokens_this_minute,
|
||||
input_tokens_this_minute = usage.input_tokens_this_minute,
|
||||
output_tokens_this_minute = usage.output_tokens_this_minute,
|
||||
tokens_this_day = usage.tokens_this_day,
|
||||
users_in_recent_minutes = users_in_recent_minutes,
|
||||
users_in_recent_days = users_in_recent_days,
|
||||
max_requests_per_minute = per_user_max_requests_per_minute,
|
||||
max_tokens_per_minute = per_user_max_tokens_per_minute,
|
||||
max_input_tokens_per_minute = per_user_max_input_tokens_per_minute,
|
||||
max_output_tokens_per_minute = per_user_max_output_tokens_per_minute,
|
||||
max_tokens_per_day = per_user_max_tokens_per_day,
|
||||
);
|
||||
|
||||
SnowflakeRow::new(
|
||||
"Language Model Rate Limited",
|
||||
Some(claims.metrics_id),
|
||||
claims.is_staff,
|
||||
claims.system_id.clone(),
|
||||
json!({
|
||||
"usage": usage,
|
||||
"users_in_recent_minutes": users_in_recent_minutes,
|
||||
"users_in_recent_days": users_in_recent_days,
|
||||
"max_requests_per_minute": per_user_max_requests_per_minute,
|
||||
"max_tokens_per_minute": per_user_max_tokens_per_minute,
|
||||
"max_input_tokens_per_minute": per_user_max_input_tokens_per_minute,
|
||||
"max_output_tokens_per_minute": per_user_max_output_tokens_per_minute,
|
||||
"max_tokens_per_day": per_user_max_tokens_per_day,
|
||||
"plan": match claims.plan {
|
||||
Plan::Free => "free".to_string(),
|
||||
Plan::ZedPro => "zed_pro".to_string(),
|
||||
},
|
||||
"model": model.name.clone(),
|
||||
"provider": provider.to_string(),
|
||||
"usage_measure": resource.to_string(),
|
||||
}),
|
||||
)
|
||||
.write(&state.kinesis_client, &state.config.kinesis_stream)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
return Err(Error::http(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
format!("Rate limit exceeded. Maximum {} reached.", resource),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CompletionChunk {
|
||||
bytes: Vec<u8>,
|
||||
input_tokens: usize,
|
||||
output_tokens: usize,
|
||||
cache_creation_input_tokens: usize,
|
||||
cache_read_input_tokens: usize,
|
||||
}
|
||||
|
||||
struct TokenCountingStream<S> {
|
||||
state: Arc<LlmState>,
|
||||
claims: LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: String,
|
||||
tokens: TokenUsage,
|
||||
inner_stream: S,
|
||||
}
|
||||
|
||||
impl<S> Stream for TokenCountingStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
|
||||
{
|
||||
type Item = Result<Vec<u8>, anyhow::Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match Pin::new(&mut self.inner_stream).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(mut chunk))) => {
|
||||
chunk.bytes.push(b'\n');
|
||||
self.tokens.input += chunk.input_tokens;
|
||||
self.tokens.output += chunk.output_tokens;
|
||||
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
|
||||
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
|
||||
Poll::Ready(Some(Ok(chunk.bytes)))
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Drop for TokenCountingStream<S> {
|
||||
fn drop(&mut self) {
|
||||
let state = self.state.clone();
|
||||
let claims = self.claims.clone();
|
||||
let provider = self.provider;
|
||||
let model = std::mem::take(&mut self.model);
|
||||
let tokens = self.tokens;
|
||||
self.state.executor.spawn_detached(async move {
|
||||
let usage = state
|
||||
.db
|
||||
.record_usage(
|
||||
UserId::from_proto(claims.user_id),
|
||||
claims.is_staff,
|
||||
provider,
|
||||
&model,
|
||||
tokens,
|
||||
claims.has_llm_subscription,
|
||||
Cents(claims.max_monthly_spend_in_cents),
|
||||
claims.free_tier_monthly_spending_limit(),
|
||||
Utc::now(),
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
if let Some(usage) = usage {
|
||||
tracing::info!(
|
||||
target: "user usage",
|
||||
user_id = claims.user_id,
|
||||
login = claims.github_user_login,
|
||||
authn.jti = claims.jti,
|
||||
is_staff = claims.is_staff,
|
||||
provider = provider.to_string(),
|
||||
model = model,
|
||||
requests_this_minute = usage.requests_this_minute,
|
||||
tokens_this_minute = usage.tokens_this_minute,
|
||||
input_tokens_this_minute = usage.input_tokens_this_minute,
|
||||
output_tokens_this_minute = usage.output_tokens_this_minute,
|
||||
);
|
||||
|
||||
let properties = json!({
|
||||
"has_llm_subscription": claims.has_llm_subscription,
|
||||
"max_monthly_spend_in_cents": claims.max_monthly_spend_in_cents,
|
||||
"plan": match claims.plan {
|
||||
Plan::Free => "free".to_string(),
|
||||
Plan::ZedPro => "zed_pro".to_string(),
|
||||
},
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
"usage": usage,
|
||||
"tokens": tokens
|
||||
});
|
||||
SnowflakeRow::new(
|
||||
"Language Model Used",
|
||||
Some(claims.metrics_id),
|
||||
claims.is_staff,
|
||||
claims.system_id.clone(),
|
||||
properties,
|
||||
)
|
||||
.write(&state.kinesis_client, &state.config.kinesis_stream)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn log_usage_periodically(state: Arc<LlmState>) {
|
||||
state.executor.clone().spawn_detached(async move {
|
||||
loop {
|
||||
state
|
||||
.executor
|
||||
.sleep(std::time::Duration::from_secs(30))
|
||||
.await;
|
||||
|
||||
for provider in LanguageModelProvider::iter() {
|
||||
for model in state.db.model_names_for_provider(provider) {
|
||||
if let Some(active_user_count) = state
|
||||
.get_active_user_count(provider, &model)
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
tracing::info!(
|
||||
target: "active user counts",
|
||||
provider = provider.to_string(),
|
||||
model = model,
|
||||
users_in_recent_minutes = active_user_count.users_in_recent_minutes,
|
||||
users_in_recent_days = active_user_count.users_in_recent_days,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(usages) = state
|
||||
.db
|
||||
.get_application_wide_usages_by_model(Utc::now())
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
for usage in usages {
|
||||
tracing::info!(
|
||||
target: "computed usage",
|
||||
provider = usage.provider.to_string(),
|
||||
model = usage.model,
|
||||
requests_this_minute = usage.requests_this_minute,
|
||||
tokens_this_minute = usage.tokens_this_minute,
|
||||
input_tokens_this_minute = usage.input_tokens_this_minute,
|
||||
output_tokens_this_minute = usage.output_tokens_this_minute,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
use reqwest::StatusCode;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
use crate::llm::LlmTokenClaims;
|
||||
use crate::{Config, Error, Result};
|
||||
|
||||
pub fn authorize_access_to_language_model(
|
||||
config: &Config,
|
||||
claims: &LlmTokenClaims,
|
||||
country_code: Option<&str>,
|
||||
provider: LanguageModelProvider,
|
||||
model: &str,
|
||||
) -> Result<()> {
|
||||
authorize_access_for_country(config, country_code, provider)?;
|
||||
authorize_access_to_model(config, claims, provider, model)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn authorize_access_to_model(
|
||||
config: &Config,
|
||||
claims: &LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: &str,
|
||||
) -> Result<()> {
|
||||
if claims.is_staff {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if provider == LanguageModelProvider::Anthropic {
|
||||
if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if claims.has_llm_closed_beta_feature_flag
|
||||
&& Some(model) == config.llm_closed_beta_model_name.as_deref()
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
format!("access to model {model:?} is not included in your plan"),
|
||||
))
|
||||
}
|
||||
|
||||
fn authorize_access_for_country(
|
||||
config: &Config,
|
||||
country_code: Option<&str>,
|
||||
provider: LanguageModelProvider,
|
||||
) -> Result<()> {
|
||||
// In development we won't have the `CF-IPCountry` header, so we can't check
|
||||
// the country code.
|
||||
//
|
||||
// This shouldn't be necessary, as anyone running in development will need to provide
|
||||
// their own API credentials in order to use an LLM provider.
|
||||
if config.is_development() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
|
||||
let country_code = match country_code {
|
||||
// `XX` - Used for clients without country code data.
|
||||
None | Some("XX") => Err(Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"no country code".to_string(),
|
||||
))?,
|
||||
// `T1` - Used for clients using the Tor network.
|
||||
Some("T1") => Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
format!("access to {provider:?} models is not available over Tor"),
|
||||
))?,
|
||||
Some(country_code) => country_code,
|
||||
};
|
||||
|
||||
let is_country_supported_by_provider = match provider {
|
||||
LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
|
||||
LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
|
||||
LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
|
||||
};
|
||||
if !is_country_supported_by_provider {
|
||||
Err(Error::http(
|
||||
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
|
||||
format!(
|
||||
"access to {provider:?} models is not available in your region ({country_code})"
|
||||
),
|
||||
))?
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum::response::IntoResponse;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::proto::Plan;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_with_supported_country(
|
||||
_cx: &mut gpui::TestAppContext,
|
||||
) {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
user_id: 99,
|
||||
plan: Plan::ZedPro,
|
||||
is_staff: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "US"), // United States
|
||||
(LanguageModelProvider::Anthropic, "GB"), // United Kingdom
|
||||
(LanguageModelProvider::OpenAi, "US"), // United States
|
||||
(LanguageModelProvider::OpenAi, "GB"), // United Kingdom
|
||||
(LanguageModelProvider::Google, "US"), // United States
|
||||
(LanguageModelProvider::Google, "GB"), // United Kingdom
|
||||
];
|
||||
|
||||
for (provider, country_code) in cases {
|
||||
authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some(country_code),
|
||||
provider,
|
||||
"the-model",
|
||||
)
|
||||
.unwrap_or_else(|_| {
|
||||
panic!("expected authorization to return Ok for {provider:?}: {country_code}")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_with_unsupported_country(
|
||||
_cx: &mut gpui::TestAppContext,
|
||||
) {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
user_id: 99,
|
||||
plan: Plan::ZedPro,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "AF"), // Afghanistan
|
||||
(LanguageModelProvider::Anthropic, "BY"), // Belarus
|
||||
(LanguageModelProvider::Anthropic, "CF"), // Central African Republic
|
||||
(LanguageModelProvider::Anthropic, "CN"), // China
|
||||
(LanguageModelProvider::Anthropic, "CU"), // Cuba
|
||||
(LanguageModelProvider::Anthropic, "ER"), // Eritrea
|
||||
(LanguageModelProvider::Anthropic, "ET"), // Ethiopia
|
||||
(LanguageModelProvider::Anthropic, "IR"), // Iran
|
||||
(LanguageModelProvider::Anthropic, "KP"), // North Korea
|
||||
(LanguageModelProvider::Anthropic, "XK"), // Kosovo
|
||||
(LanguageModelProvider::Anthropic, "LY"), // Libya
|
||||
(LanguageModelProvider::Anthropic, "MM"), // Myanmar
|
||||
(LanguageModelProvider::Anthropic, "RU"), // Russia
|
||||
(LanguageModelProvider::Anthropic, "SO"), // Somalia
|
||||
(LanguageModelProvider::Anthropic, "SS"), // South Sudan
|
||||
(LanguageModelProvider::Anthropic, "SD"), // Sudan
|
||||
(LanguageModelProvider::Anthropic, "SY"), // Syria
|
||||
(LanguageModelProvider::Anthropic, "VE"), // Venezuela
|
||||
(LanguageModelProvider::Anthropic, "YE"), // Yemen
|
||||
(LanguageModelProvider::OpenAi, "KP"), // North Korea
|
||||
(LanguageModelProvider::Google, "KP"), // North Korea
|
||||
];
|
||||
|
||||
for (provider, country_code) in cases {
|
||||
let error_response = authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some(country_code),
|
||||
provider,
|
||||
"the-model",
|
||||
)
|
||||
.expect_err(&format!(
|
||||
"expected authorization to return an error for {provider:?}: {country_code}"
|
||||
))
|
||||
.into_response();
|
||||
|
||||
assert_eq!(
|
||||
error_response.status(),
|
||||
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
|
||||
);
|
||||
let response_body = hyper::body::to_bytes(error_response.into_body())
|
||||
.await
|
||||
.unwrap()
|
||||
.to_vec();
|
||||
assert_eq!(
|
||||
String::from_utf8(response_body).unwrap(),
|
||||
format!(
|
||||
"access to {provider:?} models is not available in your region ({country_code})"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
user_id: 99,
|
||||
plan: Plan::ZedPro,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "T1"), // Tor
|
||||
(LanguageModelProvider::OpenAi, "T1"), // Tor
|
||||
(LanguageModelProvider::Google, "T1"), // Tor
|
||||
];
|
||||
|
||||
for (provider, country_code) in cases {
|
||||
let error_response = authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some(country_code),
|
||||
provider,
|
||||
"the-model",
|
||||
)
|
||||
.expect_err(&format!(
|
||||
"expected authorization to return an error for {provider:?}: {country_code}"
|
||||
))
|
||||
.into_response();
|
||||
|
||||
assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
|
||||
let response_body = hyper::body::to_bytes(error_response.into_body())
|
||||
.await
|
||||
.unwrap()
|
||||
.to_vec();
|
||||
assert_eq!(
|
||||
String::from_utf8(response_body).unwrap(),
|
||||
format!("access to {provider:?} models is not available over Tor")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_based_on_plan() {
|
||||
let config = Config::test();
|
||||
|
||||
let test_cases = vec![
|
||||
// Pro plan should have access to claude-3.5-sonnet
|
||||
(
|
||||
Plan::ZedPro,
|
||||
LanguageModelProvider::Anthropic,
|
||||
"claude-3-5-sonnet",
|
||||
true,
|
||||
),
|
||||
// Free plan should have access to claude-3.5-sonnet
|
||||
(
|
||||
Plan::Free,
|
||||
LanguageModelProvider::Anthropic,
|
||||
"claude-3-5-sonnet",
|
||||
true,
|
||||
),
|
||||
// Pro plan should NOT have access to other Anthropic models
|
||||
(
|
||||
Plan::ZedPro,
|
||||
LanguageModelProvider::Anthropic,
|
||||
"claude-3-opus",
|
||||
false,
|
||||
),
|
||||
];
|
||||
|
||||
for (plan, provider, model, expected_access) in test_cases {
|
||||
let claims = LlmTokenClaims {
|
||||
plan,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result =
|
||||
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
|
||||
|
||||
if expected_access {
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Expected access to be granted for plan {:?}, provider {:?}, model {}",
|
||||
plan,
|
||||
provider,
|
||||
model
|
||||
);
|
||||
} else {
|
||||
let error = result.expect_err(&format!(
|
||||
"Expected access to be denied for plan {:?}, provider {:?}, model {}",
|
||||
plan, provider, model
|
||||
));
|
||||
let response = error.into_response();
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_for_staff() {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
is_staff: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Staff should have access to all models
|
||||
let test_cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
|
||||
(LanguageModelProvider::Anthropic, "claude-2"),
|
||||
(LanguageModelProvider::Anthropic, "claude-123-agi"),
|
||||
(LanguageModelProvider::OpenAi, "gpt-4"),
|
||||
(LanguageModelProvider::Google, "gemini-pro"),
|
||||
];
|
||||
|
||||
for (provider, model) in test_cases {
|
||||
let result =
|
||||
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Expected staff to have access to provider {:?}, model {}",
|
||||
provider,
|
||||
model
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,6 @@ use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
pub use queries::usages::{ActiveUserCount, TokenUsage};
|
||||
pub use sea_orm::ConnectOptions;
|
||||
use sea_orm::prelude::*;
|
||||
use sea_orm::{
|
||||
|
||||
@@ -2,5 +2,4 @@ use super::*;
|
||||
|
||||
pub mod billing_events;
|
||||
pub mod providers;
|
||||
pub mod revoked_access_tokens;
|
||||
pub mod usages;
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
use super::*;
|
||||
|
||||
impl LlmDatabase {
|
||||
/// Returns whether the access token with the given `jti` has been revoked.
|
||||
pub async fn is_access_token_revoked(&self, jti: &str) -> Result<bool> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(revoked_access_token::Entity::find()
|
||||
.filter(revoked_access_token::Column::Jti.eq(jti))
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.is_some())
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,56 +1,12 @@
|
||||
use crate::db::UserId;
|
||||
use crate::llm::Cents;
|
||||
use chrono::{Datelike, Duration};
|
||||
use chrono::Datelike;
|
||||
use futures::StreamExt as _;
|
||||
use rpc::LanguageModelProvider;
|
||||
use sea_orm::QuerySelect;
|
||||
use std::{iter, str::FromStr};
|
||||
use std::str::FromStr;
|
||||
use strum::IntoEnumIterator as _;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
|
||||
pub struct TokenUsage {
|
||||
pub input: usize,
|
||||
pub input_cache_creation: usize,
|
||||
pub input_cache_read: usize,
|
||||
pub output: usize,
|
||||
}
|
||||
|
||||
impl TokenUsage {
|
||||
pub fn total(&self) -> usize {
|
||||
self.input + self.input_cache_creation + self.input_cache_read + self.output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
|
||||
pub struct Usage {
|
||||
pub requests_this_minute: usize,
|
||||
pub tokens_this_minute: usize,
|
||||
pub input_tokens_this_minute: usize,
|
||||
pub output_tokens_this_minute: usize,
|
||||
pub tokens_this_day: usize,
|
||||
pub tokens_this_month: TokenUsage,
|
||||
pub spending_this_month: Cents,
|
||||
pub lifetime_spending: Cents,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct ApplicationWideUsage {
|
||||
pub provider: LanguageModelProvider,
|
||||
pub model: String,
|
||||
pub requests_this_minute: usize,
|
||||
pub tokens_this_minute: usize,
|
||||
pub input_tokens_this_minute: usize,
|
||||
pub output_tokens_this_minute: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct ActiveUserCount {
|
||||
pub users_in_recent_minutes: usize,
|
||||
pub users_in_recent_days: usize,
|
||||
}
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn initialize_usage_measures(&mut self) -> Result<()> {
|
||||
let all_measures = self
|
||||
@@ -90,100 +46,6 @@ impl LlmDatabase {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_application_wide_usages_by_model(
|
||||
&self,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Vec<ApplicationWideUsage>> {
|
||||
self.transaction(|tx| async move {
|
||||
let past_minute = now - Duration::minutes(1);
|
||||
let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
|
||||
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
|
||||
let input_tokens_per_minute =
|
||||
self.usage_measure_ids[&UsageMeasure::InputTokensPerMinute];
|
||||
let output_tokens_per_minute =
|
||||
self.usage_measure_ids[&UsageMeasure::OutputTokensPerMinute];
|
||||
|
||||
let mut results = Vec::new();
|
||||
for ((provider, model_name), model) in self.models.iter() {
|
||||
let mut usages = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::Timestamp
|
||||
.gte(past_minute.naive_utc())
|
||||
.and(usage::Column::IsStaff.eq(false))
|
||||
.and(usage::Column::ModelId.eq(model.id))
|
||||
.and(
|
||||
usage::Column::MeasureId
|
||||
.eq(requests_per_minute)
|
||||
.or(usage::Column::MeasureId.eq(tokens_per_minute)),
|
||||
),
|
||||
)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut requests_this_minute = 0;
|
||||
let mut tokens_this_minute = 0;
|
||||
let mut input_tokens_this_minute = 0;
|
||||
let mut output_tokens_this_minute = 0;
|
||||
while let Some(usage) = usages.next().await {
|
||||
let usage = usage?;
|
||||
if usage.measure_id == requests_per_minute {
|
||||
requests_this_minute += Self::get_live_buckets(
|
||||
&usage,
|
||||
now.naive_utc(),
|
||||
UsageMeasure::RequestsPerMinute,
|
||||
)
|
||||
.0
|
||||
.iter()
|
||||
.copied()
|
||||
.sum::<i64>() as usize;
|
||||
} else if usage.measure_id == tokens_per_minute {
|
||||
tokens_this_minute += Self::get_live_buckets(
|
||||
&usage,
|
||||
now.naive_utc(),
|
||||
UsageMeasure::TokensPerMinute,
|
||||
)
|
||||
.0
|
||||
.iter()
|
||||
.copied()
|
||||
.sum::<i64>() as usize;
|
||||
} else if usage.measure_id == input_tokens_per_minute {
|
||||
input_tokens_this_minute += Self::get_live_buckets(
|
||||
&usage,
|
||||
now.naive_utc(),
|
||||
UsageMeasure::InputTokensPerMinute,
|
||||
)
|
||||
.0
|
||||
.iter()
|
||||
.copied()
|
||||
.sum::<i64>() as usize;
|
||||
} else if usage.measure_id == output_tokens_per_minute {
|
||||
output_tokens_this_minute += Self::get_live_buckets(
|
||||
&usage,
|
||||
now.naive_utc(),
|
||||
UsageMeasure::OutputTokensPerMinute,
|
||||
)
|
||||
.0
|
||||
.iter()
|
||||
.copied()
|
||||
.sum::<i64>() as usize;
|
||||
}
|
||||
}
|
||||
|
||||
results.push(ApplicationWideUsage {
|
||||
provider: *provider,
|
||||
model: model_name.clone(),
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
input_tokens_this_minute,
|
||||
output_tokens_this_minute,
|
||||
})
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_user_spending_for_month(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
@@ -223,499 +85,6 @@ impl LlmDatabase {
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_usage(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Usage> {
|
||||
self.transaction(|tx| async move {
|
||||
let model = self
|
||||
.models
|
||||
.get(&(provider, model_name.to_string()))
|
||||
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
|
||||
|
||||
let usages = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let month = now.date_naive().month() as i32;
|
||||
let year = now.date_naive().year();
|
||||
let monthly_usage = monthly_usage::Entity::find()
|
||||
.filter(
|
||||
monthly_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(monthly_usage::Column::ModelId.eq(model.id))
|
||||
.and(monthly_usage::Column::Month.eq(month))
|
||||
.and(monthly_usage::Column::Year.eq(year)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
let lifetime_usage = lifetime_usage::Entity::find()
|
||||
.filter(
|
||||
lifetime_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(lifetime_usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
let requests_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
|
||||
let tokens_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
|
||||
let input_tokens_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMinute)?;
|
||||
let output_tokens_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMinute)?;
|
||||
let tokens_this_day =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
|
||||
let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
|
||||
calculate_spending(
|
||||
model,
|
||||
monthly_usage.input_tokens as usize,
|
||||
monthly_usage.cache_creation_input_tokens as usize,
|
||||
monthly_usage.cache_read_input_tokens as usize,
|
||||
monthly_usage.output_tokens as usize,
|
||||
)
|
||||
} else {
|
||||
Cents::ZERO
|
||||
};
|
||||
let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
|
||||
calculate_spending(
|
||||
model,
|
||||
lifetime_usage.input_tokens as usize,
|
||||
lifetime_usage.cache_creation_input_tokens as usize,
|
||||
lifetime_usage.cache_read_input_tokens as usize,
|
||||
lifetime_usage.output_tokens as usize,
|
||||
)
|
||||
} else {
|
||||
Cents::ZERO
|
||||
};
|
||||
|
||||
Ok(Usage {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
input_tokens_this_minute,
|
||||
output_tokens_this_minute,
|
||||
tokens_this_day,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.input_tokens as usize),
|
||||
input_cache_creation: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
|
||||
input_cache_read: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
|
||||
output: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.output_tokens as usize),
|
||||
},
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn record_usage(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
is_staff: bool,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
tokens: TokenUsage,
|
||||
has_llm_subscription: bool,
|
||||
max_monthly_spend: Cents,
|
||||
free_tier_monthly_spending_limit: Cents,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Usage> {
|
||||
self.transaction(|tx| async move {
|
||||
let model = self.model(provider, model_name)?;
|
||||
|
||||
let usages = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let requests_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::RequestsPerMinute,
|
||||
now,
|
||||
1,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let tokens_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::TokensPerMinute,
|
||||
now,
|
||||
tokens.total(),
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let input_tokens_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::InputTokensPerMinute,
|
||||
now,
|
||||
// Cache read input tokens are not counted for the purposes of rate limits (but they are still billed).
|
||||
tokens.input + tokens.input_cache_creation,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let output_tokens_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::OutputTokensPerMinute,
|
||||
now,
|
||||
tokens.output,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let tokens_this_day = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::TokensPerDay,
|
||||
now,
|
||||
tokens.total(),
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let month = now.date_naive().month() as i32;
|
||||
let year = now.date_naive().year();
|
||||
|
||||
// Update monthly usage
|
||||
let monthly_usage = monthly_usage::Entity::find()
|
||||
.filter(
|
||||
monthly_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(monthly_usage::Column::ModelId.eq(model.id))
|
||||
.and(monthly_usage::Column::Month.eq(month))
|
||||
.and(monthly_usage::Column::Year.eq(year)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
let monthly_usage = match monthly_usage {
|
||||
Some(usage) => {
|
||||
monthly_usage::Entity::update(monthly_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(
|
||||
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
monthly_usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
month: ActiveValue::set(month),
|
||||
year: ActiveValue::set(year),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let spending_this_month = calculate_spending(
|
||||
model,
|
||||
monthly_usage.input_tokens as usize,
|
||||
monthly_usage.cache_creation_input_tokens as usize,
|
||||
monthly_usage.cache_read_input_tokens as usize,
|
||||
monthly_usage.output_tokens as usize,
|
||||
);
|
||||
|
||||
if !is_staff
|
||||
&& spending_this_month > free_tier_monthly_spending_limit
|
||||
&& has_llm_subscription
|
||||
&& (spending_this_month - free_tier_monthly_spending_limit) <= max_monthly_spend
|
||||
{
|
||||
billing_event::ActiveModel {
|
||||
id: ActiveValue::not_set(),
|
||||
idempotency_key: ActiveValue::not_set(),
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
input_cache_creation_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Update lifetime usage
|
||||
let lifetime_usage = lifetime_usage::Entity::find()
|
||||
.filter(
|
||||
lifetime_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(lifetime_usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
let lifetime_usage = match lifetime_usage {
|
||||
Some(usage) => {
|
||||
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(
|
||||
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
lifetime_usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let lifetime_spending = calculate_spending(
|
||||
model,
|
||||
lifetime_usage.input_tokens as usize,
|
||||
lifetime_usage.cache_creation_input_tokens as usize,
|
||||
lifetime_usage.cache_read_input_tokens as usize,
|
||||
lifetime_usage.output_tokens as usize,
|
||||
);
|
||||
|
||||
Ok(Usage {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
input_tokens_this_minute,
|
||||
output_tokens_this_minute,
|
||||
tokens_this_day,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: monthly_usage.input_tokens as usize,
|
||||
input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
|
||||
input_cache_read: monthly_usage.cache_read_input_tokens as usize,
|
||||
output: monthly_usage.output_tokens as usize,
|
||||
},
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the active user count for the specified model.
|
||||
pub async fn get_active_user_count(
|
||||
&self,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<ActiveUserCount> {
|
||||
self.transaction(|tx| async move {
|
||||
let minute_since = now - Duration::minutes(5);
|
||||
let day_since = now - Duration::days(5);
|
||||
|
||||
let model = self
|
||||
.models
|
||||
.get(&(provider, model_name.to_string()))
|
||||
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
|
||||
|
||||
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
|
||||
|
||||
let users_in_recent_minutes = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::ModelId
|
||||
.eq(model.id)
|
||||
.and(usage::Column::MeasureId.eq(tokens_per_minute))
|
||||
.and(usage::Column::Timestamp.gte(minute_since.naive_utc()))
|
||||
.and(usage::Column::IsStaff.eq(false)),
|
||||
)
|
||||
.select_only()
|
||||
.column(usage::Column::UserId)
|
||||
.group_by(usage::Column::UserId)
|
||||
.count(&*tx)
|
||||
.await? as usize;
|
||||
|
||||
let users_in_recent_days = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::ModelId
|
||||
.eq(model.id)
|
||||
.and(usage::Column::MeasureId.eq(tokens_per_minute))
|
||||
.and(usage::Column::Timestamp.gte(day_since.naive_utc()))
|
||||
.and(usage::Column::IsStaff.eq(false)),
|
||||
)
|
||||
.select_only()
|
||||
.column(usage::Column::UserId)
|
||||
.group_by(usage::Column::UserId)
|
||||
.count(&*tx)
|
||||
.await? as usize;
|
||||
|
||||
Ok(ActiveUserCount {
|
||||
users_in_recent_minutes,
|
||||
users_in_recent_days,
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn update_usage_for_measure(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
is_staff: bool,
|
||||
model_id: ModelId,
|
||||
usages: &[usage::Model],
|
||||
usage_measure: UsageMeasure,
|
||||
now: DateTimeUtc,
|
||||
usage_to_add: usize,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<usize> {
|
||||
let now = now.naive_utc();
|
||||
let measure_id = *self
|
||||
.usage_measure_ids
|
||||
.get(&usage_measure)
|
||||
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
|
||||
|
||||
let mut id = None;
|
||||
let mut timestamp = now;
|
||||
let mut buckets = vec![0_i64];
|
||||
|
||||
if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
|
||||
id = Some(old_usage.id);
|
||||
let (live_buckets, buckets_since) =
|
||||
Self::get_live_buckets(old_usage, now, usage_measure);
|
||||
if !live_buckets.is_empty() {
|
||||
buckets.clear();
|
||||
buckets.extend_from_slice(live_buckets);
|
||||
buckets.extend(iter::repeat(0).take(buckets_since));
|
||||
timestamp =
|
||||
old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
|
||||
}
|
||||
}
|
||||
|
||||
*buckets.last_mut().unwrap() += usage_to_add as i64;
|
||||
let total_usage = buckets.iter().sum::<i64>() as usize;
|
||||
|
||||
let mut model = usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
is_staff: ActiveValue::set(is_staff),
|
||||
model_id: ActiveValue::set(model_id),
|
||||
measure_id: ActiveValue::set(measure_id),
|
||||
timestamp: ActiveValue::set(timestamp),
|
||||
buckets: ActiveValue::set(buckets),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some(id) = id {
|
||||
model.id = ActiveValue::unchanged(id);
|
||||
model.update(tx).await?;
|
||||
} else {
|
||||
usage::Entity::insert(model)
|
||||
.exec_without_returning(tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(total_usage)
|
||||
}
|
||||
|
||||
fn get_usage_for_measure(
|
||||
&self,
|
||||
usages: &[usage::Model],
|
||||
now: DateTimeUtc,
|
||||
usage_measure: UsageMeasure,
|
||||
) -> Result<usize> {
|
||||
let now = now.naive_utc();
|
||||
let measure_id = *self
|
||||
.usage_measure_ids
|
||||
.get(&usage_measure)
|
||||
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
|
||||
let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
|
||||
return Ok(0);
|
||||
};
|
||||
|
||||
let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
|
||||
Ok(live_buckets.iter().sum::<i64>() as _)
|
||||
}
|
||||
|
||||
fn get_live_buckets(
|
||||
usage: &usage::Model,
|
||||
now: chrono::NaiveDateTime,
|
||||
measure: UsageMeasure,
|
||||
) -> (&[i64], usize) {
|
||||
let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
|
||||
let buckets_since_usage =
|
||||
seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
|
||||
let buckets_since_usage = buckets_since_usage.ceil() as usize;
|
||||
let mut live_buckets = &[] as &[i64];
|
||||
if buckets_since_usage < measure.bucket_count() {
|
||||
let expired_bucket_count =
|
||||
(usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
|
||||
live_buckets = &usage.buckets[expired_bucket_count..];
|
||||
while live_buckets.first() == Some(&0) {
|
||||
live_buckets = &live_buckets[1..];
|
||||
}
|
||||
}
|
||||
(live_buckets, buckets_since_usage)
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_spending(
|
||||
@@ -741,32 +110,3 @@ fn calculate_spending(
|
||||
+ output_token_cost;
|
||||
Cents::new(spending as u32)
|
||||
}
|
||||
|
||||
const MINUTE_BUCKET_COUNT: usize = 12;
|
||||
const DAY_BUCKET_COUNT: usize = 48;
|
||||
|
||||
impl UsageMeasure {
|
||||
fn bucket_count(&self) -> usize {
|
||||
match self {
|
||||
UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
|
||||
UsageMeasure::TokensPerMinute
|
||||
| UsageMeasure::InputTokensPerMinute
|
||||
| UsageMeasure::OutputTokensPerMinute => MINUTE_BUCKET_COUNT,
|
||||
UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
|
||||
}
|
||||
}
|
||||
|
||||
fn total_duration(&self) -> Duration {
|
||||
match self {
|
||||
UsageMeasure::RequestsPerMinute => Duration::minutes(1),
|
||||
UsageMeasure::TokensPerMinute
|
||||
| UsageMeasure::InputTokensPerMinute
|
||||
| UsageMeasure::OutputTokensPerMinute => Duration::minutes(1),
|
||||
UsageMeasure::TokensPerDay => Duration::hours(24),
|
||||
}
|
||||
}
|
||||
|
||||
fn bucket_duration(&self) -> Duration {
|
||||
self.total_duration() / self.bucket_count() as i32
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
pub mod billing_event;
|
||||
pub mod lifetime_usage;
|
||||
pub mod model;
|
||||
pub mod monthly_usage;
|
||||
pub mod provider;
|
||||
pub mod revoked_access_token;
|
||||
pub mod usage;
|
||||
pub mod usage_measure;
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
use crate::{db::UserId, llm::db::ModelId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "lifetime_usages")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: i32,
|
||||
pub user_id: UserId,
|
||||
pub model_id: ModelId,
|
||||
pub input_tokens: i64,
|
||||
pub cache_creation_input_tokens: i64,
|
||||
pub cache_read_input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -1,19 +0,0 @@
|
||||
use chrono::NaiveDateTime;
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
use crate::llm::db::RevokedAccessTokenId;
|
||||
|
||||
/// A revoked access token.
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "revoked_access_tokens")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: RevokedAccessTokenId,
|
||||
pub jti: String,
|
||||
pub revoked_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -1,6 +1,4 @@
|
||||
mod billing_tests;
|
||||
mod provider_tests;
|
||||
mod usage_tests;
|
||||
|
||||
use gpui::BackgroundExecutor;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
use crate::{
|
||||
Cents,
|
||||
db::UserId,
|
||||
llm::{
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
db::{LlmDatabase, TokenUsage, queries::providers::ModelParams},
|
||||
},
|
||||
test_llm_db,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
test_llm_db!(
|
||||
test_billing_limit_exceeded,
|
||||
test_billing_limit_exceeded_postgres
|
||||
);
|
||||
|
||||
async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
|
||||
let provider = LanguageModelProvider::Anthropic;
|
||||
let model = "fake-claude-limerick";
|
||||
const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
|
||||
const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
|
||||
|
||||
// Initialize the database and insert the model
|
||||
db.initialize().await.unwrap();
|
||||
db.insert_models(&[ModelParams {
|
||||
provider,
|
||||
name: model.to_string(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 10_000,
|
||||
max_tokens_per_day: 50_000,
|
||||
price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
|
||||
price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
|
||||
}])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Set a fixed datetime for consistent testing
|
||||
let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let max_monthly_spend = Cents::from_dollars(11);
|
||||
|
||||
// Record usage that brings us close to the limit but doesn't exceed it
|
||||
// Let's say we use $10.50 worth of tokens
|
||||
let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens
|
||||
let usage = TokenUsage {
|
||||
input: tokens_to_use,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
|
||||
// Verify that before we record any usage, there are 0 billing events
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 0);
|
||||
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the recorded usage and spending
|
||||
let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
// Verify that we exceeded the free tier usage
|
||||
assert_eq!(recorded_usage.spending_this_month, Cents::new(1050));
|
||||
assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT);
|
||||
|
||||
// Verify that there is one `billing_event` record
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 1);
|
||||
|
||||
let (billing_event, _model) = &billing_events[0];
|
||||
assert_eq!(billing_event.user_id, user_id);
|
||||
assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
|
||||
assert_eq!(billing_event.input_cache_creation_tokens, 0);
|
||||
assert_eq!(billing_event.input_cache_read_tokens, 0);
|
||||
assert_eq!(billing_event.output_tokens, 0);
|
||||
|
||||
// Record usage that puts us at $20.50
|
||||
let usage_2 = TokenUsage {
|
||||
input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage_2,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the updated usage and spending
|
||||
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(updated_usage.spending_this_month, Cents::new(2050));
|
||||
|
||||
// Verify that there are now two billing events
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 2);
|
||||
|
||||
let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit
|
||||
let usage_exceeding = TokenUsage {
|
||||
input: tokens_to_exceed,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
|
||||
// This should still create a billing event as it's the first request that exceeds the limit
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage_exceeding,
|
||||
true,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
max_monthly_spend,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// Verify the updated usage and spending
|
||||
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(updated_usage.spending_this_month, Cents::new(2150));
|
||||
|
||||
// Verify that we never exceed the user max spending for the user
|
||||
// and avoid charging them.
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 2);
|
||||
}
|
||||
@@ -1,306 +0,0 @@
|
||||
use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT;
|
||||
use crate::{
|
||||
Cents,
|
||||
db::UserId,
|
||||
llm::db::{
|
||||
LlmDatabase, TokenUsage,
|
||||
queries::{providers::ModelParams, usages::Usage},
|
||||
},
|
||||
test_llm_db,
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
|
||||
|
||||
async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
let provider = LanguageModelProvider::Anthropic;
|
||||
let model = "claude-3-5-sonnet";
|
||||
|
||||
db.initialize().await.unwrap();
|
||||
db.insert_models(&[ModelParams {
|
||||
provider,
|
||||
name: model.to_string(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 10_000,
|
||||
max_tokens_per_day: 50_000,
|
||||
price_per_million_input_tokens: 50,
|
||||
price_per_million_output_tokens: 50,
|
||||
}])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We're using a fixed datetime to prevent flakiness based on the clock.
|
||||
let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let now = t0;
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let now = t0 + Duration::seconds(10);
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 2000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 3000,
|
||||
input_tokens_this_minute: 3000,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 3000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
let now = t0 + Duration::seconds(60);
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 2000,
|
||||
input_tokens_this_minute: 2000,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 3000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
let now = t0 + Duration::seconds(60);
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 5000,
|
||||
input_tokens_this_minute: 5000,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 6000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 6000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
let t1 = t0 + Duration::hours(24);
|
||||
let now = t1;
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 0,
|
||||
tokens_this_minute: 0,
|
||||
input_tokens_this_minute: 0,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 5000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 6000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 4000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 4000,
|
||||
input_tokens_this_minute: 4000,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 9000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 10000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
// We're using a fixed datetime to prevent flakiness based on the clock.
|
||||
let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
|
||||
// Test cache creation input tokens
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 1500,
|
||||
input_tokens_this_minute: 1500,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 1500,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
// Test cache read input tokens
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 300,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 2800,
|
||||
input_tokens_this_minute: 2500,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 2800,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 2000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 300,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -9,14 +9,14 @@ use axum::{
|
||||
|
||||
use collab::api::CloudflareIpCountryHeader;
|
||||
use collab::api::billing::sync_llm_usage_with_stripe_periodically;
|
||||
use collab::llm::{db::LlmDatabase, log_usage_periodically};
|
||||
use collab::llm::db::LlmDatabase;
|
||||
use collab::migrations::run_database_migrations;
|
||||
use collab::user_backfiller::spawn_user_backfiller;
|
||||
use collab::{
|
||||
AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
|
||||
env, executor::Executor, rpc::ResultExt,
|
||||
};
|
||||
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically, llm::LlmState};
|
||||
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
|
||||
use db::Database;
|
||||
use std::{
|
||||
env::args,
|
||||
@@ -74,11 +74,10 @@ async fn main() -> Result<()> {
|
||||
let mode = match args.next().as_deref() {
|
||||
Some("collab") => ServiceMode::Collab,
|
||||
Some("api") => ServiceMode::Api,
|
||||
Some("llm") => ServiceMode::Llm,
|
||||
Some("all") => ServiceMode::All,
|
||||
_ => {
|
||||
return Err(anyhow!(
|
||||
"usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
|
||||
"usage: collab <version | migrate | seed | serve <api|collab|all>>"
|
||||
))?;
|
||||
}
|
||||
};
|
||||
@@ -97,20 +96,9 @@ async fn main() -> Result<()> {
|
||||
|
||||
let mut on_shutdown = None;
|
||||
|
||||
if mode.is_llm() {
|
||||
setup_llm_database(&config).await?;
|
||||
|
||||
let state = LlmState::new(config.clone(), Executor::Production).await?;
|
||||
|
||||
log_usage_periodically(state.clone());
|
||||
|
||||
app = app
|
||||
.merge(collab::llm::routes())
|
||||
.layer(Extension(state.clone()));
|
||||
}
|
||||
|
||||
if mode.is_collab() || mode.is_api() {
|
||||
setup_app_database(&config).await?;
|
||||
setup_llm_database(&config).await?;
|
||||
|
||||
let state = AppState::new(config, Executor::Production).await?;
|
||||
|
||||
@@ -336,18 +324,11 @@ async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
|
||||
format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown"))
|
||||
}
|
||||
|
||||
async fn handle_liveness_probe(
|
||||
app_state: Option<Extension<Arc<AppState>>>,
|
||||
llm_state: Option<Extension<Arc<LlmState>>>,
|
||||
) -> Result<String> {
|
||||
async fn handle_liveness_probe(app_state: Option<Extension<Arc<AppState>>>) -> Result<String> {
|
||||
if let Some(state) = app_state {
|
||||
state.db.get_all_users(0, 1).await?;
|
||||
}
|
||||
|
||||
if let Some(llm_state) = llm_state {
|
||||
llm_state.db.list_providers().await?;
|
||||
}
|
||||
|
||||
Ok("ok".to_string())
|
||||
}
|
||||
|
||||
|
||||
@@ -694,7 +694,15 @@ async fn test_collaborating_with_code_actions(
|
||||
// Confirming the code action will trigger a resolve request.
|
||||
let confirm_action = editor_b
|
||||
.update_in(cx_b, |editor, window, cx| {
|
||||
Editor::confirm_code_action(editor, &ConfirmCodeAction { item_ix: Some(0) }, window, cx)
|
||||
Editor::confirm_code_action(
|
||||
editor,
|
||||
&ConfirmCodeAction {
|
||||
item_ix: Some(0),
|
||||
from_mouse_context_menu: false,
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
fake_language_server.set_request_handler::<lsp::request::CodeActionResolveRequest, _, _>(
|
||||
|
||||
@@ -191,6 +191,14 @@ pub fn components() -> AllComponents {
|
||||
all_components
|
||||
}
|
||||
|
||||
// #[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
// pub enum ComponentStatus {
|
||||
// WorkInProgress,
|
||||
// EngineeringReady,
|
||||
// Live,
|
||||
// Deprecated,
|
||||
// }
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum ComponentScope {
|
||||
Collaboration,
|
||||
@@ -241,24 +249,30 @@ pub struct ComponentExample {
|
||||
impl RenderOnce for ComponentExample {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
div()
|
||||
.pt_2()
|
||||
.w_full()
|
||||
.flex()
|
||||
.flex_col()
|
||||
.gap_3()
|
||||
.child(
|
||||
div()
|
||||
.child(self.variant_name.clone())
|
||||
.text_size(rems(1.25))
|
||||
.text_color(cx.theme().colors().text),
|
||||
.flex()
|
||||
.flex_col()
|
||||
.child(
|
||||
div()
|
||||
.child(self.variant_name.clone())
|
||||
.text_size(rems(1.0))
|
||||
.text_color(cx.theme().colors().text),
|
||||
)
|
||||
.when_some(self.description, |this, description| {
|
||||
this.child(
|
||||
div()
|
||||
.text_size(rems(0.875))
|
||||
.text_color(cx.theme().colors().text_muted)
|
||||
.child(description.clone()),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.when_some(self.description, |this, description| {
|
||||
this.child(
|
||||
div()
|
||||
.text_size(rems(0.9375))
|
||||
.text_color(cx.theme().colors().text_muted)
|
||||
.child(description.clone()),
|
||||
)
|
||||
})
|
||||
.child(
|
||||
div()
|
||||
.flex()
|
||||
@@ -268,11 +282,11 @@ impl RenderOnce for ComponentExample {
|
||||
.justify_center()
|
||||
.p_8()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.border_color(cx.theme().colors().border.opacity(0.5))
|
||||
.bg(pattern_slash(
|
||||
cx.theme().colors().surface_background.opacity(0.5),
|
||||
24.0,
|
||||
24.0,
|
||||
12.0,
|
||||
12.0,
|
||||
))
|
||||
.shadow_sm()
|
||||
.child(self.element),
|
||||
|
||||
@@ -16,12 +16,16 @@ default = []
|
||||
|
||||
[dependencies]
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
gpui.workspace = true
|
||||
languages.workspace = true
|
||||
notifications.workspace = true
|
||||
project.workspace = true
|
||||
ui.workspace = true
|
||||
workspace.workspace = true
|
||||
notifications.workspace = true
|
||||
collections.workspace = true
|
||||
ui_input.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
db.workspace = true
|
||||
anyhow.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
//!
|
||||
//! A view for exploring Zed components.
|
||||
|
||||
mod persistence;
|
||||
|
||||
use std::iter::Iterator;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -9,24 +11,27 @@ use client::UserStore;
|
||||
use component::{ComponentId, ComponentMetadata, components};
|
||||
use gpui::{
|
||||
App, Entity, EventEmitter, FocusHandle, Focusable, Task, WeakEntity, Window, list, prelude::*,
|
||||
uniform_list,
|
||||
};
|
||||
|
||||
use collections::HashMap;
|
||||
|
||||
use gpui::{ListState, ScrollHandle, UniformListScrollHandle};
|
||||
use gpui::{ListState, ScrollHandle, ScrollStrategy, UniformListScrollHandle};
|
||||
use languages::LanguageRegistry;
|
||||
use notifications::status_toast::{StatusToast, ToastIcon};
|
||||
use persistence::COMPONENT_PREVIEW_DB;
|
||||
use project::Project;
|
||||
use ui::{Divider, ListItem, ListSubHeader, prelude::*};
|
||||
use ui::{Divider, HighlightedLabel, ListItem, ListSubHeader, prelude::*};
|
||||
|
||||
use ui_input::SingleLineInput;
|
||||
use workspace::{AppState, ItemId, SerializableItem};
|
||||
use workspace::{Item, Workspace, WorkspaceId, item::ItemEvent};
|
||||
|
||||
pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
workspace::register_serializable_item::<ComponentPreview>(cx);
|
||||
|
||||
let app_state = app_state.clone();
|
||||
|
||||
cx.observe_new(move |workspace: &mut Workspace, _, cx| {
|
||||
cx.observe_new(move |workspace: &mut Workspace, _window, cx| {
|
||||
let app_state = app_state.clone();
|
||||
let weak_workspace = cx.entity().downgrade();
|
||||
|
||||
@@ -44,6 +49,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
user_store,
|
||||
None,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
@@ -64,13 +70,13 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
|
||||
enum PreviewEntry {
|
||||
AllComponents,
|
||||
Separator,
|
||||
Component(ComponentMetadata),
|
||||
Component(ComponentMetadata, Option<Vec<usize>>),
|
||||
SectionHeader(SharedString),
|
||||
}
|
||||
|
||||
impl From<ComponentMetadata> for PreviewEntry {
|
||||
fn from(component: ComponentMetadata) -> Self {
|
||||
PreviewEntry::Component(component)
|
||||
PreviewEntry::Component(component, None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,6 +94,7 @@ enum PreviewPage {
|
||||
}
|
||||
|
||||
struct ComponentPreview {
|
||||
workspace_id: Option<WorkspaceId>,
|
||||
focus_handle: FocusHandle,
|
||||
_view_scroll_handle: ScrollHandle,
|
||||
nav_scroll_handle: UniformListScrollHandle,
|
||||
@@ -99,6 +106,8 @@ struct ComponentPreview {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
user_store: Entity<UserStore>,
|
||||
filter_editor: Entity<SingleLineInput>,
|
||||
filter_text: String,
|
||||
}
|
||||
|
||||
impl ComponentPreview {
|
||||
@@ -108,11 +117,14 @@ impl ComponentPreview {
|
||||
user_store: Entity<UserStore>,
|
||||
selected_index: impl Into<Option<usize>>,
|
||||
active_page: Option<PreviewPage>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let sorted_components = components().all_sorted();
|
||||
let selected_index = selected_index.into().unwrap_or(0);
|
||||
let active_page = active_page.unwrap_or(PreviewPage::AllComponents);
|
||||
let filter_editor =
|
||||
cx.new(|cx| SingleLineInput::new(window, cx, "Find components or usages…"));
|
||||
|
||||
let component_list = ListState::new(
|
||||
sorted_components.len(),
|
||||
@@ -132,6 +144,7 @@ impl ComponentPreview {
|
||||
);
|
||||
|
||||
let mut component_preview = Self {
|
||||
workspace_id: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
_view_scroll_handle: ScrollHandle::new(),
|
||||
nav_scroll_handle: UniformListScrollHandle::new(),
|
||||
@@ -143,6 +156,8 @@ impl ComponentPreview {
|
||||
components: sorted_components,
|
||||
component_list,
|
||||
cursor_index: selected_index,
|
||||
filter_editor,
|
||||
filter_text: String::new(),
|
||||
};
|
||||
|
||||
if component_preview.cursor_index > 0 {
|
||||
@@ -154,6 +169,13 @@ impl ComponentPreview {
|
||||
component_preview
|
||||
}
|
||||
|
||||
pub fn active_page_id(&self, _cx: &App) -> ActivePageId {
|
||||
match &self.active_page {
|
||||
PreviewPage::AllComponents => ActivePageId::default(),
|
||||
PreviewPage::Component(component_id) => ActivePageId(component_id.0.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn scroll_to_preview(&mut self, ix: usize, cx: &mut Context<Self>) {
|
||||
self.component_list.scroll_to_reveal_item(ix);
|
||||
self.cursor_index = ix;
|
||||
@@ -162,6 +184,7 @@ impl ComponentPreview {
|
||||
|
||||
fn set_active_page(&mut self, page: PreviewPage, cx: &mut Context<Self>) {
|
||||
self.active_page = page;
|
||||
cx.emit(ItemEvent::UpdateTab);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
@@ -169,20 +192,94 @@ impl ComponentPreview {
|
||||
self.components[ix].clone()
|
||||
}
|
||||
|
||||
fn filtered_components(&self) -> Vec<ComponentMetadata> {
|
||||
if self.filter_text.is_empty() {
|
||||
return self.components.clone();
|
||||
}
|
||||
|
||||
let filter = self.filter_text.to_lowercase();
|
||||
self.components
|
||||
.iter()
|
||||
.filter(|component| {
|
||||
let component_name = component.name().to_lowercase();
|
||||
let scope_name = component.scope().to_string().to_lowercase();
|
||||
let description = component
|
||||
.description()
|
||||
.map(|d| d.to_lowercase())
|
||||
.unwrap_or_default();
|
||||
|
||||
component_name.contains(&filter)
|
||||
|| scope_name.contains(&filter)
|
||||
|| description.contains(&filter)
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn scope_ordered_entries(&self) -> Vec<PreviewEntry> {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let mut scope_groups: HashMap<ComponentScope, Vec<ComponentMetadata>> = HashMap::default();
|
||||
let mut scope_groups: HashMap<
|
||||
ComponentScope,
|
||||
Vec<(ComponentMetadata, Option<Vec<usize>>)>,
|
||||
> = HashMap::default();
|
||||
let lowercase_filter = self.filter_text.to_lowercase();
|
||||
|
||||
for component in &self.components {
|
||||
scope_groups
|
||||
.entry(component.scope())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(component.clone());
|
||||
if self.filter_text.is_empty() {
|
||||
scope_groups
|
||||
.entry(component.scope())
|
||||
.or_insert_with(Vec::new)
|
||||
.push((component.clone(), None));
|
||||
continue;
|
||||
}
|
||||
|
||||
// let full_component_name = component.name();
|
||||
let scopeless_name = component.scopeless_name();
|
||||
let scope_name = component.scope().to_string();
|
||||
let description = component.description().unwrap_or_default();
|
||||
|
||||
let lowercase_scopeless = scopeless_name.to_lowercase();
|
||||
let lowercase_scope = scope_name.to_lowercase();
|
||||
let lowercase_desc = description.to_lowercase();
|
||||
|
||||
if lowercase_scopeless.contains(&lowercase_filter) {
|
||||
if let Some(index) = lowercase_scopeless.find(&lowercase_filter) {
|
||||
let end = index + lowercase_filter.len();
|
||||
|
||||
if end <= scopeless_name.len() {
|
||||
let mut positions = Vec::new();
|
||||
for i in index..end {
|
||||
if scopeless_name.is_char_boundary(i) {
|
||||
positions.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
if !positions.is_empty() {
|
||||
scope_groups
|
||||
.entry(component.scope())
|
||||
.or_insert_with(Vec::new)
|
||||
.push((component.clone(), Some(positions)));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lowercase_scopeless.contains(&lowercase_filter)
|
||||
|| lowercase_scope.contains(&lowercase_filter)
|
||||
|| lowercase_desc.contains(&lowercase_filter)
|
||||
{
|
||||
scope_groups
|
||||
.entry(component.scope())
|
||||
.or_insert_with(Vec::new)
|
||||
.push((component.clone(), None));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the components in each group
|
||||
for components in scope_groups.values_mut() {
|
||||
components.sort_by_key(|c| c.name().to_lowercase());
|
||||
components.sort_by_key(|(c, _)| c.sort_name());
|
||||
}
|
||||
|
||||
let mut entries = Vec::new();
|
||||
@@ -204,10 +301,10 @@ impl ComponentPreview {
|
||||
if !components.is_empty() {
|
||||
entries.push(PreviewEntry::SectionHeader(scope.to_string().into()));
|
||||
let mut sorted_components = components;
|
||||
sorted_components.sort_by_key(|component| component.sort_name());
|
||||
sorted_components.sort_by_key(|(component, _)| component.sort_name());
|
||||
|
||||
for component in sorted_components {
|
||||
entries.push(PreviewEntry::Component(component));
|
||||
for (component, positions) in sorted_components {
|
||||
entries.push(PreviewEntry::Component(component, positions));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -219,10 +316,10 @@ impl ComponentPreview {
|
||||
entries.push(PreviewEntry::Separator);
|
||||
entries.push(PreviewEntry::SectionHeader("Uncategorized".into()));
|
||||
let mut sorted_components = components.clone();
|
||||
sorted_components.sort_by_key(|c| c.sort_name());
|
||||
sorted_components.sort_by_key(|(c, _)| c.sort_name());
|
||||
|
||||
for component in sorted_components {
|
||||
entries.push(PreviewEntry::Component(component.clone()));
|
||||
for (component, positions) in sorted_components {
|
||||
entries.push(PreviewEntry::Component(component, positions));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -237,14 +334,33 @@ impl ComponentPreview {
|
||||
cx: &Context<Self>,
|
||||
) -> impl IntoElement + use<> {
|
||||
match entry {
|
||||
PreviewEntry::Component(component_metadata) => {
|
||||
PreviewEntry::Component(component_metadata, highlight_positions) => {
|
||||
let id = component_metadata.id();
|
||||
let selected = self.active_page == PreviewPage::Component(id.clone());
|
||||
let name = component_metadata.scopeless_name();
|
||||
|
||||
ListItem::new(ix)
|
||||
.child(
|
||||
Label::new(component_metadata.scopeless_name().clone())
|
||||
.color(Color::Default),
|
||||
)
|
||||
.child(if let Some(_positions) = highlight_positions {
|
||||
let name_lower = name.to_lowercase();
|
||||
let filter_lower = self.filter_text.to_lowercase();
|
||||
let valid_positions = if let Some(start) = name_lower.find(&filter_lower) {
|
||||
let end = start + filter_lower.len();
|
||||
(start..end).collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
if valid_positions.is_empty() {
|
||||
Label::new(name.clone())
|
||||
.color(Color::Default)
|
||||
.into_any_element()
|
||||
} else {
|
||||
HighlightedLabel::new(name.clone(), valid_positions).into_any_element()
|
||||
}
|
||||
} else {
|
||||
Label::new(name.clone())
|
||||
.color(Color::Default)
|
||||
.into_any_element()
|
||||
})
|
||||
.selectable(true)
|
||||
.toggle_state(selected)
|
||||
.inset(true)
|
||||
@@ -282,20 +398,70 @@ impl ComponentPreview {
|
||||
}
|
||||
|
||||
fn update_component_list(&mut self, cx: &mut Context<Self>) {
|
||||
let new_len = self.scope_ordered_entries().len();
|
||||
let entries = self.scope_ordered_entries();
|
||||
let new_len = entries.len();
|
||||
let weak_entity = cx.entity().downgrade();
|
||||
|
||||
if new_len > 0 {
|
||||
self.nav_scroll_handle
|
||||
.scroll_to_item(0, ScrollStrategy::Top);
|
||||
}
|
||||
|
||||
let filtered_components = self.filtered_components();
|
||||
|
||||
if !self.filter_text.is_empty() && !matches!(self.active_page, PreviewPage::AllComponents) {
|
||||
if let PreviewPage::Component(ref component_id) = self.active_page {
|
||||
let component_still_visible = filtered_components
|
||||
.iter()
|
||||
.any(|component| component.id() == *component_id);
|
||||
|
||||
if !component_still_visible {
|
||||
if !filtered_components.is_empty() {
|
||||
let first_component = &filtered_components[0];
|
||||
self.set_active_page(PreviewPage::Component(first_component.id()), cx);
|
||||
} else {
|
||||
self.set_active_page(PreviewPage::AllComponents, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.component_list = ListState::new(
|
||||
filtered_components.len(),
|
||||
gpui::ListAlignment::Top,
|
||||
px(1500.0),
|
||||
{
|
||||
let components = filtered_components.clone();
|
||||
let this = cx.entity().downgrade();
|
||||
move |ix, window: &mut Window, cx: &mut App| {
|
||||
if ix >= components.len() {
|
||||
return div().w_full().h_0().into_any_element();
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let component = &components[ix];
|
||||
this.render_preview(component, window, cx)
|
||||
.into_any_element()
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let new_list = ListState::new(
|
||||
new_len,
|
||||
gpui::ListAlignment::Top,
|
||||
px(1500.0),
|
||||
move |ix, window, cx| {
|
||||
if ix >= entries.len() {
|
||||
return div().w_full().h_0().into_any_element();
|
||||
}
|
||||
|
||||
let entry = &entries[ix];
|
||||
|
||||
weak_entity
|
||||
.update(cx, |this, cx| match entry {
|
||||
PreviewEntry::Component(component) => this
|
||||
PreviewEntry::Component(component, _) => this
|
||||
.render_preview(component, window, cx)
|
||||
.into_any_element(),
|
||||
PreviewEntry::SectionHeader(shared_string) => this
|
||||
@@ -309,6 +475,7 @@ impl ComponentPreview {
|
||||
);
|
||||
|
||||
self.component_list = new_list;
|
||||
cx.emit(ItemEvent::UpdateTab);
|
||||
}
|
||||
|
||||
fn render_scope_header(
|
||||
@@ -377,16 +544,27 @@ impl ComponentPreview {
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_all_components(&self) -> impl IntoElement {
|
||||
fn render_all_components(&self, cx: &Context<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.id("component-list")
|
||||
.px_8()
|
||||
.pt_4()
|
||||
.size_full()
|
||||
.child(
|
||||
list(self.component_list.clone())
|
||||
.flex_grow()
|
||||
.with_sizing_behavior(gpui::ListSizingBehavior::Auto),
|
||||
if self.filtered_components().is_empty() && !self.filter_text.is_empty() {
|
||||
div()
|
||||
.size_full()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.text_color(cx.theme().colors().text_muted)
|
||||
.child(format!("No components matching '{}'.", self.filter_text))
|
||||
.into_any_element()
|
||||
} else {
|
||||
list(self.component_list.clone())
|
||||
.flex_grow()
|
||||
.with_sizing_behavior(gpui::ListSizingBehavior::Auto)
|
||||
.into_any_element()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -432,6 +610,19 @@ impl ComponentPreview {
|
||||
|
||||
impl Render for ComponentPreview {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
// TODO: move this into the struct
|
||||
let current_filter = self.filter_editor.update(cx, |input, cx| {
|
||||
if input.is_empty(cx) {
|
||||
String::new()
|
||||
} else {
|
||||
input.editor().read(cx).text(cx).to_string()
|
||||
}
|
||||
});
|
||||
|
||||
if current_filter != self.filter_text {
|
||||
self.filter_text = current_filter;
|
||||
self.update_component_list(cx);
|
||||
}
|
||||
let sidebar_entries = self.scope_ordered_entries();
|
||||
let active_page = self.active_page.clone();
|
||||
|
||||
@@ -449,14 +640,22 @@ impl Render for ComponentPreview {
|
||||
.border_color(cx.theme().colors().border)
|
||||
.h_full()
|
||||
.child(
|
||||
uniform_list(
|
||||
gpui::uniform_list(
|
||||
cx.entity().clone(),
|
||||
"component-nav",
|
||||
sidebar_entries.len(),
|
||||
move |this, range, _window, cx| {
|
||||
range
|
||||
.map(|ix| {
|
||||
this.render_sidebar_entry(ix, &sidebar_entries[ix], cx)
|
||||
.filter_map(|ix| {
|
||||
if ix < sidebar_entries.len() {
|
||||
Some(this.render_sidebar_entry(
|
||||
ix,
|
||||
&sidebar_entries[ix],
|
||||
cx,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
},
|
||||
@@ -481,12 +680,29 @@ impl Render for ComponentPreview {
|
||||
),
|
||||
),
|
||||
)
|
||||
.child(match active_page {
|
||||
PreviewPage::AllComponents => self.render_all_components().into_any_element(),
|
||||
PreviewPage::Component(id) => self
|
||||
.render_component_page(&id, window, cx)
|
||||
.into_any_element(),
|
||||
})
|
||||
.child(
|
||||
v_flex()
|
||||
.id("content-area")
|
||||
.flex_1()
|
||||
.size_full()
|
||||
.overflow_hidden()
|
||||
.child(
|
||||
div()
|
||||
.p_2()
|
||||
.w_full()
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(self.filter_editor.clone()),
|
||||
)
|
||||
.child(match active_page {
|
||||
PreviewPage::AllComponents => {
|
||||
self.render_all_components(cx).into_any_element()
|
||||
}
|
||||
PreviewPage::Component(id) => self
|
||||
.render_component_page(&id, window, cx)
|
||||
.into_any_element(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -498,6 +714,21 @@ impl Focusable for ComponentPreview {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ActivePageId(pub String);
|
||||
|
||||
impl Default for ActivePageId {
|
||||
fn default() -> Self {
|
||||
ActivePageId("AllComponents".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ComponentId> for ActivePageId {
|
||||
fn from(id: ComponentId) -> Self {
|
||||
ActivePageId(id.0.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Item for ComponentPreview {
|
||||
type Event = ItemEvent;
|
||||
|
||||
@@ -516,7 +747,7 @@ impl Item for ComponentPreview {
|
||||
fn clone_on_split(
|
||||
&self,
|
||||
_workspace_id: Option<WorkspaceId>,
|
||||
_window: &mut Window,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<gpui::Entity<Self>>
|
||||
where
|
||||
@@ -535,6 +766,7 @@ impl Item for ComponentPreview {
|
||||
user_store,
|
||||
selected_index,
|
||||
Some(active_page),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}))
|
||||
@@ -543,6 +775,15 @@ impl Item for ComponentPreview {
|
||||
fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) {
|
||||
f(*event)
|
||||
}
|
||||
|
||||
fn added_to_workspace(
|
||||
&mut self,
|
||||
workspace: &mut Workspace,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<Self>,
|
||||
) {
|
||||
self.workspace_id = workspace.database_id();
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializableItem for ComponentPreview {
|
||||
@@ -553,26 +794,53 @@ impl SerializableItem for ComponentPreview {
|
||||
fn deserialize(
|
||||
project: Entity<Project>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
_workspace_id: WorkspaceId,
|
||||
_item_id: ItemId,
|
||||
workspace_id: WorkspaceId,
|
||||
item_id: ItemId,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Task<gpui::Result<Entity<Self>>> {
|
||||
let deserialized_active_page =
|
||||
match COMPONENT_PREVIEW_DB.get_active_page(item_id, workspace_id) {
|
||||
Ok(page) => {
|
||||
if let Some(page) = page {
|
||||
ActivePageId(page)
|
||||
} else {
|
||||
ActivePageId::default()
|
||||
}
|
||||
}
|
||||
Err(_) => ActivePageId::default(),
|
||||
};
|
||||
|
||||
let user_store = project.read(cx).user_store().clone();
|
||||
let language_registry = project.read(cx).languages().clone();
|
||||
let preview_page = if deserialized_active_page.0 == ActivePageId::default().0 {
|
||||
Some(PreviewPage::default())
|
||||
} else {
|
||||
let component_str = deserialized_active_page.0;
|
||||
let component_registry = components();
|
||||
let all_components = component_registry.all();
|
||||
let found_component = all_components.iter().find(|c| c.id().0 == component_str);
|
||||
|
||||
if let Some(component) = found_component {
|
||||
Some(PreviewPage::Component(component.id().clone()))
|
||||
} else {
|
||||
Some(PreviewPage::default())
|
||||
}
|
||||
};
|
||||
|
||||
window.spawn(cx, async move |cx| {
|
||||
let user_store = user_store.clone();
|
||||
let language_registry = language_registry.clone();
|
||||
let weak_workspace = workspace.clone();
|
||||
cx.update(|_, cx| {
|
||||
cx.update(move |window, cx| {
|
||||
Ok(cx.new(|cx| {
|
||||
ComponentPreview::new(
|
||||
weak_workspace,
|
||||
language_registry,
|
||||
user_store,
|
||||
None,
|
||||
None,
|
||||
preview_page,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}))
|
||||
@@ -581,34 +849,41 @@ impl SerializableItem for ComponentPreview {
|
||||
}
|
||||
|
||||
fn cleanup(
|
||||
_workspace_id: WorkspaceId,
|
||||
_alive_items: Vec<ItemId>,
|
||||
workspace_id: WorkspaceId,
|
||||
alive_items: Vec<ItemId>,
|
||||
_window: &mut Window,
|
||||
_cx: &mut App,
|
||||
cx: &mut App,
|
||||
) -> Task<gpui::Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
// window.spawn(cx, |_| {
|
||||
// ...
|
||||
// })
|
||||
cx.background_spawn(async move {
|
||||
COMPONENT_PREVIEW_DB
|
||||
.delete_unloaded_items(workspace_id, alive_items)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn serialize(
|
||||
&mut self,
|
||||
_workspace: &mut Workspace,
|
||||
_item_id: ItemId,
|
||||
item_id: ItemId,
|
||||
_closing: bool,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<Self>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<gpui::Result<()>>> {
|
||||
// TODO: Serialize the active index so we can re-open to the same place
|
||||
None
|
||||
let active_page = self.active_page_id(cx);
|
||||
let workspace_id = self.workspace_id?;
|
||||
Some(cx.background_spawn(async move {
|
||||
COMPONENT_PREVIEW_DB
|
||||
.save_active_page(item_id, workspace_id, active_page.0)
|
||||
.await
|
||||
}))
|
||||
}
|
||||
|
||||
fn should_serialize(&self, _event: &Self::Event) -> bool {
|
||||
false
|
||||
fn should_serialize(&self, event: &Self::Event) -> bool {
|
||||
matches!(event, ItemEvent::UpdateTab)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: use language registry to allow rendering markdown
|
||||
#[derive(IntoElement)]
|
||||
pub struct ComponentPreviewPage {
|
||||
// languages: Arc<LanguageRegistry>,
|
||||
|
||||
73
crates/component_preview/src/persistence.rs
Normal file
73
crates/component_preview/src/persistence.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use anyhow::Result;
|
||||
use db::{define_connection, query, sqlez::statement::Statement, sqlez_macros::sql};
|
||||
use workspace::{ItemId, WorkspaceDb, WorkspaceId};
|
||||
|
||||
define_connection! {
|
||||
pub static ref COMPONENT_PREVIEW_DB: ComponentPreviewDb<WorkspaceDb> =
|
||||
&[sql!(
|
||||
CREATE TABLE component_previews (
|
||||
workspace_id INTEGER,
|
||||
item_id INTEGER UNIQUE,
|
||||
active_page_id TEXT,
|
||||
PRIMARY KEY(workspace_id, item_id),
|
||||
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
|
||||
ON DELETE CASCADE
|
||||
) STRICT;
|
||||
)];
|
||||
}
|
||||
|
||||
impl ComponentPreviewDb {
|
||||
pub async fn save_active_page(
|
||||
&self,
|
||||
item_id: ItemId,
|
||||
workspace_id: WorkspaceId,
|
||||
active_page_id: String,
|
||||
) -> Result<()> {
|
||||
let query = "INSERT INTO component_previews(item_id, workspace_id, active_page_id)
|
||||
VALUES (?1, ?2, ?3)
|
||||
ON CONFLICT DO UPDATE SET
|
||||
active_page_id = ?3";
|
||||
self.write(move |conn| {
|
||||
let mut statement = Statement::prepare(conn, query)?;
|
||||
let mut next_index = statement.bind(&item_id, 1)?;
|
||||
next_index = statement.bind(&workspace_id, next_index)?;
|
||||
statement.bind(&active_page_id, next_index)?;
|
||||
statement.exec()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
query! {
|
||||
pub fn get_active_page(item_id: ItemId, workspace_id: WorkspaceId) -> Result<Option<String>> {
|
||||
SELECT active_page_id
|
||||
FROM component_previews
|
||||
WHERE item_id = ? AND workspace_id = ?
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_unloaded_items(
|
||||
&self,
|
||||
workspace: WorkspaceId,
|
||||
alive_items: Vec<ItemId>,
|
||||
) -> Result<()> {
|
||||
let placeholders = alive_items
|
||||
.iter()
|
||||
.map(|_| "?")
|
||||
.collect::<Vec<&str>>()
|
||||
.join(", ");
|
||||
|
||||
let query = format!(
|
||||
"DELETE FROM component_previews WHERE workspace_id = ? AND item_id NOT IN ({placeholders})"
|
||||
);
|
||||
|
||||
self.write(move |conn| {
|
||||
let mut statement = Statement::prepare(conn, query)?;
|
||||
let mut next_index = statement.bind(&workspace, 1)?;
|
||||
for id in alive_items {
|
||||
next_index = statement.bind(&id, next_index)?;
|
||||
}
|
||||
statement.exec()
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -53,16 +53,18 @@ impl Tool for ContextServerTool {
|
||||
true
|
||||
}
|
||||
|
||||
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
|
||||
match &self.tool.input_schema {
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
let mut schema = self.tool.input_schema.clone();
|
||||
assistant_tool::adapt_schema_to_format(&mut schema, format)?;
|
||||
Ok(match schema {
|
||||
serde_json::Value::Null => {
|
||||
serde_json::json!({ "type": "object", "properties": [] })
|
||||
}
|
||||
serde_json::Value::Object(map) if map.is_empty() => {
|
||||
serde_json::json!({ "type": "object", "properties": [] })
|
||||
}
|
||||
_ => self.tool.input_schema.clone(),
|
||||
}
|
||||
_ => schema,
|
||||
})
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
|
||||
@@ -33,6 +33,8 @@ pub enum Model {
|
||||
Gpt4o,
|
||||
#[serde(alias = "gpt-4", rename = "gpt-4")]
|
||||
Gpt4,
|
||||
#[serde(alias = "gpt-4.1", rename = "gpt-4.1")]
|
||||
Gpt4_1,
|
||||
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
|
||||
Gpt3_5Turbo,
|
||||
#[serde(alias = "o1", rename = "o1")]
|
||||
@@ -50,6 +52,8 @@ pub enum Model {
|
||||
Claude3_7SonnetThinking,
|
||||
#[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
|
||||
Gemini20Flash,
|
||||
#[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")]
|
||||
Gemini25Pro,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@@ -57,11 +61,12 @@ impl Model {
|
||||
match self {
|
||||
Self::Gpt4o
|
||||
| Self::Gpt4
|
||||
| Self::Gpt4_1
|
||||
| Self::Gpt3_5Turbo
|
||||
| Self::Claude3_5Sonnet
|
||||
| Self::Claude3_7Sonnet
|
||||
| Self::Claude3_7SonnetThinking => true,
|
||||
Self::O3Mini | Self::O1 | Self::Gemini20Flash => false,
|
||||
Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +74,7 @@ impl Model {
|
||||
match id {
|
||||
"gpt-4o" => Ok(Self::Gpt4o),
|
||||
"gpt-4" => Ok(Self::Gpt4),
|
||||
"gpt-4.1" => Ok(Self::Gpt4_1),
|
||||
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
|
||||
"o1" => Ok(Self::O1),
|
||||
"o3-mini" => Ok(Self::O3Mini),
|
||||
@@ -76,6 +82,7 @@ impl Model {
|
||||
"claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
|
||||
"claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
|
||||
"gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
|
||||
"gemini-2.5-pro" => Ok(Self::Gemini25Pro),
|
||||
_ => Err(anyhow!("Invalid model id: {}", id)),
|
||||
}
|
||||
}
|
||||
@@ -84,6 +91,7 @@ impl Model {
|
||||
match self {
|
||||
Self::Gpt3_5Turbo => "gpt-3.5-turbo",
|
||||
Self::Gpt4 => "gpt-4",
|
||||
Self::Gpt4_1 => "gpt-4.1",
|
||||
Self::Gpt4o => "gpt-4o",
|
||||
Self::O3Mini => "o3-mini",
|
||||
Self::O1 => "o1",
|
||||
@@ -91,6 +99,7 @@ impl Model {
|
||||
Self::Claude3_7Sonnet => "claude-3-7-sonnet",
|
||||
Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
|
||||
Self::Gemini20Flash => "gemini-2.0-flash-001",
|
||||
Self::Gemini25Pro => "gemini-2.5-pro",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +107,7 @@ impl Model {
|
||||
match self {
|
||||
Self::Gpt3_5Turbo => "GPT-3.5",
|
||||
Self::Gpt4 => "GPT-4",
|
||||
Self::Gpt4_1 => "GPT-4.1",
|
||||
Self::Gpt4o => "GPT-4o",
|
||||
Self::O3Mini => "o3-mini",
|
||||
Self::O1 => "o1",
|
||||
@@ -105,6 +115,7 @@ impl Model {
|
||||
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
|
||||
Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
|
||||
Self::Gemini20Flash => "Gemini 2.0 Flash",
|
||||
Self::Gemini25Pro => "Gemini 2.5 Pro",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,13 +123,15 @@ impl Model {
|
||||
match self {
|
||||
Self::Gpt4o => 64_000,
|
||||
Self::Gpt4 => 32_768,
|
||||
Self::Gpt4_1 => 1_047_576,
|
||||
Self::Gpt3_5Turbo => 12_288,
|
||||
Self::O3Mini => 64_000,
|
||||
Self::O1 => 20_000,
|
||||
Self::Claude3_5Sonnet => 200_000,
|
||||
Self::Claude3_7Sonnet => 90_000,
|
||||
Self::Claude3_7SonnetThinking => 90_000,
|
||||
Model::Gemini20Flash => 128_000,
|
||||
Self::Gemini20Flash => 128_000,
|
||||
Self::Gemini25Pro => 128_000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,6 @@ log.workspace = true
|
||||
node_runtime.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
regex.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -20,7 +20,7 @@ use std::{
|
||||
net::Ipv4Addr,
|
||||
ops::Deref,
|
||||
path::PathBuf,
|
||||
sync::{Arc, LazyLock},
|
||||
sync::Arc,
|
||||
};
|
||||
use task::{DebugAdapterConfig, DebugTaskDefinition};
|
||||
use util::ResultExt;
|
||||
@@ -291,14 +291,7 @@ pub trait DebugAdapter: 'static + Send + Sync {
|
||||
|
||||
/// Should return base configuration to make the debug adapter work
|
||||
fn request_args(&self, config: &DebugTaskDefinition) -> Value;
|
||||
|
||||
fn attach_processes_filter(&self) -> regex::Regex {
|
||||
EMPTY_REGEX.clone()
|
||||
}
|
||||
}
|
||||
|
||||
static EMPTY_REGEX: LazyLock<regex::Regex> =
|
||||
LazyLock::new(|| regex::Regex::new("").expect("Regex compilation to succeed"));
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub struct FakeAdapter {}
|
||||
|
||||
@@ -375,10 +368,4 @@ impl DebugAdapter for FakeAdapter {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn attach_processes_filter(&self) -> regex::Regex {
|
||||
static REGEX: LazyLock<regex::Regex> =
|
||||
LazyLock::new(|| regex::Regex::new("^fake-binary").unwrap());
|
||||
REGEX.clone()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ struct DapRegistryState {
|
||||
adapters: BTreeMap<DebugAdapterName, Arc<dyn DebugAdapter>>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Default)]
|
||||
/// Stores available debug adapters.
|
||||
pub struct DapRegistry(Arc<RwLock<DapRegistryState>>);
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ dap.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
paths.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
task.workspace = true
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user