Compare commits
169 Commits
tune-for-g
...
vim-syntax
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e6f44d87c | ||
|
|
707a4c7f20 | ||
|
|
854076f96d | ||
|
|
cf931247d0 | ||
|
|
b74477d12e | ||
|
|
3077abf9cf | ||
|
|
07dab4e94a | ||
|
|
59686f1f44 | ||
|
|
a60bea8a3d | ||
|
|
b820aa1fcd | ||
|
|
55d91bce53 | ||
|
|
b798392050 | ||
|
|
657c8b1084 | ||
|
|
2bb8aa2f73 | ||
|
|
beeb42da29 | ||
|
|
6d66ff1d95 | ||
|
|
e0b818af62 | ||
|
|
58a400b1ee | ||
|
|
8ab7d44d51 | ||
|
|
56d4c0af9f | ||
|
|
feeda7fa37 | ||
|
|
4a5c55a8f2 | ||
|
|
7c1ae9bcc3 | ||
|
|
6f97da3435 | ||
|
|
63c1033448 | ||
|
|
b16911e756 | ||
|
|
b14401f817 | ||
|
|
17cf865d1e | ||
|
|
b7ec437b13 | ||
|
|
f1aab1120d | ||
|
|
3f90bc81bd | ||
|
|
9d5fb3c3f3 | ||
|
|
864767ad35 | ||
|
|
ec69b68e72 | ||
|
|
9dd18e5ee1 | ||
|
|
2ebe16a52f | ||
|
|
1ed4647203 | ||
|
|
ebed567adb | ||
|
|
a6544c70c5 | ||
|
|
b363e1a482 | ||
|
|
65e3e84cbc | ||
|
|
1e1d4430c2 | ||
|
|
c874f1fa9d | ||
|
|
9a9e96ed5a | ||
|
|
8c46e290df | ||
|
|
aacbb9c2f4 | ||
|
|
f90333f92e | ||
|
|
b24f614ca3 | ||
|
|
cefa0cbed8 | ||
|
|
3fb1023667 | ||
|
|
9c715b470e | ||
|
|
ae219e9e99 | ||
|
|
6d99c12796 | ||
|
|
8fb7fa941a | ||
|
|
22d75b798e | ||
|
|
06a199da4d | ||
|
|
ab6125ddde | ||
|
|
d3bc561f26 | ||
|
|
f13f2dfb70 | ||
|
|
24e4446cd3 | ||
|
|
cc536655a1 | ||
|
|
2a9e73c65d | ||
|
|
4f1728e5ee | ||
|
|
40c91d5df0 | ||
|
|
fe1b36671d | ||
|
|
bb9e2b0403 | ||
|
|
4f8d7f0a6b | ||
|
|
caf3d30bf6 | ||
|
|
df0cf22347 | ||
|
|
a305eda8d1 | ||
|
|
ba7b1db054 | ||
|
|
019c8ded77 | ||
|
|
1704dbea7e | ||
|
|
eefa6c4882 | ||
|
|
1f17df7fb0 | ||
|
|
6d687a2c2c | ||
|
|
32214abb64 | ||
|
|
a78563b80b | ||
|
|
f881cacd8a | ||
|
|
a539a38f13 | ||
|
|
ca6fd101c1 | ||
|
|
f8097c7c98 | ||
|
|
c1427ea802 | ||
|
|
1e83022f03 | ||
|
|
0ee900e8fb | ||
|
|
f9f4be1fc4 | ||
|
|
a00b07371a | ||
|
|
f725b5e248 | ||
|
|
07436b4284 | ||
|
|
8bec4cbecb | ||
|
|
047e7eacec | ||
|
|
1d5d3de85c | ||
|
|
c4dbaa91f0 | ||
|
|
97c01c6720 | ||
|
|
310ea43048 | ||
|
|
6bb4b5fa64 | ||
|
|
e0fa3032ec | ||
|
|
9cf6be2057 | ||
|
|
5462e199fb | ||
|
|
3a60420b41 | ||
|
|
89c184a26f | ||
|
|
d7f0241d7b | ||
|
|
1445af559b | ||
|
|
804de3316e | ||
|
|
a387bf5f54 | ||
|
|
c7047d5f0a | ||
|
|
406d975f39 | ||
|
|
cbed580db0 | ||
|
|
8aef64bbfa | ||
|
|
9086784038 | ||
|
|
2abc5893c1 | ||
|
|
a23ee61a4b | ||
|
|
38e45e828b | ||
|
|
181bf78b7d | ||
|
|
c42d060509 | ||
|
|
6ea9abdc1b | ||
|
|
070eac28e3 | ||
|
|
05692e298a | ||
|
|
ccb049bd97 | ||
|
|
fe57eedb44 | ||
|
|
c57e6bc784 | ||
|
|
83135e98e6 | ||
|
|
703ee29658 | ||
|
|
f792827a01 | ||
|
|
45f9edcbb9 | ||
|
|
e3354543c0 | ||
|
|
cb187b0b4d | ||
|
|
d989b2260b | ||
|
|
ae076fa415 | ||
|
|
b4af61edfe | ||
|
|
ea8a3be91b | ||
|
|
5173a1a968 | ||
|
|
87f097a0ab | ||
|
|
f9407db7d6 | ||
|
|
384b11392a | ||
|
|
f20596c33b | ||
|
|
eb863f8fd6 | ||
|
|
97579662e6 | ||
|
|
53849cf983 | ||
|
|
1e25249055 | ||
|
|
469824c350 | ||
|
|
a1c645e57e | ||
|
|
0791596cda | ||
|
|
9cc1851be7 | ||
|
|
50bd8770bd | ||
|
|
00bdebc89d | ||
|
|
d5134062ac | ||
|
|
0e9f6986cf | ||
|
|
1035c6aab5 | ||
|
|
75e69a5ae9 | ||
|
|
05afe95539 | ||
|
|
a5a116439e | ||
|
|
361ceee72b | ||
|
|
68724ea99e | ||
|
|
e12106e025 | ||
|
|
77aa667bf3 | ||
|
|
8b47b40dc0 | ||
|
|
01990c8375 | ||
|
|
4e7dc37f01 | ||
|
|
00fd045844 | ||
|
|
7443fde4e9 | ||
|
|
d5ab42aeb8 | ||
|
|
07403f0b08 | ||
|
|
00bc154c46 | ||
|
|
f627ac92ee | ||
|
|
218e8d09c5 | ||
|
|
2c4b75ab30 | ||
|
|
aab76208b5 | ||
|
|
f3f0766242 |
@@ -1,8 +1,8 @@
|
|||||||
name: Bug Report (Agent Panel)
|
name: Bug Report (AI Related)
|
||||||
description: Zed Agent Panel Bugs
|
description: Zed Agent Panel Bugs
|
||||||
type: "Bug"
|
type: "Bug"
|
||||||
labels: ["agent", "ai"]
|
labels: ["ai"]
|
||||||
title: "Agent Panel: <a short description of the Agent Panel bug>"
|
title: "AI: <a short description of the AI Related bug>"
|
||||||
body:
|
body:
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
@@ -14,7 +14,6 @@ body:
|
|||||||
|
|
||||||
### Description
|
### Description
|
||||||
<!-- Describe with sufficient detail to reproduce from a clean Zed install. -->
|
<!-- Describe with sufficient detail to reproduce from a clean Zed install. -->
|
||||||
<!-- Please include the LLM provider and model name you are using -->
|
|
||||||
Steps to trigger the problem:
|
Steps to trigger the problem:
|
||||||
1.
|
1.
|
||||||
2.
|
2.
|
||||||
@@ -22,6 +21,13 @@ body:
|
|||||||
|
|
||||||
Actual Behavior:
|
Actual Behavior:
|
||||||
Expected Behavior:
|
Expected Behavior:
|
||||||
|
|
||||||
|
### Model Provider Details
|
||||||
|
- Provider: (Anthropic via ZedPro, Anthropic via API key, Copilot Chat, Mistral, OpenAI, etc)
|
||||||
|
- Model Name:
|
||||||
|
- Mode: (Agent Panel, Inline Assistant, Terminal Assistant or Text Threads)
|
||||||
|
- MCP Servers in-use:
|
||||||
|
- Other Details:
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
5
.github/workflows/ci.yml
vendored
@@ -482,7 +482,9 @@ jobs:
|
|||||||
- macos_tests
|
- macos_tests
|
||||||
- windows_clippy
|
- windows_clippy
|
||||||
- windows_tests
|
- windows_tests
|
||||||
if: always()
|
if: |
|
||||||
|
github.repository_owner == 'zed-industries' &&
|
||||||
|
always()
|
||||||
steps:
|
steps:
|
||||||
- name: Check all tests passed
|
- name: Check all tests passed
|
||||||
run: |
|
run: |
|
||||||
@@ -714,6 +716,7 @@ jobs:
|
|||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
nix-build:
|
nix-build:
|
||||||
|
name: Build with Nix
|
||||||
uses: ./.github/workflows/nix.yml
|
uses: ./.github/workflows/nix.yml
|
||||||
if: github.repository_owner == 'zed-industries' && contains(github.event.pull_request.labels.*.name, 'run-nix')
|
if: github.repository_owner == 'zed-industries' && contains(github.event.pull_request.labels.*.name, 'run-nix')
|
||||||
with:
|
with:
|
||||||
|
|||||||
1
.github/workflows/nix.yml
vendored
@@ -56,6 +56,7 @@ jobs:
|
|||||||
name: zed
|
name: zed
|
||||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||||
pushFilter: "${{ inputs.cachix-filter }}"
|
pushFilter: "${{ inputs.cachix-filter }}"
|
||||||
|
cachixArgs: '-v'
|
||||||
|
|
||||||
- run: nix build .#${{ inputs.flake-output }} -L --accept-flake-config
|
- run: nix build .#${{ inputs.flake-output }} -L --accept-flake-config
|
||||||
|
|
||||||
|
|||||||
1
.github/workflows/release_nightly.yml
vendored
@@ -168,6 +168,7 @@ jobs:
|
|||||||
run: script/upload-nightly linux-targz
|
run: script/upload-nightly linux-targz
|
||||||
|
|
||||||
bundle-nix:
|
bundle-nix:
|
||||||
|
name: Build and cache Nix package
|
||||||
needs: tests
|
needs: tests
|
||||||
uses: ./.github/workflows/nix.yml
|
uses: ./.github/workflows/nix.yml
|
||||||
|
|
||||||
|
|||||||
@@ -2,16 +2,11 @@
|
|||||||
{
|
{
|
||||||
"label": "Debug Zed (CodeLLDB)",
|
"label": "Debug Zed (CodeLLDB)",
|
||||||
"adapter": "CodeLLDB",
|
"adapter": "CodeLLDB",
|
||||||
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
|
"build": { "label": "Build Zed", "command": "cargo", "args": ["build"] }
|
||||||
"request": "launch"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"label": "Debug Zed (GDB)",
|
"label": "Debug Zed (GDB)",
|
||||||
"adapter": "GDB",
|
"adapter": "GDB",
|
||||||
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
|
"build": { "label": "Build Zed", "command": "cargo", "args": ["build"] }
|
||||||
"request": "launch",
|
|
||||||
"initialize_args": {
|
|
||||||
"stopAtBeginningOfMainSubprogram": true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
26
Cargo.lock
generated
@@ -114,6 +114,7 @@ dependencies = [
|
|||||||
"serde_json_lenient",
|
"serde_json_lenient",
|
||||||
"settings",
|
"settings",
|
||||||
"smol",
|
"smol",
|
||||||
|
"sqlez",
|
||||||
"streaming_diff",
|
"streaming_diff",
|
||||||
"telemetry",
|
"telemetry",
|
||||||
"telemetry_events",
|
"telemetry_events",
|
||||||
@@ -133,6 +134,7 @@ dependencies = [
|
|||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
"zed_actions",
|
"zed_actions",
|
||||||
"zed_llm_client",
|
"zed_llm_client",
|
||||||
|
"zstd",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -525,6 +527,7 @@ dependencies = [
|
|||||||
"fuzzy",
|
"fuzzy",
|
||||||
"gpui",
|
"gpui",
|
||||||
"indexed_docs",
|
"indexed_docs",
|
||||||
|
"indoc",
|
||||||
"language",
|
"language",
|
||||||
"language_model",
|
"language_model",
|
||||||
"languages",
|
"languages",
|
||||||
@@ -559,6 +562,7 @@ dependencies = [
|
|||||||
"workspace",
|
"workspace",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
"zed_actions",
|
"zed_actions",
|
||||||
|
"zed_llm_client",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -683,6 +687,7 @@ dependencies = [
|
|||||||
"language_model",
|
"language_model",
|
||||||
"language_models",
|
"language_models",
|
||||||
"log",
|
"log",
|
||||||
|
"lsp",
|
||||||
"markdown",
|
"markdown",
|
||||||
"open",
|
"open",
|
||||||
"paths",
|
"paths",
|
||||||
@@ -2198,6 +2203,7 @@ dependencies = [
|
|||||||
"editor",
|
"editor",
|
||||||
"gpui",
|
"gpui",
|
||||||
"itertools 0.14.0",
|
"itertools 0.14.0",
|
||||||
|
"settings",
|
||||||
"theme",
|
"theme",
|
||||||
"ui",
|
"ui",
|
||||||
"workspace",
|
"workspace",
|
||||||
@@ -4732,6 +4738,7 @@ dependencies = [
|
|||||||
"tree-sitter-rust",
|
"tree-sitter-rust",
|
||||||
"tree-sitter-typescript",
|
"tree-sitter-typescript",
|
||||||
"ui",
|
"ui",
|
||||||
|
"unicode-script",
|
||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
"unindent",
|
"unindent",
|
||||||
"url",
|
"url",
|
||||||
@@ -5045,6 +5052,7 @@ dependencies = [
|
|||||||
"util",
|
"util",
|
||||||
"uuid",
|
"uuid",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
|
"zed_llm_client",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -6147,6 +6155,7 @@ dependencies = [
|
|||||||
"workspace",
|
"workspace",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
"zed_actions",
|
"zed_actions",
|
||||||
|
"zed_llm_client",
|
||||||
"zlog",
|
"zlog",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -7064,6 +7073,7 @@ dependencies = [
|
|||||||
"image",
|
"image",
|
||||||
"inventory",
|
"inventory",
|
||||||
"itertools 0.14.0",
|
"itertools 0.14.0",
|
||||||
|
"libc",
|
||||||
"log",
|
"log",
|
||||||
"lyon",
|
"lyon",
|
||||||
"media",
|
"media",
|
||||||
@@ -8752,6 +8762,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"settings",
|
"settings",
|
||||||
|
"shellexpand 2.1.2",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"smol",
|
"smol",
|
||||||
"streaming-iterator",
|
"streaming-iterator",
|
||||||
@@ -8929,6 +8940,7 @@ dependencies = [
|
|||||||
"async-compression",
|
"async-compression",
|
||||||
"async-tar",
|
"async-tar",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"chrono",
|
||||||
"collections",
|
"collections",
|
||||||
"dap",
|
"dap",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
@@ -8982,6 +8994,7 @@ dependencies = [
|
|||||||
"tree-sitter-yaml",
|
"tree-sitter-yaml",
|
||||||
"unindent",
|
"unindent",
|
||||||
"util",
|
"util",
|
||||||
|
"which 6.0.3",
|
||||||
"workspace",
|
"workspace",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
]
|
]
|
||||||
@@ -15581,6 +15594,7 @@ dependencies = [
|
|||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
"hex",
|
"hex",
|
||||||
|
"log",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"pretty_assertions",
|
"pretty_assertions",
|
||||||
"proto",
|
"proto",
|
||||||
@@ -16480,9 +16494,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tree-sitter"
|
name = "tree-sitter"
|
||||||
version = "0.25.3"
|
version = "0.25.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b9ac5ea5e7f2f1700842ec071401010b9c59bf735295f6e9fa079c3dc035b167"
|
checksum = "ac5fff5c47490dfdf473b5228039bfacad9d765d9b6939d26bf7cc064c1c7822"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cc",
|
"cc",
|
||||||
"regex",
|
"regex",
|
||||||
@@ -17115,8 +17129,6 @@ dependencies = [
|
|||||||
"tempfile",
|
"tempfile",
|
||||||
"tendril",
|
"tendril",
|
||||||
"unicase",
|
"unicase",
|
||||||
"unicode-script",
|
|
||||||
"unicode-segmentation",
|
|
||||||
"util_macros",
|
"util_macros",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
@@ -19680,7 +19692,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zed"
|
name = "zed"
|
||||||
version = "0.189.0"
|
version = "0.190.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"activity_indicator",
|
"activity_indicator",
|
||||||
"agent",
|
"agent",
|
||||||
@@ -19876,9 +19888,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zed_llm_client"
|
name = "zed_llm_client"
|
||||||
version = "0.8.3"
|
version = "0.8.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "22a8b9575b215536ed8ad254ba07171e4e13bd029eda3b54cca4b184d2768050"
|
checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"serde",
|
"serde",
|
||||||
|
|||||||
@@ -572,7 +572,7 @@ tokio = { version = "1" }
|
|||||||
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
|
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
tower-http = "0.4.4"
|
tower-http = "0.4.4"
|
||||||
tree-sitter = { version = "0.25.3", features = ["wasm"] }
|
tree-sitter = { version = "0.25.5", features = ["wasm"] }
|
||||||
tree-sitter-bash = "0.23"
|
tree-sitter-bash = "0.23"
|
||||||
tree-sitter-c = "0.23"
|
tree-sitter-c = "0.23"
|
||||||
tree-sitter-cpp = "0.23"
|
tree-sitter-cpp = "0.23"
|
||||||
@@ -617,7 +617,7 @@ wasmtime = { version = "29", default-features = false, features = [
|
|||||||
wasmtime-wasi = "29"
|
wasmtime-wasi = "29"
|
||||||
which = "6.0.0"
|
which = "6.0.0"
|
||||||
workspace-hack = "0.1.0"
|
workspace-hack = "0.1.0"
|
||||||
zed_llm_client = "0.8.3"
|
zed_llm_client = "0.8.4"
|
||||||
zstd = "0.11"
|
zstd = "0.11"
|
||||||
|
|
||||||
[workspace.dependencies.async-stripe]
|
[workspace.dependencies.async-stripe]
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
<path d="M17 20H16C14.9391 20 13.9217 19.6629 13.1716 19.0627C12.4214 18.4626 12 17.6487 12 16.8V7.2C12 6.35131 12.4214 5.53737 13.1716 4.93726C13.9217 4.33714 14.9391 4 16 4H17" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
<path d="M11 13H10.4C9.76346 13 9.15302 12.7893 8.70296 12.4142C8.25284 12.0391 8 11.5304 8 11V5C8 4.46957 8.25284 3.96086 8.70296 3.58579C9.15302 3.21071 9.76346 3 10.4 3H11" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
<path d="M7 20H8C9.06087 20 10.0783 19.5786 10.8284 18.8284C11.5786 18.0783 12 17.0609 12 16V15" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
<path d="M5 13H5.6C6.23654 13 6.84698 12.7893 7.29704 12.4142C7.74716 12.0391 8 11.5304 8 11V5C8 4.46957 7.74716 3.96086 7.29704 3.58579C6.84698 3.21071 6.23654 3 5.6 3H5" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
<path d="M7 4H8C9.06087 4 10.0783 4.42143 10.8284 5.17157C11.5786 5.92172 12 6.93913 12 8V9" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
|
||||||
</svg>
|
</svg>
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 715 B After Width: | Height: | Size: 617 B |
3
assets/icons/play_alt.svg
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path d="M4 3L13 8L4 13V3Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 214 B |
8
assets/icons/play_bug.svg
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path d="M4 12C2.35977 11.85 1 10.575 1 9" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
|
<path d="M1.00875 15.2C1.00875 13.625 0.683456 12.275 4.00001 12.2" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
|
<path d="M7 9C7 10.575 5.62857 11.85 4 12" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
|
<path d="M4 12.2C6.98117 12.2 7 13.625 7 15.2" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
|
<rect x="2.5" y="9" width="3" height="6" rx="1.5" fill="black"/>
|
||||||
|
<path d="M9 10L13 8L4 3V7.5" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 813 B |
@@ -1,3 +1,8 @@
|
|||||||
<svg width="17" height="17" viewBox="0 0 17 17" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M6.36667 3.79167C5.53364 3.79167 4.85833 4.46697 4.85833 5.3C4.85833 6.13303 5.53364 6.80833 6.36667 6.80833C7.1997 6.80833 7.875 6.13303 7.875 5.3C7.875 4.46697 7.1997 3.79167 6.36667 3.79167ZM2.1 5.925H3.67944C3.9626 7.14732 5.05824 8.05833 6.36667 8.05833C7.67509 8.05833 8.77073 7.14732 9.05389 5.925H14.9C15.2452 5.925 15.525 5.64518 15.525 5.3C15.525 4.95482 15.2452 4.675 14.9 4.675H9.05389C8.77073 3.45268 7.67509 2.54167 6.36667 2.54167C5.05824 2.54167 3.9626 3.45268 3.67944 4.675H2.1C1.75482 4.675 1.475 4.95482 1.475 5.3C1.475 5.64518 1.75482 5.925 2.1 5.925ZM13.3206 12.325C13.0374 13.5473 11.9418 14.4583 10.6333 14.4583C9.32491 14.4583 8.22927 13.5473 7.94611 12.325H2.1C1.75482 12.325 1.475 12.0452 1.475 11.7C1.475 11.3548 1.75482 11.075 2.1 11.075H7.94611C8.22927 9.85268 9.32491 8.94167 10.6333 8.94167C11.9418 8.94167 13.0374 9.85268 13.3206 11.075H14.9C15.2452 11.075 15.525 11.3548 15.525 11.7C15.525 12.0452 15.2452 12.325 14.9 12.325H13.3206ZM9.125 11.7C9.125 10.867 9.8003 10.1917 10.6333 10.1917C11.4664 10.1917 12.1417 10.867 12.1417 11.7C12.1417 12.533 11.4664 13.2083 10.6333 13.2083C9.8003 13.2083 9.125 12.533 9.125 11.7Z" fill="black"/>
|
<path d="M2 5H4" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
|
||||||
|
<path d="M8 5L14 5" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
|
||||||
|
<path d="M12 11L14 11" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
|
||||||
|
<path d="M2 11H8" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
|
||||||
|
<circle cx="6" cy="5" r="2" fill="black" fill-opacity="0.1" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
|
||||||
|
<circle cx="10" cy="11" r="2" fill="black" fill-opacity="0.1" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
|
||||||
</svg>
|
</svg>
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 1.3 KiB After Width: | Height: | Size: 657 B |
@@ -1,5 +1,5 @@
|
|||||||
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
<path d="M7 1.75L5.88467 5.14092C5.82759 5.31446 5.73055 5.47218 5.60136 5.60136C5.47218 5.73055 5.31446 5.82759 5.14092 5.88467L1.75 7L5.14092 8.11533C5.31446 8.17241 5.47218 8.26945 5.60136 8.39864C5.73055 8.52782 5.82759 8.68554 5.88467 8.85908L7 12.25L8.11533 8.85908C8.17241 8.68554 8.26945 8.52782 8.39864 8.39864C8.52782 8.26945 8.68554 8.17241 8.85908 8.11533L12.25 7L8.85908 5.88467C8.68554 5.82759 8.52782 5.73055 8.39864 5.60136C8.26945 5.47218 8.17241 5.31446 8.11533 5.14092L7 1.75Z" fill="black" fill-opacity="0.15" stroke="black" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
|
<path d="M8 2L6.72534 5.87534C6.6601 6.07367 6.5492 6.25392 6.40155 6.40155C6.25392 6.5492 6.07367 6.6601 5.87534 6.72534L2 8L5.87534 9.27466C6.07367 9.3399 6.25392 9.4508 6.40155 9.59845C6.5492 9.74608 6.6601 9.92633 6.72534 10.1247L8 14L9.27466 10.1247C9.3399 9.92633 9.4508 9.74608 9.59845 9.59845C9.74608 9.4508 9.92633 9.3399 10.1247 9.27466L14 8L10.1247 6.72534C9.92633 6.6601 9.74608 6.5492 9.59845 6.40155C9.4508 6.25392 9.3399 6.07367 9.27466 5.87534L8 2Z" fill="black" fill-opacity="0.15" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
<path d="M2.91667 1.75V4.08333M1.75 2.91667H4.08333" stroke="black" stroke-opacity="0.75" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
|
<path d="M3.33334 2V4.66666M2 3.33334H4.66666" stroke="black" stroke-opacity="0.75" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
<path d="M11.0833 9.91667V12.25M9.91667 11.0833H12.25" stroke="black" stroke-opacity="0.75" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
|
<path d="M12.6665 11.3333V14M11.3333 12.6666H13.9999" stroke="black" stroke-opacity="0.75" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
</svg>
|
</svg>
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 1.0 KiB After Width: | Height: | Size: 998 B |
@@ -31,8 +31,6 @@
|
|||||||
"ctrl-,": "zed::OpenSettings",
|
"ctrl-,": "zed::OpenSettings",
|
||||||
"ctrl-q": "zed::Quit",
|
"ctrl-q": "zed::Quit",
|
||||||
"f4": "debugger::Start",
|
"f4": "debugger::Start",
|
||||||
"alt-f4": "debugger::RerunLastSession",
|
|
||||||
"f5": "debugger::Continue",
|
|
||||||
"shift-f5": "debugger::Stop",
|
"shift-f5": "debugger::Stop",
|
||||||
"ctrl-shift-f5": "debugger::Restart",
|
"ctrl-shift-f5": "debugger::Restart",
|
||||||
"f6": "debugger::Pause",
|
"f6": "debugger::Pause",
|
||||||
@@ -127,9 +125,7 @@
|
|||||||
"shift-f10": "editor::OpenContextMenu",
|
"shift-f10": "editor::OpenContextMenu",
|
||||||
"ctrl-shift-e": "editor::ToggleEditPrediction",
|
"ctrl-shift-e": "editor::ToggleEditPrediction",
|
||||||
"f9": "editor::ToggleBreakpoint",
|
"f9": "editor::ToggleBreakpoint",
|
||||||
"shift-f9": "editor::EditLogBreakpoint",
|
"shift-f9": "editor::EditLogBreakpoint"
|
||||||
"ctrl-shift-backspace": "editor::GoToPreviousChange",
|
|
||||||
"ctrl-shift-alt-backspace": "editor::GoToNextChange"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -148,6 +144,8 @@
|
|||||||
"ctrl->": "assistant::QuoteSelection",
|
"ctrl->": "assistant::QuoteSelection",
|
||||||
"ctrl-<": "assistant::InsertIntoEditor",
|
"ctrl-<": "assistant::InsertIntoEditor",
|
||||||
"ctrl-alt-e": "editor::SelectEnclosingSymbol",
|
"ctrl-alt-e": "editor::SelectEnclosingSymbol",
|
||||||
|
"ctrl-shift-backspace": "editor::GoToPreviousChange",
|
||||||
|
"ctrl-shift-alt-backspace": "editor::GoToNextChange",
|
||||||
"alt-enter": "editor::OpenSelectionsInMultibuffer"
|
"alt-enter": "editor::OpenSelectionsInMultibuffer"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -244,13 +242,14 @@
|
|||||||
"ctrl-i": "agent::ToggleProfileSelector",
|
"ctrl-i": "agent::ToggleProfileSelector",
|
||||||
"ctrl-alt-/": "agent::ToggleModelSelector",
|
"ctrl-alt-/": "agent::ToggleModelSelector",
|
||||||
"ctrl-shift-a": "agent::ToggleContextPicker",
|
"ctrl-shift-a": "agent::ToggleContextPicker",
|
||||||
"ctrl-shift-o": "agent::ToggleNavigationMenu",
|
"ctrl-shift-j": "agent::ToggleNavigationMenu",
|
||||||
"ctrl-shift-i": "agent::ToggleOptionsMenu",
|
"ctrl-shift-i": "agent::ToggleOptionsMenu",
|
||||||
"shift-alt-escape": "agent::ExpandMessageEditor",
|
"shift-alt-escape": "agent::ExpandMessageEditor",
|
||||||
"ctrl-alt-e": "agent::RemoveAllContext",
|
"ctrl-alt-e": "agent::RemoveAllContext",
|
||||||
"ctrl-shift-e": "project_panel::ToggleFocus",
|
"ctrl-shift-e": "project_panel::ToggleFocus",
|
||||||
"ctrl-shift-enter": "agent::ContinueThread",
|
"ctrl-shift-enter": "agent::ContinueThread",
|
||||||
"alt-enter": "agent::ContinueWithBurnMode"
|
"alt-enter": "agent::ContinueWithBurnMode",
|
||||||
|
"ctrl-alt-b": "agent::ToggleBurnMode"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -582,11 +581,24 @@
|
|||||||
"ctrl-alt-r": "task::Rerun",
|
"ctrl-alt-r": "task::Rerun",
|
||||||
"alt-t": "task::Rerun",
|
"alt-t": "task::Rerun",
|
||||||
"alt-shift-t": "task::Spawn",
|
"alt-shift-t": "task::Spawn",
|
||||||
"alt-shift-r": ["task::Spawn", { "reveal_target": "center" }]
|
"alt-shift-r": ["task::Spawn", { "reveal_target": "center" }],
|
||||||
// also possible to spawn tasks by name:
|
// also possible to spawn tasks by name:
|
||||||
// "foo-bar": ["task::Spawn", { "task_name": "MyTask", "reveal_target": "dock" }]
|
// "foo-bar": ["task::Spawn", { "task_name": "MyTask", "reveal_target": "dock" }]
|
||||||
// or by tag:
|
// or by tag:
|
||||||
// "foo-bar": ["task::Spawn", { "task_tag": "MyTag" }],
|
// "foo-bar": ["task::Spawn", { "task_tag": "MyTag" }],
|
||||||
|
"f5": "debugger::RerunLastSession"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Workspace && debugger_running",
|
||||||
|
"bindings": {
|
||||||
|
"f5": "zed::NoAction"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Workspace && debugger_stopped",
|
||||||
|
"bindings": {
|
||||||
|
"f5": "debugger::Continue"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -872,7 +884,8 @@
|
|||||||
"context": "DebugPanel",
|
"context": "DebugPanel",
|
||||||
"bindings": {
|
"bindings": {
|
||||||
"ctrl-t": "debugger::ToggleThreadPicker",
|
"ctrl-t": "debugger::ToggleThreadPicker",
|
||||||
"ctrl-i": "debugger::ToggleSessionPicker"
|
"ctrl-i": "debugger::ToggleSessionPicker",
|
||||||
|
"shift-alt-escape": "debugger::ToggleExpandItem"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -927,6 +940,13 @@
|
|||||||
"tab": "channel_modal::ToggleMode"
|
"tab": "channel_modal::ToggleMode"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"context": "FileFinder",
|
||||||
|
"bindings": {
|
||||||
|
"ctrl-shift-a": "file_finder::ToggleSplitMenu",
|
||||||
|
"ctrl-shift-i": "file_finder::ToggleFilterMenu"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"context": "FileFinder || (FileFinder > Picker > Editor) || (FileFinder > Picker > menu)",
|
"context": "FileFinder || (FileFinder > Picker > Editor) || (FileFinder > Picker > menu)",
|
||||||
"bindings": {
|
"bindings": {
|
||||||
@@ -1018,5 +1038,12 @@
|
|||||||
"bindings": {
|
"bindings": {
|
||||||
"enter": "menu::Confirm"
|
"enter": "menu::Confirm"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "RunModal",
|
||||||
|
"bindings": {
|
||||||
|
"ctrl-tab": "pane::ActivateNextItem",
|
||||||
|
"ctrl-shift-tab": "pane::ActivatePreviousItem"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -4,8 +4,6 @@
|
|||||||
"use_key_equivalents": true,
|
"use_key_equivalents": true,
|
||||||
"bindings": {
|
"bindings": {
|
||||||
"f4": "debugger::Start",
|
"f4": "debugger::Start",
|
||||||
"alt-f4": "debugger::RerunLastSession",
|
|
||||||
"f5": "debugger::Continue",
|
|
||||||
"shift-f5": "debugger::Stop",
|
"shift-f5": "debugger::Stop",
|
||||||
"shift-cmd-f5": "debugger::Restart",
|
"shift-cmd-f5": "debugger::Restart",
|
||||||
"f6": "debugger::Pause",
|
"f6": "debugger::Pause",
|
||||||
@@ -279,13 +277,14 @@
|
|||||||
"cmd-i": "agent::ToggleProfileSelector",
|
"cmd-i": "agent::ToggleProfileSelector",
|
||||||
"cmd-alt-/": "agent::ToggleModelSelector",
|
"cmd-alt-/": "agent::ToggleModelSelector",
|
||||||
"cmd-shift-a": "agent::ToggleContextPicker",
|
"cmd-shift-a": "agent::ToggleContextPicker",
|
||||||
"cmd-shift-o": "agent::ToggleNavigationMenu",
|
"cmd-shift-j": "agent::ToggleNavigationMenu",
|
||||||
"cmd-shift-i": "agent::ToggleOptionsMenu",
|
"cmd-shift-i": "agent::ToggleOptionsMenu",
|
||||||
"shift-alt-escape": "agent::ExpandMessageEditor",
|
"shift-alt-escape": "agent::ExpandMessageEditor",
|
||||||
"cmd-alt-e": "agent::RemoveAllContext",
|
"cmd-alt-e": "agent::RemoveAllContext",
|
||||||
"cmd-shift-e": "project_panel::ToggleFocus",
|
"cmd-shift-e": "project_panel::ToggleFocus",
|
||||||
"cmd-shift-enter": "agent::ContinueThread",
|
"cmd-shift-enter": "agent::ContinueThread",
|
||||||
"alt-enter": "agent::ContinueWithBurnMode"
|
"alt-enter": "agent::ContinueWithBurnMode",
|
||||||
|
"cmd-alt-b": "agent::ToggleBurnMode"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -545,9 +544,7 @@
|
|||||||
"cmd-\\": "pane::SplitRight",
|
"cmd-\\": "pane::SplitRight",
|
||||||
"cmd-k v": "markdown::OpenPreviewToTheSide",
|
"cmd-k v": "markdown::OpenPreviewToTheSide",
|
||||||
"cmd-shift-v": "markdown::OpenPreview",
|
"cmd-shift-v": "markdown::OpenPreview",
|
||||||
"ctrl-cmd-c": "editor::DisplayCursorNames",
|
"ctrl-cmd-c": "editor::DisplayCursorNames"
|
||||||
"cmd-shift-backspace": "editor::GoToPreviousChange",
|
|
||||||
"cmd-shift-alt-backspace": "editor::GoToNextChange"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -555,7 +552,9 @@
|
|||||||
"use_key_equivalents": true,
|
"use_key_equivalents": true,
|
||||||
"bindings": {
|
"bindings": {
|
||||||
"cmd-shift-o": "outline::Toggle",
|
"cmd-shift-o": "outline::Toggle",
|
||||||
"ctrl-g": "go_to_line::Toggle"
|
"ctrl-g": "go_to_line::Toggle",
|
||||||
|
"cmd-shift-backspace": "editor::GoToPreviousChange",
|
||||||
|
"cmd-shift-alt-backspace": "editor::GoToNextChange"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -634,7 +633,8 @@
|
|||||||
"cmd-k shift-right": "workspace::SwapPaneRight",
|
"cmd-k shift-right": "workspace::SwapPaneRight",
|
||||||
"cmd-k shift-up": "workspace::SwapPaneUp",
|
"cmd-k shift-up": "workspace::SwapPaneUp",
|
||||||
"cmd-k shift-down": "workspace::SwapPaneDown",
|
"cmd-k shift-down": "workspace::SwapPaneDown",
|
||||||
"cmd-shift-x": "zed::Extensions"
|
"cmd-shift-x": "zed::Extensions",
|
||||||
|
"f5": "debugger::RerunLastSession"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -651,6 +651,20 @@
|
|||||||
// "foo-bar": ["task::Spawn", { "task_tag": "MyTag" }],
|
// "foo-bar": ["task::Spawn", { "task_tag": "MyTag" }],
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"context": "Workspace && debugger_running",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"f5": "zed::NoAction"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Workspace && debugger_stopped",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"f5": "debugger::Continue"
|
||||||
|
}
|
||||||
|
},
|
||||||
// Bindings from Sublime Text
|
// Bindings from Sublime Text
|
||||||
{
|
{
|
||||||
"context": "Editor",
|
"context": "Editor",
|
||||||
@@ -935,7 +949,8 @@
|
|||||||
"context": "DebugPanel",
|
"context": "DebugPanel",
|
||||||
"bindings": {
|
"bindings": {
|
||||||
"cmd-t": "debugger::ToggleThreadPicker",
|
"cmd-t": "debugger::ToggleThreadPicker",
|
||||||
"cmd-i": "debugger::ToggleSessionPicker"
|
"cmd-i": "debugger::ToggleSessionPicker",
|
||||||
|
"shift-alt-escape": "debugger::ToggleExpandItem"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -986,6 +1001,14 @@
|
|||||||
"tab": "channel_modal::ToggleMode"
|
"tab": "channel_modal::ToggleMode"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"context": "FileFinder",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"cmd-shift-a": "file_finder::ToggleSplitMenu",
|
||||||
|
"cmd-shift-i": "file_finder::ToggleFilterMenu"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"context": "FileFinder || (FileFinder > Picker > Editor) || (FileFinder > Picker > menu)",
|
"context": "FileFinder || (FileFinder > Picker > Editor) || (FileFinder > Picker > menu)",
|
||||||
"use_key_equivalents": true,
|
"use_key_equivalents": true,
|
||||||
@@ -1108,5 +1131,13 @@
|
|||||||
"bindings": {
|
"bindings": {
|
||||||
"enter": "menu::Confirm"
|
"enter": "menu::Confirm"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "RunModal",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"ctrl-tab": "pane::ActivateNextItem",
|
||||||
|
"ctrl-shift-tab": "pane::ActivatePreviousItem"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
85
assets/keymaps/linux/cursor.json
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
[
|
||||||
|
// Cursor for MacOS. See: https://docs.cursor.com/kbd
|
||||||
|
{
|
||||||
|
"context": "Workspace",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"ctrl-i": "agent::ToggleFocus",
|
||||||
|
"ctrl-shift-i": "agent::ToggleFocus",
|
||||||
|
"ctrl-l": "agent::ToggleFocus",
|
||||||
|
"ctrl-shift-l": "agent::ToggleFocus",
|
||||||
|
"ctrl-alt-b": "agent::ToggleFocus",
|
||||||
|
"ctrl-shift-j": "agent::OpenConfiguration"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Editor && mode == full",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"ctrl-i": "agent::ToggleFocus",
|
||||||
|
"ctrl-shift-i": "agent::ToggleFocus",
|
||||||
|
"ctrl-shift-l": "assistant::QuoteSelection", // In cursor uses "Ask" mode
|
||||||
|
"ctrl-l": "assistant::QuoteSelection", // In cursor uses "Agent" mode
|
||||||
|
"ctrl-k": "assistant::InlineAssist",
|
||||||
|
"ctrl-shift-k": "assistant::InsertIntoEditor"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "InlineAssistEditor",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"ctrl-shift-backspace": "editor::Cancel"
|
||||||
|
// "alt-enter": // Quick Question
|
||||||
|
// "ctrl-shift-enter": // Full File Context
|
||||||
|
// "ctrl-shift-k": // Toggle input focus (editor <> inline assist)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "AgentPanel || ContextEditor || (MessageEditor > Editor)",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"ctrl-i": "workspace::ToggleRightDock",
|
||||||
|
"ctrl-shift-i": "workspace::ToggleRightDock",
|
||||||
|
"ctrl-l": "workspace::ToggleRightDock",
|
||||||
|
"ctrl-shift-l": "workspace::ToggleRightDock",
|
||||||
|
"ctrl-alt-b": "workspace::ToggleRightDock",
|
||||||
|
"ctrl-w": "workspace::ToggleRightDock", // technically should close chat
|
||||||
|
"ctrl-.": "agent::ToggleProfileSelector",
|
||||||
|
"ctrl-/": "agent::ToggleModelSelector",
|
||||||
|
"ctrl-shift-backspace": "editor::Cancel",
|
||||||
|
"ctrl-r": "agent::NewThread",
|
||||||
|
"ctrl-shift-v": "editor::Paste",
|
||||||
|
"ctrl-shift-k": "assistant::InsertIntoEditor"
|
||||||
|
// "escape": "agent::ToggleFocus"
|
||||||
|
///// Enable when Zed supports multiple thread tabs
|
||||||
|
// "ctrl-t": // new thread tab
|
||||||
|
// "ctrl-[": // next thread tab
|
||||||
|
// "ctrl-]": // next thread tab
|
||||||
|
///// Enable if Zed adds support for keyboard navigation of thread elements
|
||||||
|
// "tab": // cycle to next message
|
||||||
|
// "shift-tab": // cycle to previous message
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Editor && editor_agent_diff",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"ctrl-enter": "agent::KeepAll",
|
||||||
|
"ctrl-backspace": "agent::RejectAll"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Editor && mode == full && edit_prediction",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"ctrl-right": "editor::AcceptPartialEditPrediction"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Terminal",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"ctrl-k": "assistant::InlineAssist"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
@@ -51,7 +51,11 @@
|
|||||||
"ctrl-k ctrl-l": "editor::ConvertToLowerCase",
|
"ctrl-k ctrl-l": "editor::ConvertToLowerCase",
|
||||||
"shift-alt-m": "markdown::OpenPreviewToTheSide",
|
"shift-alt-m": "markdown::OpenPreviewToTheSide",
|
||||||
"ctrl-backspace": "editor::DeleteToPreviousWordStart",
|
"ctrl-backspace": "editor::DeleteToPreviousWordStart",
|
||||||
"ctrl-delete": "editor::DeleteToNextWordEnd"
|
"ctrl-delete": "editor::DeleteToNextWordEnd",
|
||||||
|
"alt-right": "editor::MoveToNextSubwordEnd",
|
||||||
|
"alt-left": "editor::MoveToPreviousSubwordStart",
|
||||||
|
"alt-shift-right": "editor::SelectToNextSubwordEnd",
|
||||||
|
"alt-shift-left": "editor::SelectToPreviousSubwordStart"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
85
assets/keymaps/macos/cursor.json
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
[
|
||||||
|
// Cursor for MacOS. See: https://docs.cursor.com/kbd
|
||||||
|
{
|
||||||
|
"context": "Workspace",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"cmd-i": "agent::ToggleFocus",
|
||||||
|
"cmd-shift-i": "agent::ToggleFocus",
|
||||||
|
"cmd-l": "agent::ToggleFocus",
|
||||||
|
"cmd-shift-l": "agent::ToggleFocus",
|
||||||
|
"cmd-alt-b": "agent::ToggleFocus",
|
||||||
|
"cmd-shift-j": "agent::OpenConfiguration"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Editor && mode == full",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"cmd-i": "agent::ToggleFocus",
|
||||||
|
"cmd-shift-i": "agent::ToggleFocus",
|
||||||
|
"cmd-shift-l": "assistant::QuoteSelection", // In cursor uses "Ask" mode
|
||||||
|
"cmd-l": "assistant::QuoteSelection", // In cursor uses "Agent" mode
|
||||||
|
"cmd-k": "assistant::InlineAssist",
|
||||||
|
"cmd-shift-k": "assistant::InsertIntoEditor"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "InlineAssistEditor",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"cmd-shift-backspace": "editor::Cancel"
|
||||||
|
// "alt-enter": // Quick Question
|
||||||
|
// "cmd-shift-enter": // Full File Context
|
||||||
|
// "cmd-shift-k": // Toggle input focus (editor <> inline assist)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "AgentPanel || ContextEditor || (MessageEditor > Editor)",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"cmd-i": "workspace::ToggleRightDock",
|
||||||
|
"cmd-shift-i": "workspace::ToggleRightDock",
|
||||||
|
"cmd-l": "workspace::ToggleRightDock",
|
||||||
|
"cmd-shift-l": "workspace::ToggleRightDock",
|
||||||
|
"cmd-alt-b": "workspace::ToggleRightDock",
|
||||||
|
"cmd-w": "workspace::ToggleRightDock", // technically should close chat
|
||||||
|
"cmd-.": "agent::ToggleProfileSelector",
|
||||||
|
"cmd-/": "agent::ToggleModelSelector",
|
||||||
|
"cmd-shift-backspace": "editor::Cancel",
|
||||||
|
"cmd-r": "agent::NewThread",
|
||||||
|
"cmd-shift-v": "editor::Paste",
|
||||||
|
"cmd-shift-k": "assistant::InsertIntoEditor"
|
||||||
|
// "escape": "agent::ToggleFocus"
|
||||||
|
///// Enable when Zed supports multiple thread tabs
|
||||||
|
// "cmd-t": // new thread tab
|
||||||
|
// "cmd-[": // next thread tab
|
||||||
|
// "cmd-]": // next thread tab
|
||||||
|
///// Enable if Zed adds support for keyboard navigation of thread elements
|
||||||
|
// "tab": // cycle to next message
|
||||||
|
// "shift-tab": // cycle to previous message
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Editor && editor_agent_diff",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"cmd-enter": "agent::KeepAll",
|
||||||
|
"cmd-backspace": "agent::RejectAll"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Editor && mode == full && edit_prediction",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"cmd-right": "editor::AcceptPartialEditPrediction"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"context": "Terminal",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"cmd-k": "assistant::InlineAssist"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
@@ -53,7 +53,11 @@
|
|||||||
"cmd-shift-j": "editor::JoinLines",
|
"cmd-shift-j": "editor::JoinLines",
|
||||||
"shift-alt-m": "markdown::OpenPreviewToTheSide",
|
"shift-alt-m": "markdown::OpenPreviewToTheSide",
|
||||||
"ctrl-backspace": "editor::DeleteToPreviousWordStart",
|
"ctrl-backspace": "editor::DeleteToPreviousWordStart",
|
||||||
"ctrl-delete": "editor::DeleteToNextWordEnd"
|
"ctrl-delete": "editor::DeleteToNextWordEnd",
|
||||||
|
"ctrl-right": "editor::MoveToNextSubwordEnd",
|
||||||
|
"ctrl-left": "editor::MoveToPreviousSubwordStart",
|
||||||
|
"ctrl-shift-right": "editor::SelectToNextSubwordEnd",
|
||||||
|
"ctrl-shift-left": "editor::SelectToPreviousSubwordStart"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -838,6 +838,19 @@
|
|||||||
"tab": "editor::AcceptEditPrediction"
|
"tab": "editor::AcceptEditPrediction"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"context": "MessageEditor > Editor && VimControl",
|
||||||
|
"bindings": {
|
||||||
|
"enter": "agent::Chat",
|
||||||
|
// TODO: Implement search
|
||||||
|
"/": null,
|
||||||
|
"?": null,
|
||||||
|
"#": null,
|
||||||
|
"*": null,
|
||||||
|
"n": null,
|
||||||
|
"shift-n": null
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"context": "os != macos && Editor && edit_prediction_conflict",
|
"context": "os != macos && Editor && edit_prediction_conflict",
|
||||||
"bindings": {
|
"bindings": {
|
||||||
|
|||||||
@@ -128,6 +128,8 @@
|
|||||||
//
|
//
|
||||||
// Default: true
|
// Default: true
|
||||||
"restore_on_file_reopen": true,
|
"restore_on_file_reopen": true,
|
||||||
|
// Whether to automatically close files that have been deleted on disk.
|
||||||
|
"close_on_file_delete": false,
|
||||||
// Size of the drop target in the editor.
|
// Size of the drop target in the editor.
|
||||||
"drop_target_size": 0.2,
|
"drop_target_size": 0.2,
|
||||||
// Whether the window should be closed when using 'close active item' on a window with no tabs.
|
// Whether the window should be closed when using 'close active item' on a window with no tabs.
|
||||||
@@ -714,7 +716,7 @@
|
|||||||
"version": "2",
|
"version": "2",
|
||||||
// Whether the agent is enabled.
|
// Whether the agent is enabled.
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
/// What completion mode to start new threads in, if available. Can be 'normal' or 'max'.
|
/// What completion mode to start new threads in, if available. Can be 'normal' or 'burn'.
|
||||||
"preferred_completion_mode": "normal",
|
"preferred_completion_mode": "normal",
|
||||||
// Whether to show the agent panel button in the status bar.
|
// Whether to show the agent panel button in the status bar.
|
||||||
"button": true,
|
"button": true,
|
||||||
@@ -731,13 +733,6 @@
|
|||||||
// The model to use.
|
// The model to use.
|
||||||
"model": "claude-sonnet-4"
|
"model": "claude-sonnet-4"
|
||||||
},
|
},
|
||||||
// The model to use when applying edits from the agent.
|
|
||||||
"editor_model": {
|
|
||||||
// The provider to use.
|
|
||||||
"provider": "zed.dev",
|
|
||||||
// The model to use.
|
|
||||||
"model": "claude-sonnet-4"
|
|
||||||
},
|
|
||||||
// Additional parameters for language model requests. When making a request to a model, parameters will be taken
|
// Additional parameters for language model requests. When making a request to a model, parameters will be taken
|
||||||
// from the last entry in this list that matches the model's provider and name. In each entry, both provider
|
// from the last entry in this list that matches the model's provider and name. In each entry, both provider
|
||||||
// and model are optional, so that you can specify parameters for either one.
|
// and model are optional, so that you can specify parameters for either one.
|
||||||
@@ -1314,7 +1309,17 @@
|
|||||||
// Settings related to running tasks.
|
// Settings related to running tasks.
|
||||||
"tasks": {
|
"tasks": {
|
||||||
"variables": {},
|
"variables": {},
|
||||||
"enabled": true
|
"enabled": true,
|
||||||
|
// Use LSP tasks over Zed language extension ones.
|
||||||
|
// If no LSP tasks are returned due to error/timeout or regular execution,
|
||||||
|
// Zed language extension tasks will be used instead.
|
||||||
|
//
|
||||||
|
// Other Zed tasks will still be shown:
|
||||||
|
// * Zed task from either of the task config file
|
||||||
|
// * Zed task from history (e.g. one-off task was spawned before)
|
||||||
|
//
|
||||||
|
// Default: true
|
||||||
|
"prefer_lsp": true
|
||||||
},
|
},
|
||||||
// An object whose keys are language names, and whose values
|
// An object whose keys are language names, and whose values
|
||||||
// are arrays of filenames or extensions of files that should
|
// are arrays of filenames or extensions of files that should
|
||||||
@@ -1452,9 +1457,7 @@
|
|||||||
"language_servers": ["erlang-ls", "!elp", "..."]
|
"language_servers": ["erlang-ls", "!elp", "..."]
|
||||||
},
|
},
|
||||||
"Git Commit": {
|
"Git Commit": {
|
||||||
"allow_rewrap": "anywhere",
|
"allow_rewrap": "anywhere"
|
||||||
"preferred_line_length": 72,
|
|
||||||
"soft_wrap": "bounded"
|
|
||||||
},
|
},
|
||||||
"Go": {
|
"Go": {
|
||||||
"code_actions_on_format": {
|
"code_actions_on_format": {
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
// Some example tasks for common languages.
|
||||||
|
//
|
||||||
|
// For more documentation on how to configure debug tasks,
|
||||||
|
// see: https://zed.dev/docs/debugger
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"label": "Debug active PHP file",
|
"label": "Debug active PHP file",
|
||||||
|
|||||||
5
assets/settings/initial_local_debug_tasks.json
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
// Project-local debug tasks
|
||||||
|
//
|
||||||
|
// For more documentation on how to configure debug tasks,
|
||||||
|
// see: https://zed.dev/docs/debugger
|
||||||
|
[]
|
||||||
@@ -311,6 +311,31 @@ impl ActivityIndicator {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(session) = self
|
||||||
|
.project
|
||||||
|
.read(cx)
|
||||||
|
.dap_store()
|
||||||
|
.read(cx)
|
||||||
|
.sessions()
|
||||||
|
.find(|s| !s.read(cx).is_started())
|
||||||
|
{
|
||||||
|
return Some(Content {
|
||||||
|
icon: Some(
|
||||||
|
Icon::new(IconName::ArrowCircle)
|
||||||
|
.size(IconSize::Small)
|
||||||
|
.with_animation(
|
||||||
|
"arrow-circle",
|
||||||
|
Animation::new(Duration::from_secs(2)).repeat(),
|
||||||
|
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
|
||||||
|
)
|
||||||
|
.into_any_element(),
|
||||||
|
),
|
||||||
|
message: format!("Debug: {}", session.read(cx).adapter()),
|
||||||
|
tooltip_message: Some(session.read(cx).label().to_string()),
|
||||||
|
on_click: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let current_job = self
|
let current_job = self
|
||||||
.project
|
.project
|
||||||
.read(cx)
|
.read(cx)
|
||||||
@@ -472,7 +497,7 @@ impl ActivityIndicator {
|
|||||||
})),
|
})),
|
||||||
tooltip_message: None,
|
tooltip_message: None,
|
||||||
}),
|
}),
|
||||||
AutoUpdateStatus::Downloading => Some(Content {
|
AutoUpdateStatus::Downloading { version } => Some(Content {
|
||||||
icon: Some(
|
icon: Some(
|
||||||
Icon::new(IconName::Download)
|
Icon::new(IconName::Download)
|
||||||
.size(IconSize::Small)
|
.size(IconSize::Small)
|
||||||
@@ -482,9 +507,9 @@ impl ActivityIndicator {
|
|||||||
on_click: Some(Arc::new(|this, window, cx| {
|
on_click: Some(Arc::new(|this, window, cx| {
|
||||||
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
||||||
})),
|
})),
|
||||||
tooltip_message: None,
|
tooltip_message: Some(Self::version_tooltip_message(&version)),
|
||||||
}),
|
}),
|
||||||
AutoUpdateStatus::Installing => Some(Content {
|
AutoUpdateStatus::Installing { version } => Some(Content {
|
||||||
icon: Some(
|
icon: Some(
|
||||||
Icon::new(IconName::Download)
|
Icon::new(IconName::Download)
|
||||||
.size(IconSize::Small)
|
.size(IconSize::Small)
|
||||||
@@ -494,7 +519,7 @@ impl ActivityIndicator {
|
|||||||
on_click: Some(Arc::new(|this, window, cx| {
|
on_click: Some(Arc::new(|this, window, cx| {
|
||||||
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
||||||
})),
|
})),
|
||||||
tooltip_message: None,
|
tooltip_message: Some(Self::version_tooltip_message(&version)),
|
||||||
}),
|
}),
|
||||||
AutoUpdateStatus::Updated {
|
AutoUpdateStatus::Updated {
|
||||||
binary_path,
|
binary_path,
|
||||||
@@ -508,7 +533,7 @@ impl ActivityIndicator {
|
|||||||
};
|
};
|
||||||
move |_, _, cx| workspace::reload(&reload, cx)
|
move |_, _, cx| workspace::reload(&reload, cx)
|
||||||
})),
|
})),
|
||||||
tooltip_message: Some(Self::install_version_tooltip_message(&version)),
|
tooltip_message: Some(Self::version_tooltip_message(&version)),
|
||||||
}),
|
}),
|
||||||
AutoUpdateStatus::Errored => Some(Content {
|
AutoUpdateStatus::Errored => Some(Content {
|
||||||
icon: Some(
|
icon: Some(
|
||||||
@@ -548,8 +573,8 @@ impl ActivityIndicator {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn install_version_tooltip_message(version: &VersionCheckType) -> String {
|
fn version_tooltip_message(version: &VersionCheckType) -> String {
|
||||||
format!("Install version: {}", {
|
format!("Version: {}", {
|
||||||
match version {
|
match version {
|
||||||
auto_update::VersionCheckType::Sha(sha) => format!("{}…", sha.short()),
|
auto_update::VersionCheckType::Sha(sha) => format!("{}…", sha.short()),
|
||||||
auto_update::VersionCheckType::Semantic(semantic_version) => {
|
auto_update::VersionCheckType::Semantic(semantic_version) => {
|
||||||
@@ -699,17 +724,17 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_install_version_tooltip_message() {
|
fn test_version_tooltip_message() {
|
||||||
let message = ActivityIndicator::install_version_tooltip_message(
|
let message = ActivityIndicator::version_tooltip_message(&VersionCheckType::Semantic(
|
||||||
&VersionCheckType::Semantic(SemanticVersion::new(1, 0, 0)),
|
SemanticVersion::new(1, 0, 0),
|
||||||
);
|
));
|
||||||
|
|
||||||
assert_eq!(message, "Install version: 1.0.0");
|
assert_eq!(message, "Version: 1.0.0");
|
||||||
|
|
||||||
let message = ActivityIndicator::install_version_tooltip_message(&VersionCheckType::Sha(
|
let message = ActivityIndicator::version_tooltip_message(&VersionCheckType::Sha(
|
||||||
AppCommitSha::new("14d9a4189f058d8736339b06ff2340101eaea5af".to_string()),
|
AppCommitSha::new("14d9a4189f058d8736339b06ff2340101eaea5af".to_string()),
|
||||||
));
|
));
|
||||||
|
|
||||||
assert_eq!(message, "Install version: 14d9a41…");
|
assert_eq!(message, "Version: 14d9a41…");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ git.workspace = true
|
|||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
heed.workspace = true
|
heed.workspace = true
|
||||||
html_to_markdown.workspace = true
|
html_to_markdown.workspace = true
|
||||||
|
indoc.workspace = true
|
||||||
http_client.workspace = true
|
http_client.workspace = true
|
||||||
indexed_docs.workspace = true
|
indexed_docs.workspace = true
|
||||||
inventory.workspace = true
|
inventory.workspace = true
|
||||||
@@ -78,6 +79,7 @@ serde_json.workspace = true
|
|||||||
serde_json_lenient.workspace = true
|
serde_json_lenient.workspace = true
|
||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
|
sqlez.workspace = true
|
||||||
streaming_diff.workspace = true
|
streaming_diff.workspace = true
|
||||||
telemetry.workspace = true
|
telemetry.workspace = true
|
||||||
telemetry_events.workspace = true
|
telemetry_events.workspace = true
|
||||||
@@ -97,6 +99,7 @@ workspace-hack.workspace = true
|
|||||||
workspace.workspace = true
|
workspace.workspace = true
|
||||||
zed_actions.workspace = true
|
zed_actions.workspace = true
|
||||||
zed_llm_client.workspace = true
|
zed_llm_client.workspace = true
|
||||||
|
zstd.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
buffer_diff = { workspace = true, features = ["test-support"] }
|
buffer_diff = { workspace = true, features = ["test-support"] }
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ use util::ResultExt as _;
|
|||||||
use util::markdown::MarkdownCodeBlock;
|
use util::markdown::MarkdownCodeBlock;
|
||||||
use workspace::{CollaboratorId, Workspace};
|
use workspace::{CollaboratorId, Workspace};
|
||||||
use zed_actions::assistant::OpenRulesLibrary;
|
use zed_actions::assistant::OpenRulesLibrary;
|
||||||
|
use zed_llm_client::CompletionIntent;
|
||||||
|
|
||||||
pub struct ActiveThread {
|
pub struct ActiveThread {
|
||||||
context_store: Entity<ContextStore>,
|
context_store: Entity<ContextStore>,
|
||||||
@@ -1016,6 +1017,15 @@ impl ActiveThread {
|
|||||||
self.play_notification_sound(cx);
|
self.play_notification_sound(cx);
|
||||||
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
|
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||||
}
|
}
|
||||||
|
ThreadEvent::ToolUseLimitReached => {
|
||||||
|
self.play_notification_sound(cx);
|
||||||
|
self.show_notification(
|
||||||
|
"Consecutive tool use limit reached.",
|
||||||
|
IconName::Warning,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
ThreadEvent::StreamedAssistantText(message_id, text) => {
|
ThreadEvent::StreamedAssistantText(message_id, text) => {
|
||||||
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
|
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
|
||||||
rendered_message.append_text(text, cx);
|
rendered_message.append_text(text, cx);
|
||||||
@@ -1436,6 +1446,7 @@ impl ActiveThread {
|
|||||||
let request = language_model::LanguageModelRequest {
|
let request = language_model::LanguageModelRequest {
|
||||||
thread_id: None,
|
thread_id: None,
|
||||||
prompt_id: None,
|
prompt_id: None,
|
||||||
|
intent: None,
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: vec![request_message],
|
messages: vec![request_message],
|
||||||
tools: vec![],
|
tools: vec![],
|
||||||
@@ -1533,9 +1544,22 @@ impl ActiveThread {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
|
fn cancel_editing_message(
|
||||||
|
&mut self,
|
||||||
|
_: &menu::Cancel,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
self.editing_message.take();
|
self.editing_message.take();
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
|
||||||
|
if let Some(workspace) = self.workspace.upgrade() {
|
||||||
|
workspace.update(cx, |workspace, cx| {
|
||||||
|
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||||
|
panel.focus_handle(cx).focus(window);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn confirm_editing_message(
|
fn confirm_editing_message(
|
||||||
@@ -1597,7 +1621,12 @@ impl ActiveThread {
|
|||||||
|
|
||||||
this.thread.update(cx, |thread, cx| {
|
this.thread.update(cx, |thread, cx| {
|
||||||
thread.advance_prompt_id();
|
thread.advance_prompt_id();
|
||||||
thread.send_to_model(model.model, Some(window.window_handle()), cx);
|
thread.send_to_model(
|
||||||
|
model.model,
|
||||||
|
CompletionIntent::UserPrompt,
|
||||||
|
Some(window.window_handle()),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
});
|
});
|
||||||
this._load_edited_message_context_task = None;
|
this._load_edited_message_context_task = None;
|
||||||
cx.notify();
|
cx.notify();
|
||||||
@@ -1818,6 +1847,7 @@ impl ActiveThread {
|
|||||||
|
|
||||||
let colors = cx.theme().colors();
|
let colors = cx.theme().colors();
|
||||||
let editor_bg_color = colors.editor_background;
|
let editor_bg_color = colors.editor_background;
|
||||||
|
let panel_bg = colors.panel_background;
|
||||||
|
|
||||||
let open_as_markdown = IconButton::new(("open-as-markdown", ix), IconName::DocumentText)
|
let open_as_markdown = IconButton::new(("open-as-markdown", ix), IconName::DocumentText)
|
||||||
.icon_size(IconSize::XSmall)
|
.icon_size(IconSize::XSmall)
|
||||||
@@ -1838,7 +1868,6 @@ impl ActiveThread {
|
|||||||
const RESPONSE_PADDING_X: Pixels = px(19.);
|
const RESPONSE_PADDING_X: Pixels = px(19.);
|
||||||
|
|
||||||
let show_feedback = thread.is_turn_end(ix);
|
let show_feedback = thread.is_turn_end(ix);
|
||||||
|
|
||||||
let feedback_container = h_flex()
|
let feedback_container = h_flex()
|
||||||
.group("feedback_container")
|
.group("feedback_container")
|
||||||
.mt_1()
|
.mt_1()
|
||||||
@@ -2135,16 +2164,14 @@ impl ActiveThread {
|
|||||||
message_id > *editing_message_id
|
message_id > *editing_message_id
|
||||||
});
|
});
|
||||||
|
|
||||||
let panel_background = cx.theme().colors().panel_background;
|
|
||||||
|
|
||||||
let backdrop = div()
|
let backdrop = div()
|
||||||
.id("backdrop")
|
.id(("backdrop", ix))
|
||||||
.stop_mouse_events_except_scroll()
|
.size_full()
|
||||||
.absolute()
|
.absolute()
|
||||||
.inset_0()
|
.inset_0()
|
||||||
.size_full()
|
.bg(panel_bg)
|
||||||
.bg(panel_background)
|
|
||||||
.opacity(0.8)
|
.opacity(0.8)
|
||||||
|
.block_mouse_except_scroll()
|
||||||
.on_click(cx.listener(Self::handle_cancel_click));
|
.on_click(cx.listener(Self::handle_cancel_click));
|
||||||
|
|
||||||
v_flex()
|
v_flex()
|
||||||
@@ -3691,7 +3718,8 @@ mod tests {
|
|||||||
|
|
||||||
// Stream response to user message
|
// Stream response to user message
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
let request = thread.to_completion_request(model.clone(), cx);
|
let request =
|
||||||
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx);
|
||||||
thread.stream_completion(request, model, cx.active_window(), cx)
|
thread.stream_completion(request, model, cx.active_window(), cx)
|
||||||
});
|
});
|
||||||
// Follow the agent
|
// Follow the agent
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ actions!(
|
|||||||
ResetTrialEndUpsell,
|
ResetTrialEndUpsell,
|
||||||
ContinueThread,
|
ContinueThread,
|
||||||
ContinueWithBurnMode,
|
ContinueWithBurnMode,
|
||||||
|
ToggleBurnMode,
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -699,7 +699,7 @@ fn render_diff_hunk_controls(
|
|||||||
.rounded_b_md()
|
.rounded_b_md()
|
||||||
.bg(cx.theme().colors().editor_background)
|
.bg(cx.theme().colors().editor_background)
|
||||||
.gap_1()
|
.gap_1()
|
||||||
.stop_mouse_events_except_scroll()
|
.block_mouse_except_scroll()
|
||||||
.shadow_md()
|
.shadow_md()
|
||||||
.children(vec![
|
.children(vec![
|
||||||
Button::new(("reject", row as u64), "Reject")
|
Button::new(("reject", row as u64), "Reject")
|
||||||
@@ -1372,6 +1372,7 @@ impl AgentDiff {
|
|||||||
| ThreadEvent::ToolFinished { .. }
|
| ThreadEvent::ToolFinished { .. }
|
||||||
| ThreadEvent::CheckpointChanged
|
| ThreadEvent::CheckpointChanged
|
||||||
| ThreadEvent::ToolConfirmationNeeded
|
| ThreadEvent::ToolConfirmationNeeded
|
||||||
|
| ThreadEvent::ToolUseLimitReached
|
||||||
| ThreadEvent::CancelEditing => {}
|
| ThreadEvent::CancelEditing => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1464,7 +1465,10 @@ impl AgentDiff {
|
|||||||
if !AgentSettings::get_global(cx).single_file_review {
|
if !AgentSettings::get_global(cx).single_file_review {
|
||||||
for (editor, _) in self.reviewing_editors.drain() {
|
for (editor, _) in self.reviewing_editors.drain() {
|
||||||
editor
|
editor
|
||||||
.update(cx, |editor, cx| editor.end_temporary_diff_override(cx))
|
.update(cx, |editor, cx| {
|
||||||
|
editor.end_temporary_diff_override(cx);
|
||||||
|
editor.unregister_addon::<EditorAgentDiffAddon>();
|
||||||
|
})
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@@ -1560,7 +1564,10 @@ impl AgentDiff {
|
|||||||
|
|
||||||
if in_workspace {
|
if in_workspace {
|
||||||
editor
|
editor
|
||||||
.update(cx, |editor, cx| editor.end_temporary_diff_override(cx))
|
.update(cx, |editor, cx| {
|
||||||
|
editor.end_temporary_diff_override(cx);
|
||||||
|
editor.unregister_addon::<EditorAgentDiffAddon>();
|
||||||
|
})
|
||||||
.ok();
|
.ok();
|
||||||
self.reviewing_editors.remove(&editor);
|
self.reviewing_editors.remove(&editor);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
use agent_settings::AgentSettings;
|
use agent_settings::AgentSettings;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use gpui::{Entity, FocusHandle, SharedString};
|
use gpui::{Entity, FocusHandle, SharedString};
|
||||||
|
use picker::popover_menu::PickerPopoverMenu;
|
||||||
|
|
||||||
use crate::Thread;
|
use crate::Thread;
|
||||||
use assistant_context_editor::language_model_selector::{
|
use assistant_context_editor::language_model_selector::{
|
||||||
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
|
LanguageModelSelector, ToggleModelSelector, language_model_selector,
|
||||||
};
|
};
|
||||||
use language_model::{ConfiguredModel, LanguageModelRegistry};
|
use language_model::{ConfiguredModel, LanguageModelRegistry};
|
||||||
use settings::update_settings_file;
|
use settings::update_settings_file;
|
||||||
@@ -35,7 +36,7 @@ impl AgentModelSelector {
|
|||||||
Self {
|
Self {
|
||||||
selector: cx.new(move |cx| {
|
selector: cx.new(move |cx| {
|
||||||
let fs = fs.clone();
|
let fs = fs.clone();
|
||||||
LanguageModelSelector::new(
|
language_model_selector(
|
||||||
{
|
{
|
||||||
let model_type = model_type.clone();
|
let model_type = model_type.clone();
|
||||||
move |cx| match &model_type {
|
move |cx| match &model_type {
|
||||||
@@ -100,15 +101,14 @@ impl AgentModelSelector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Render for AgentModelSelector {
|
impl Render for AgentModelSelector {
|
||||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||||
let focus_handle = self.focus_handle.clone();
|
let focus_handle = self.focus_handle.clone();
|
||||||
|
|
||||||
let model = self.selector.read(cx).active_model(cx);
|
let model = self.selector.read(cx).delegate.active_model(cx);
|
||||||
let model_name = model
|
let model_name = model
|
||||||
.map(|model| model.model.name().0)
|
.map(|model| model.model.name().0)
|
||||||
.unwrap_or_else(|| SharedString::from("No model selected"));
|
.unwrap_or_else(|| SharedString::from("No model selected"));
|
||||||
|
PickerPopoverMenu::new(
|
||||||
LanguageModelSelectorPopoverMenu::new(
|
|
||||||
self.selector.clone(),
|
self.selector.clone(),
|
||||||
Button::new("active-model", model_name)
|
Button::new("active-model", model_name)
|
||||||
.label_size(LabelSize::Small)
|
.label_size(LabelSize::Small)
|
||||||
@@ -127,7 +127,9 @@ impl Render for AgentModelSelector {
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
gpui::Corner::BottomRight,
|
gpui::Corner::BottomRight,
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
.with_handle(self.menu_handle.clone())
|
.with_handle(self.menu_handle.clone())
|
||||||
|
.render(window, cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ use workspace::{
|
|||||||
use zed_actions::agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding};
|
use zed_actions::agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding};
|
||||||
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
|
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
|
||||||
use zed_actions::{DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize};
|
use zed_actions::{DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize};
|
||||||
use zed_llm_client::UsageLimit;
|
use zed_llm_client::{CompletionIntent, UsageLimit};
|
||||||
|
|
||||||
use crate::active_thread::{self, ActiveThread, ActiveThreadEvent};
|
use crate::active_thread::{self, ActiveThread, ActiveThreadEvent};
|
||||||
use crate::agent_configuration::{AgentConfiguration, AssistantConfigurationEvent};
|
use crate::agent_configuration::{AgentConfiguration, AssistantConfigurationEvent};
|
||||||
@@ -67,8 +67,8 @@ use crate::{
|
|||||||
AddContextServer, AgentDiffPane, ContextStore, ContinueThread, ContinueWithBurnMode,
|
AddContextServer, AgentDiffPane, ContextStore, ContinueThread, ContinueWithBurnMode,
|
||||||
DeleteRecentlyOpenThread, ExpandMessageEditor, Follow, InlineAssistant, NewTextThread,
|
DeleteRecentlyOpenThread, ExpandMessageEditor, Follow, InlineAssistant, NewTextThread,
|
||||||
NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ResetTrialEndUpsell,
|
NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ResetTrialEndUpsell,
|
||||||
ResetTrialUpsell, TextThreadStore, ThreadEvent, ToggleContextPicker, ToggleNavigationMenu,
|
ResetTrialUpsell, TextThreadStore, ThreadEvent, ToggleBurnMode, ToggleContextPicker,
|
||||||
ToggleOptionsMenu,
|
ToggleNavigationMenu, ToggleOptionsMenu,
|
||||||
};
|
};
|
||||||
|
|
||||||
const AGENT_PANEL_KEY: &str = "agent_panel";
|
const AGENT_PANEL_KEY: &str = "agent_panel";
|
||||||
@@ -174,7 +174,7 @@ enum ActiveView {
|
|||||||
thread: WeakEntity<Thread>,
|
thread: WeakEntity<Thread>,
|
||||||
_subscriptions: Vec<gpui::Subscription>,
|
_subscriptions: Vec<gpui::Subscription>,
|
||||||
},
|
},
|
||||||
PromptEditor {
|
TextThread {
|
||||||
context_editor: Entity<ContextEditor>,
|
context_editor: Entity<ContextEditor>,
|
||||||
title_editor: Entity<Editor>,
|
title_editor: Entity<Editor>,
|
||||||
buffer_search_bar: Entity<BufferSearchBar>,
|
buffer_search_bar: Entity<BufferSearchBar>,
|
||||||
@@ -194,7 +194,7 @@ impl ActiveView {
|
|||||||
pub fn which_font_size_used(&self) -> WhichFontSize {
|
pub fn which_font_size_used(&self) -> WhichFontSize {
|
||||||
match self {
|
match self {
|
||||||
ActiveView::Thread { .. } | ActiveView::History => WhichFontSize::AgentFont,
|
ActiveView::Thread { .. } | ActiveView::History => WhichFontSize::AgentFont,
|
||||||
ActiveView::PromptEditor { .. } => WhichFontSize::BufferFont,
|
ActiveView::TextThread { .. } => WhichFontSize::BufferFont,
|
||||||
ActiveView::Configuration => WhichFontSize::None,
|
ActiveView::Configuration => WhichFontSize::None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -333,7 +333,7 @@ impl ActiveView {
|
|||||||
buffer_search_bar.set_active_pane_item(Some(&context_editor), window, cx)
|
buffer_search_bar.set_active_pane_item(Some(&context_editor), window, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
Self::PromptEditor {
|
Self::TextThread {
|
||||||
context_editor,
|
context_editor,
|
||||||
title_editor: editor,
|
title_editor: editor,
|
||||||
buffer_search_bar,
|
buffer_search_bar,
|
||||||
@@ -1084,9 +1084,23 @@ impl AgentPanel {
|
|||||||
pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context<Self>) {
|
pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context<Self>) {
|
||||||
match self.active_view {
|
match self.active_view {
|
||||||
ActiveView::Configuration | ActiveView::History => {
|
ActiveView::Configuration | ActiveView::History => {
|
||||||
self.active_view =
|
if let Some(previous_view) = self.previous_view.take() {
|
||||||
ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx);
|
self.active_view = previous_view;
|
||||||
self.message_editor.focus_handle(cx).focus(window);
|
|
||||||
|
match &self.active_view {
|
||||||
|
ActiveView::Thread { .. } => {
|
||||||
|
self.message_editor.focus_handle(cx).focus(window);
|
||||||
|
}
|
||||||
|
ActiveView::TextThread { context_editor, .. } => {
|
||||||
|
context_editor.focus_handle(cx).focus(window);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.active_view =
|
||||||
|
ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx);
|
||||||
|
self.message_editor.focus_handle(cx).focus(window);
|
||||||
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
@@ -1296,7 +1310,12 @@ impl AgentPanel {
|
|||||||
active_thread.thread().update(cx, |thread, cx| {
|
active_thread.thread().update(cx, |thread, cx| {
|
||||||
thread.insert_invisible_continue_message(cx);
|
thread.insert_invisible_continue_message(cx);
|
||||||
thread.advance_prompt_id();
|
thread.advance_prompt_id();
|
||||||
thread.send_to_model(model, Some(window.window_handle()), cx);
|
thread.send_to_model(
|
||||||
|
model,
|
||||||
|
CompletionIntent::UserPrompt,
|
||||||
|
Some(window.window_handle()),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
@@ -1304,9 +1323,27 @@ impl AgentPanel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn toggle_burn_mode(
|
||||||
|
&mut self,
|
||||||
|
_: &ToggleBurnMode,
|
||||||
|
_window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
self.thread.update(cx, |active_thread, cx| {
|
||||||
|
active_thread.thread().update(cx, |thread, _cx| {
|
||||||
|
let current_mode = thread.completion_mode();
|
||||||
|
|
||||||
|
thread.set_completion_mode(match current_mode {
|
||||||
|
CompletionMode::Burn => CompletionMode::Normal,
|
||||||
|
CompletionMode::Normal => CompletionMode::Burn,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn active_context_editor(&self) -> Option<Entity<ContextEditor>> {
|
pub(crate) fn active_context_editor(&self) -> Option<Entity<ContextEditor>> {
|
||||||
match &self.active_view {
|
match &self.active_view {
|
||||||
ActiveView::PromptEditor { context_editor, .. } => Some(context_editor.clone()),
|
ActiveView::TextThread { context_editor, .. } => Some(context_editor.clone()),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1329,6 +1366,12 @@ impl AgentPanel {
|
|||||||
let current_is_history = matches!(self.active_view, ActiveView::History);
|
let current_is_history = matches!(self.active_view, ActiveView::History);
|
||||||
let new_is_history = matches!(new_view, ActiveView::History);
|
let new_is_history = matches!(new_view, ActiveView::History);
|
||||||
|
|
||||||
|
let current_is_config = matches!(self.active_view, ActiveView::Configuration);
|
||||||
|
let new_is_config = matches!(new_view, ActiveView::Configuration);
|
||||||
|
|
||||||
|
let current_is_special = current_is_history || current_is_config;
|
||||||
|
let new_is_special = new_is_history || new_is_config;
|
||||||
|
|
||||||
match &self.active_view {
|
match &self.active_view {
|
||||||
ActiveView::Thread { thread, .. } => {
|
ActiveView::Thread { thread, .. } => {
|
||||||
if let Some(thread) = thread.upgrade() {
|
if let Some(thread) = thread.upgrade() {
|
||||||
@@ -1340,7 +1383,7 @@ impl AgentPanel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ActiveView::PromptEditor { context_editor, .. } => {
|
ActiveView::TextThread { context_editor, .. } => {
|
||||||
let context = context_editor.read(cx).context();
|
let context = context_editor.read(cx).context();
|
||||||
// When switching away from an unsaved text thread, delete its entry.
|
// When switching away from an unsaved text thread, delete its entry.
|
||||||
if context.read(cx).path().is_none() {
|
if context.read(cx).path().is_none() {
|
||||||
@@ -1360,7 +1403,7 @@ impl AgentPanel {
|
|||||||
store.push_recently_opened_entry(RecentEntry::Thread(id, thread), cx);
|
store.push_recently_opened_entry(RecentEntry::Thread(id, thread), cx);
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
ActiveView::PromptEditor { context_editor, .. } => {
|
ActiveView::TextThread { context_editor, .. } => {
|
||||||
self.history_store.update(cx, |store, cx| {
|
self.history_store.update(cx, |store, cx| {
|
||||||
let context = context_editor.read(cx).context().clone();
|
let context = context_editor.read(cx).context().clone();
|
||||||
store.push_recently_opened_entry(RecentEntry::Context(context), cx)
|
store.push_recently_opened_entry(RecentEntry::Context(context), cx)
|
||||||
@@ -1369,12 +1412,12 @@ impl AgentPanel {
|
|||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
if current_is_history && !new_is_history {
|
if current_is_special && !new_is_special {
|
||||||
self.active_view = new_view;
|
self.active_view = new_view;
|
||||||
} else if !current_is_history && new_is_history {
|
} else if !current_is_special && new_is_special {
|
||||||
self.previous_view = Some(std::mem::replace(&mut self.active_view, new_view));
|
self.previous_view = Some(std::mem::replace(&mut self.active_view, new_view));
|
||||||
} else {
|
} else {
|
||||||
if !new_is_history {
|
if !new_is_special {
|
||||||
self.previous_view = None;
|
self.previous_view = None;
|
||||||
}
|
}
|
||||||
self.active_view = new_view;
|
self.active_view = new_view;
|
||||||
@@ -1389,7 +1432,7 @@ impl Focusable for AgentPanel {
|
|||||||
match &self.active_view {
|
match &self.active_view {
|
||||||
ActiveView::Thread { .. } => self.message_editor.focus_handle(cx),
|
ActiveView::Thread { .. } => self.message_editor.focus_handle(cx),
|
||||||
ActiveView::History => self.history.focus_handle(cx),
|
ActiveView::History => self.history.focus_handle(cx),
|
||||||
ActiveView::PromptEditor { context_editor, .. } => context_editor.focus_handle(cx),
|
ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx),
|
||||||
ActiveView::Configuration => {
|
ActiveView::Configuration => {
|
||||||
if let Some(configuration) = self.configuration.as_ref() {
|
if let Some(configuration) = self.configuration.as_ref() {
|
||||||
configuration.focus_handle(cx)
|
configuration.focus_handle(cx)
|
||||||
@@ -1541,7 +1584,7 @@ impl AgentPanel {
|
|||||||
.into_any_element(),
|
.into_any_element(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ActiveView::PromptEditor {
|
ActiveView::TextThread {
|
||||||
title_editor,
|
title_editor,
|
||||||
context_editor,
|
context_editor,
|
||||||
..
|
..
|
||||||
@@ -1633,7 +1676,7 @@ impl AgentPanel {
|
|||||||
|
|
||||||
let show_token_count = match &self.active_view {
|
let show_token_count = match &self.active_view {
|
||||||
ActiveView::Thread { .. } => !is_empty || !editor_empty,
|
ActiveView::Thread { .. } => !is_empty || !editor_empty,
|
||||||
ActiveView::PromptEditor { .. } => true,
|
ActiveView::TextThread { .. } => true,
|
||||||
_ => false,
|
_ => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1949,7 +1992,7 @@ impl AgentPanel {
|
|||||||
|
|
||||||
Some(token_count)
|
Some(token_count)
|
||||||
}
|
}
|
||||||
ActiveView::PromptEditor { context_editor, .. } => {
|
ActiveView::TextThread { context_editor, .. } => {
|
||||||
let element = render_remaining_tokens(context_editor, cx)?;
|
let element = render_remaining_tokens(context_editor, cx)?;
|
||||||
|
|
||||||
Some(element.into_any_element())
|
Some(element.into_any_element())
|
||||||
@@ -2663,7 +2706,7 @@ impl AgentPanel {
|
|||||||
.on_click(cx.listener(|this, _, window, cx| {
|
.on_click(cx.listener(|this, _, window, cx| {
|
||||||
this.thread.update(cx, |active_thread, cx| {
|
this.thread.update(cx, |active_thread, cx| {
|
||||||
active_thread.thread().update(cx, |thread, _cx| {
|
active_thread.thread().update(cx, |thread, _cx| {
|
||||||
thread.set_completion_mode(CompletionMode::Max);
|
thread.set_completion_mode(CompletionMode::Burn);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
this.continue_conversation(window, cx);
|
this.continue_conversation(window, cx);
|
||||||
@@ -2867,7 +2910,7 @@ impl AgentPanel {
|
|||||||
) -> Div {
|
) -> Div {
|
||||||
let mut registrar = buffer_search::DivRegistrar::new(
|
let mut registrar = buffer_search::DivRegistrar::new(
|
||||||
|this, _, _cx| match &this.active_view {
|
|this, _, _cx| match &this.active_view {
|
||||||
ActiveView::PromptEditor {
|
ActiveView::TextThread {
|
||||||
buffer_search_bar, ..
|
buffer_search_bar, ..
|
||||||
} => Some(buffer_search_bar.clone()),
|
} => Some(buffer_search_bar.clone()),
|
||||||
_ => None,
|
_ => None,
|
||||||
@@ -2985,7 +3028,7 @@ impl AgentPanel {
|
|||||||
.detach();
|
.detach();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
ActiveView::PromptEditor { context_editor, .. } => {
|
ActiveView::TextThread { context_editor, .. } => {
|
||||||
context_editor.update(cx, |context_editor, cx| {
|
context_editor.update(cx, |context_editor, cx| {
|
||||||
ContextEditor::insert_dragged_files(
|
ContextEditor::insert_dragged_files(
|
||||||
context_editor,
|
context_editor,
|
||||||
@@ -3012,7 +3055,7 @@ impl AgentPanel {
|
|||||||
fn key_context(&self) -> KeyContext {
|
fn key_context(&self) -> KeyContext {
|
||||||
let mut key_context = KeyContext::new_with_defaults();
|
let mut key_context = KeyContext::new_with_defaults();
|
||||||
key_context.add("AgentPanel");
|
key_context.add("AgentPanel");
|
||||||
if matches!(self.active_view, ActiveView::PromptEditor { .. }) {
|
if matches!(self.active_view, ActiveView::TextThread { .. }) {
|
||||||
key_context.add("prompt_editor");
|
key_context.add("prompt_editor");
|
||||||
}
|
}
|
||||||
key_context
|
key_context
|
||||||
@@ -3060,11 +3103,12 @@ impl Render for AgentPanel {
|
|||||||
.on_action(cx.listener(|this, _: &ContinueWithBurnMode, window, cx| {
|
.on_action(cx.listener(|this, _: &ContinueWithBurnMode, window, cx| {
|
||||||
this.thread.update(cx, |active_thread, cx| {
|
this.thread.update(cx, |active_thread, cx| {
|
||||||
active_thread.thread().update(cx, |thread, _cx| {
|
active_thread.thread().update(cx, |thread, _cx| {
|
||||||
thread.set_completion_mode(CompletionMode::Max);
|
thread.set_completion_mode(CompletionMode::Burn);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
this.continue_conversation(window, cx);
|
this.continue_conversation(window, cx);
|
||||||
}))
|
}))
|
||||||
|
.on_action(cx.listener(Self::toggle_burn_mode))
|
||||||
.child(self.render_toolbar(window, cx))
|
.child(self.render_toolbar(window, cx))
|
||||||
.children(self.render_upsell(window, cx))
|
.children(self.render_upsell(window, cx))
|
||||||
.children(self.render_trial_end_upsell(window, cx))
|
.children(self.render_trial_end_upsell(window, cx))
|
||||||
@@ -3077,7 +3121,7 @@ impl Render for AgentPanel {
|
|||||||
.children(self.render_last_error(cx))
|
.children(self.render_last_error(cx))
|
||||||
.child(self.render_drag_target(cx)),
|
.child(self.render_drag_target(cx)),
|
||||||
ActiveView::History => parent.child(self.history.clone()),
|
ActiveView::History => parent.child(self.history.clone()),
|
||||||
ActiveView::PromptEditor {
|
ActiveView::TextThread {
|
||||||
context_editor,
|
context_editor,
|
||||||
buffer_search_bar,
|
buffer_search_bar,
|
||||||
..
|
..
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ use std::{
|
|||||||
};
|
};
|
||||||
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
|
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
|
||||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||||
|
use zed_llm_client::CompletionIntent;
|
||||||
|
|
||||||
pub struct BufferCodegen {
|
pub struct BufferCodegen {
|
||||||
alternatives: Vec<Entity<CodegenAlternative>>,
|
alternatives: Vec<Entity<CodegenAlternative>>,
|
||||||
@@ -464,6 +465,7 @@ impl CodegenAlternative {
|
|||||||
LanguageModelRequest {
|
LanguageModelRequest {
|
||||||
thread_id: None,
|
thread_id: None,
|
||||||
prompt_id: None,
|
prompt_id: None,
|
||||||
|
intent: Some(CompletionIntent::InlineAssist),
|
||||||
mode: None,
|
mode: None,
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
|
|||||||
@@ -734,6 +734,7 @@ impl Display for RulesContext {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ImageContext {
|
pub struct ImageContext {
|
||||||
pub project_path: Option<ProjectPath>,
|
pub project_path: Option<ProjectPath>,
|
||||||
|
pub full_path: Option<Arc<Path>>,
|
||||||
pub original_image: Arc<gpui::Image>,
|
pub original_image: Arc<gpui::Image>,
|
||||||
// TODO: handle this elsewhere and remove `ignore-interior-mutability` opt-out in clippy.toml
|
// TODO: handle this elsewhere and remove `ignore-interior-mutability` opt-out in clippy.toml
|
||||||
// needed due to a false positive of `clippy::mutable_key_type`.
|
// needed due to a false positive of `clippy::mutable_key_type`.
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ use http_client::HttpClientWithUrl;
|
|||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use language::{Buffer, CodeLabel, HighlightId};
|
use language::{Buffer, CodeLabel, HighlightId};
|
||||||
use lsp::CompletionContext;
|
use lsp::CompletionContext;
|
||||||
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
|
use project::{Completion, CompletionIntent, CompletionResponse, ProjectPath, Symbol, WorktreeId};
|
||||||
use prompt_store::PromptStore;
|
use prompt_store::PromptStore;
|
||||||
use rope::Point;
|
use rope::Point;
|
||||||
use text::{Anchor, OffsetRangeExt, ToPoint};
|
use text::{Anchor, OffsetRangeExt, ToPoint};
|
||||||
@@ -746,7 +746,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
|||||||
_trigger: CompletionContext,
|
_trigger: CompletionContext,
|
||||||
_window: &mut Window,
|
_window: &mut Window,
|
||||||
cx: &mut Context<Editor>,
|
cx: &mut Context<Editor>,
|
||||||
) -> Task<Result<Option<Vec<Completion>>>> {
|
) -> Task<Result<Vec<CompletionResponse>>> {
|
||||||
let state = buffer.update(cx, |buffer, _cx| {
|
let state = buffer.update(cx, |buffer, _cx| {
|
||||||
let position = buffer_position.to_point(buffer);
|
let position = buffer_position.to_point(buffer);
|
||||||
let line_start = Point::new(position.row, 0);
|
let line_start = Point::new(position.row, 0);
|
||||||
@@ -756,13 +756,13 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
|||||||
MentionCompletion::try_parse(line, offset_to_line)
|
MentionCompletion::try_parse(line, offset_to_line)
|
||||||
});
|
});
|
||||||
let Some(state) = state else {
|
let Some(state) = state else {
|
||||||
return Task::ready(Ok(None));
|
return Task::ready(Ok(Vec::new()));
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some((workspace, context_store)) =
|
let Some((workspace, context_store)) =
|
||||||
self.workspace.upgrade().zip(self.context_store.upgrade())
|
self.workspace.upgrade().zip(self.context_store.upgrade())
|
||||||
else {
|
else {
|
||||||
return Task::ready(Ok(None));
|
return Task::ready(Ok(Vec::new()));
|
||||||
};
|
};
|
||||||
|
|
||||||
let snapshot = buffer.read(cx).snapshot();
|
let snapshot = buffer.read(cx).snapshot();
|
||||||
@@ -815,10 +815,10 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
|||||||
cx.spawn(async move |_, cx| {
|
cx.spawn(async move |_, cx| {
|
||||||
let matches = search_task.await;
|
let matches = search_task.await;
|
||||||
let Some(editor) = editor.upgrade() else {
|
let Some(editor) = editor.upgrade() else {
|
||||||
return Ok(None);
|
return Ok(Vec::new());
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(cx.update(|cx| {
|
let completions = cx.update(|cx| {
|
||||||
matches
|
matches
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|mat| match mat {
|
.filter_map(|mat| match mat {
|
||||||
@@ -901,7 +901,14 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
|||||||
),
|
),
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
})?))
|
})?;
|
||||||
|
|
||||||
|
Ok(vec![CompletionResponse {
|
||||||
|
completions,
|
||||||
|
// Since this does its own filtering (see `filter_completions()` returns false),
|
||||||
|
// there is no benefit to computing whether this set of completions is incomplete.
|
||||||
|
is_incomplete: true,
|
||||||
|
}])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ use assistant_context_editor::AssistantContext;
|
|||||||
use collections::{HashSet, IndexSet};
|
use collections::{HashSet, IndexSet};
|
||||||
use futures::{self, FutureExt};
|
use futures::{self, FutureExt};
|
||||||
use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
|
use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
|
||||||
use language::Buffer;
|
use language::{Buffer, File as _};
|
||||||
use language_model::LanguageModelImage;
|
use language_model::LanguageModelImage;
|
||||||
use project::image_store::is_image_file;
|
use project::image_store::is_image_file;
|
||||||
use project::{Project, ProjectItem, ProjectPath, Symbol};
|
use project::{Project, ProjectItem, ProjectPath, Symbol};
|
||||||
@@ -304,11 +304,13 @@ impl ContextStore {
|
|||||||
project.open_image(project_path.clone(), cx)
|
project.open_image(project_path.clone(), cx)
|
||||||
})?;
|
})?;
|
||||||
let image_item = open_image_task.await?;
|
let image_item = open_image_task.await?;
|
||||||
let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
|
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
|
let item = image_item.read(cx);
|
||||||
this.insert_image(
|
this.insert_image(
|
||||||
Some(image_item.read(cx).project_path(cx)),
|
Some(item.project_path(cx)),
|
||||||
image,
|
Some(item.file.full_path(cx).into()),
|
||||||
|
item.image.clone(),
|
||||||
remove_if_exists,
|
remove_if_exists,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
@@ -317,12 +319,13 @@ impl ContextStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
|
pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
|
||||||
self.insert_image(None, image, false, cx);
|
self.insert_image(None, None, image, false, cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn insert_image(
|
fn insert_image(
|
||||||
&mut self,
|
&mut self,
|
||||||
project_path: Option<ProjectPath>,
|
project_path: Option<ProjectPath>,
|
||||||
|
full_path: Option<Arc<Path>>,
|
||||||
image: Arc<Image>,
|
image: Arc<Image>,
|
||||||
remove_if_exists: bool,
|
remove_if_exists: bool,
|
||||||
cx: &mut Context<ContextStore>,
|
cx: &mut Context<ContextStore>,
|
||||||
@@ -330,6 +333,7 @@ impl ContextStore {
|
|||||||
let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
|
let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
|
||||||
let context = AgentContextHandle::Image(ImageContext {
|
let context = AgentContextHandle::Image(ImageContext {
|
||||||
project_path,
|
project_path,
|
||||||
|
full_path,
|
||||||
original_image: image,
|
original_image: image,
|
||||||
image_task,
|
image_task,
|
||||||
context_id: self.next_context_id.post_inc(),
|
context_id: self.next_context_id.post_inc(),
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ impl HistoryStore {
|
|||||||
let entries = join_all(entries)
|
let entries = join_all(entries)
|
||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|result| result.log_err())
|
.filter_map(|result| result.log_with_level(log::Level::Debug))
|
||||||
.collect::<VecDeque<_>>();
|
.collect::<VecDeque<_>>();
|
||||||
|
|
||||||
this.update(cx, |this, _| {
|
this.update(cx, |this, _| {
|
||||||
|
|||||||
@@ -1445,7 +1445,7 @@ impl InlineAssistant {
|
|||||||
style: BlockStyle::Flex,
|
style: BlockStyle::Flex,
|
||||||
render: Arc::new(move |cx| {
|
render: Arc::new(move |cx| {
|
||||||
div()
|
div()
|
||||||
.block_mouse_down()
|
.block_mouse_except_scroll()
|
||||||
.bg(cx.theme().status().deleted_background)
|
.bg(cx.theme().status().deleted_background)
|
||||||
.size_full()
|
.size_full()
|
||||||
.h(height as f32 * cx.window.line_height())
|
.h(height as f32 * cx.window.line_height())
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ impl<T: 'static> Render for PromptEditor<T> {
|
|||||||
v_flex()
|
v_flex()
|
||||||
.key_context("PromptEditor")
|
.key_context("PromptEditor")
|
||||||
.bg(cx.theme().colors().editor_background)
|
.bg(cx.theme().colors().editor_background)
|
||||||
.block_mouse_down()
|
.block_mouse_except_scroll()
|
||||||
.gap_0p5()
|
.gap_0p5()
|
||||||
.border_y_1()
|
.border_y_1()
|
||||||
.border_color(cx.theme().status().info_border)
|
.border_color(cx.theme().status().info_border)
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ use theme::ThemeSettings;
|
|||||||
use ui::{Disclosure, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*};
|
use ui::{Disclosure, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*};
|
||||||
use util::{ResultExt as _, maybe};
|
use util::{ResultExt as _, maybe};
|
||||||
use workspace::{CollaboratorId, Workspace};
|
use workspace::{CollaboratorId, Workspace};
|
||||||
|
use zed_llm_client::CompletionIntent;
|
||||||
|
|
||||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
|
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
|
||||||
use crate::context_store::ContextStore;
|
use crate::context_store::ContextStore;
|
||||||
@@ -51,7 +52,7 @@ use crate::thread::{MessageCrease, Thread, TokenUsageRatio};
|
|||||||
use crate::thread_store::{TextThreadStore, ThreadStore};
|
use crate::thread_store::{TextThreadStore, ThreadStore};
|
||||||
use crate::{
|
use crate::{
|
||||||
ActiveThread, AgentDiffPane, Chat, ChatWithFollow, ExpandMessageEditor, Follow, NewThread,
|
ActiveThread, AgentDiffPane, Chat, ChatWithFollow, ExpandMessageEditor, Follow, NewThread,
|
||||||
OpenAgentDiff, RemoveAllContext, ToggleContextPicker, ToggleProfileSelector,
|
OpenAgentDiff, RemoveAllContext, ToggleBurnMode, ToggleContextPicker, ToggleProfileSelector,
|
||||||
register_agent_preview,
|
register_agent_preview,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -111,6 +112,7 @@ pub(crate) fn create_editor(
|
|||||||
editor.set_placeholder_text("Message the agent – @ to include context", cx);
|
editor.set_placeholder_text("Message the agent – @ to include context", cx);
|
||||||
editor.set_show_indent_guides(false, cx);
|
editor.set_show_indent_guides(false, cx);
|
||||||
editor.set_soft_wrap();
|
editor.set_soft_wrap();
|
||||||
|
editor.set_use_modal_editing(true);
|
||||||
editor.set_context_menu_options(ContextMenuOptions {
|
editor.set_context_menu_options(ContextMenuOptions {
|
||||||
min_entries_visible: 12,
|
min_entries_visible: 12,
|
||||||
max_entries_visible: 12,
|
max_entries_visible: 12,
|
||||||
@@ -375,7 +377,12 @@ impl MessageEditor {
|
|||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.advance_prompt_id();
|
thread.advance_prompt_id();
|
||||||
thread.send_to_model(model, Some(window_handle), cx);
|
thread.send_to_model(
|
||||||
|
model,
|
||||||
|
CompletionIntent::UserPrompt,
|
||||||
|
Some(window_handle),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
})
|
})
|
||||||
.log_err();
|
.log_err();
|
||||||
})
|
})
|
||||||
@@ -471,6 +478,22 @@ impl MessageEditor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn toggle_burn_mode(
|
||||||
|
&mut self,
|
||||||
|
_: &ToggleBurnMode,
|
||||||
|
_window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
self.thread.update(cx, |thread, _cx| {
|
||||||
|
let active_completion_mode = thread.completion_mode();
|
||||||
|
|
||||||
|
thread.set_completion_mode(match active_completion_mode {
|
||||||
|
CompletionMode::Burn => CompletionMode::Normal,
|
||||||
|
CompletionMode::Normal => CompletionMode::Burn,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
fn render_max_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
fn render_max_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||||
let thread = self.thread.read(cx);
|
let thread = self.thread.read(cx);
|
||||||
let model = thread.configured_model();
|
let model = thread.configured_model();
|
||||||
@@ -479,8 +502,8 @@ impl MessageEditor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let active_completion_mode = thread.completion_mode();
|
let active_completion_mode = thread.completion_mode();
|
||||||
let max_mode_enabled = active_completion_mode == CompletionMode::Max;
|
let burn_mode_enabled = active_completion_mode == CompletionMode::Burn;
|
||||||
let icon = if max_mode_enabled {
|
let icon = if burn_mode_enabled {
|
||||||
IconName::ZedBurnModeOn
|
IconName::ZedBurnModeOn
|
||||||
} else {
|
} else {
|
||||||
IconName::ZedBurnMode
|
IconName::ZedBurnMode
|
||||||
@@ -490,18 +513,13 @@ impl MessageEditor {
|
|||||||
IconButton::new("burn-mode", icon)
|
IconButton::new("burn-mode", icon)
|
||||||
.icon_size(IconSize::Small)
|
.icon_size(IconSize::Small)
|
||||||
.icon_color(Color::Muted)
|
.icon_color(Color::Muted)
|
||||||
.toggle_state(max_mode_enabled)
|
.toggle_state(burn_mode_enabled)
|
||||||
.selected_icon_color(Color::Error)
|
.selected_icon_color(Color::Error)
|
||||||
.on_click(cx.listener(move |this, _event, _window, cx| {
|
.on_click(cx.listener(|this, _event, window, cx| {
|
||||||
this.thread.update(cx, |thread, _cx| {
|
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
|
||||||
thread.set_completion_mode(match active_completion_mode {
|
|
||||||
CompletionMode::Max => CompletionMode::Normal,
|
|
||||||
CompletionMode::Normal => CompletionMode::Max,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}))
|
}))
|
||||||
.tooltip(move |_window, cx| {
|
.tooltip(move |_window, cx| {
|
||||||
cx.new(|_| MaxModeTooltip::new().selected(max_mode_enabled))
|
cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
|
||||||
.into()
|
.into()
|
||||||
})
|
})
|
||||||
.into_any_element(),
|
.into_any_element(),
|
||||||
@@ -596,6 +614,7 @@ impl MessageEditor {
|
|||||||
.on_action(cx.listener(Self::remove_all_context))
|
.on_action(cx.listener(Self::remove_all_context))
|
||||||
.on_action(cx.listener(Self::move_up))
|
.on_action(cx.listener(Self::move_up))
|
||||||
.on_action(cx.listener(Self::expand_message_editor))
|
.on_action(cx.listener(Self::expand_message_editor))
|
||||||
|
.on_action(cx.listener(Self::toggle_burn_mode))
|
||||||
.capture_action(cx.listener(Self::paste))
|
.capture_action(cx.listener(Self::paste))
|
||||||
.gap_2()
|
.gap_2()
|
||||||
.p_2()
|
.p_2()
|
||||||
@@ -1268,6 +1287,7 @@ impl MessageEditor {
|
|||||||
let request = language_model::LanguageModelRequest {
|
let request = language_model::LanguageModelRequest {
|
||||||
thread_id: None,
|
thread_id: None,
|
||||||
prompt_id: None,
|
prompt_id: None,
|
||||||
|
intent: None,
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: vec![request_message],
|
messages: vec![request_message],
|
||||||
tools: vec![],
|
tools: vec![],
|
||||||
|
|||||||
1
crates/agent/src/prompts/stale_files_prompt_header.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
These files changed since last read:
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
Generate a detailed summary of this conversation. Include:
|
||||||
|
1. A brief overview of what was discussed
|
||||||
|
2. Key facts or information discovered
|
||||||
|
3. Outcomes or conclusions reached
|
||||||
|
4. Any action items or next steps if any
|
||||||
|
Format it in Markdown with headings and bullet points.
|
||||||
4
crates/agent/src/prompts/summarize_thread_prompt.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
Generate a concise 3-7 word title for this conversation, omitting punctuation.
|
||||||
|
Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`.
|
||||||
|
If the conversation is about a specific subject, include it in the title.
|
||||||
|
Be descriptive. DO NOT speak in the first person.
|
||||||
@@ -179,18 +179,17 @@ impl TerminalTransaction {
|
|||||||
// Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
|
// Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
|
||||||
let input = Self::sanitize_input(hunk);
|
let input = Self::sanitize_input(hunk);
|
||||||
self.terminal
|
self.terminal
|
||||||
.update(cx, |terminal, _| terminal.input(input));
|
.update(cx, |terminal, _| terminal.input(input.into_bytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn undo(&self, cx: &mut App) {
|
pub fn undo(&self, cx: &mut App) {
|
||||||
self.terminal
|
self.terminal
|
||||||
.update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
|
.update(cx, |terminal, _| terminal.input(CLEAR_INPUT.as_bytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn complete(&self, cx: &mut App) {
|
pub fn complete(&self, cx: &mut App) {
|
||||||
self.terminal.update(cx, |terminal, _| {
|
self.terminal
|
||||||
terminal.input(CARRIAGE_RETURN.to_string())
|
.update(cx, |terminal, _| terminal.input(CARRIAGE_RETURN.as_bytes()));
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sanitize_input(mut input: String) -> String {
|
fn sanitize_input(mut input: String) -> String {
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ use terminal_view::TerminalView;
|
|||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
use workspace::{Toast, Workspace, notifications::NotificationId};
|
use workspace::{Toast, Workspace, notifications::NotificationId};
|
||||||
|
use zed_llm_client::CompletionIntent;
|
||||||
|
|
||||||
pub fn init(
|
pub fn init(
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
@@ -105,7 +106,7 @@ impl TerminalInlineAssistant {
|
|||||||
});
|
});
|
||||||
let prompt_editor_render = prompt_editor.clone();
|
let prompt_editor_render = prompt_editor.clone();
|
||||||
let block = terminal_view::BlockProperties {
|
let block = terminal_view::BlockProperties {
|
||||||
height: 2,
|
height: 4,
|
||||||
render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
|
render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
|
||||||
};
|
};
|
||||||
terminal_view.update(cx, |terminal_view, cx| {
|
terminal_view.update(cx, |terminal_view, cx| {
|
||||||
@@ -201,7 +202,7 @@ impl TerminalInlineAssistant {
|
|||||||
.update(cx, |terminal, cx| {
|
.update(cx, |terminal, cx| {
|
||||||
terminal
|
terminal
|
||||||
.terminal()
|
.terminal()
|
||||||
.update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
|
.update(cx, |terminal, _| terminal.input(CLEAR_INPUT.as_bytes()));
|
||||||
})
|
})
|
||||||
.log_err();
|
.log_err();
|
||||||
|
|
||||||
@@ -291,6 +292,7 @@ impl TerminalInlineAssistant {
|
|||||||
thread_id: None,
|
thread_id: None,
|
||||||
prompt_id: None,
|
prompt_id: None,
|
||||||
mode: None,
|
mode: None,
|
||||||
|
intent: Some(CompletionIntent::TerminalInlineAssist),
|
||||||
messages: vec![request_message],
|
messages: vec![request_message],
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ use language_model::{
|
|||||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||||
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
|
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
|
||||||
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
|
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
|
||||||
StopReason, TokenUsage, WrappedTextContent,
|
StopReason, TokenUsage,
|
||||||
};
|
};
|
||||||
use postage::stream::Stream as _;
|
use postage::stream::Stream as _;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
@@ -38,7 +38,7 @@ use thiserror::Error;
|
|||||||
use ui::Window;
|
use ui::Window;
|
||||||
use util::{ResultExt as _, post_inc};
|
use util::{ResultExt as _, post_inc};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use zed_llm_client::CompletionRequestStatus;
|
use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
|
||||||
|
|
||||||
use crate::ThreadStore;
|
use crate::ThreadStore;
|
||||||
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
|
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
|
||||||
@@ -891,10 +891,7 @@ impl Thread {
|
|||||||
|
|
||||||
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
|
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
|
||||||
match &self.tool_use.tool_result(id)?.content {
|
match &self.tool_use.tool_result(id)?.content {
|
||||||
LanguageModelToolResultContent::Text(text)
|
LanguageModelToolResultContent::Text(text) => Some(text),
|
||||||
| LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
|
|
||||||
Some(text)
|
|
||||||
}
|
|
||||||
LanguageModelToolResultContent::Image(_) => {
|
LanguageModelToolResultContent::Image(_) => {
|
||||||
// TODO: We should display image
|
// TODO: We should display image
|
||||||
None
|
None
|
||||||
@@ -1187,6 +1184,7 @@ impl Thread {
|
|||||||
pub fn send_to_model(
|
pub fn send_to_model(
|
||||||
&mut self,
|
&mut self,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
|
intent: CompletionIntent,
|
||||||
window: Option<AnyWindowHandle>,
|
window: Option<AnyWindowHandle>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
@@ -1196,7 +1194,7 @@ impl Thread {
|
|||||||
|
|
||||||
self.remaining_turns -= 1;
|
self.remaining_turns -= 1;
|
||||||
|
|
||||||
let request = self.to_completion_request(model.clone(), cx);
|
let request = self.to_completion_request(model.clone(), intent, cx);
|
||||||
|
|
||||||
self.stream_completion(request, model, window, cx);
|
self.stream_completion(request, model, window, cx);
|
||||||
}
|
}
|
||||||
@@ -1216,11 +1214,13 @@ impl Thread {
|
|||||||
pub fn to_completion_request(
|
pub fn to_completion_request(
|
||||||
&self,
|
&self,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
|
intent: CompletionIntent,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> LanguageModelRequest {
|
) -> LanguageModelRequest {
|
||||||
let mut request = LanguageModelRequest {
|
let mut request = LanguageModelRequest {
|
||||||
thread_id: Some(self.id.to_string()),
|
thread_id: Some(self.id.to_string()),
|
||||||
prompt_id: Some(self.last_prompt_id.to_string()),
|
prompt_id: Some(self.last_prompt_id.to_string()),
|
||||||
|
intent: Some(intent),
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: vec![],
|
messages: vec![],
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
@@ -1374,12 +1374,14 @@ impl Thread {
|
|||||||
fn to_summarize_request(
|
fn to_summarize_request(
|
||||||
&self,
|
&self,
|
||||||
model: &Arc<dyn LanguageModel>,
|
model: &Arc<dyn LanguageModel>,
|
||||||
|
intent: CompletionIntent,
|
||||||
added_user_message: String,
|
added_user_message: String,
|
||||||
cx: &App,
|
cx: &App,
|
||||||
) -> LanguageModelRequest {
|
) -> LanguageModelRequest {
|
||||||
let mut request = LanguageModelRequest {
|
let mut request = LanguageModelRequest {
|
||||||
thread_id: None,
|
thread_id: None,
|
||||||
prompt_id: None,
|
prompt_id: None,
|
||||||
|
intent: Some(intent),
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: vec![],
|
messages: vec![],
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
@@ -1426,7 +1428,7 @@ impl Thread {
|
|||||||
messages: &mut Vec<LanguageModelRequestMessage>,
|
messages: &mut Vec<LanguageModelRequestMessage>,
|
||||||
cx: &App,
|
cx: &App,
|
||||||
) {
|
) {
|
||||||
const STALE_FILES_HEADER: &str = "These files changed since last read:";
|
const STALE_FILES_HEADER: &str = include_str!("./prompts/stale_files_prompt_header.txt");
|
||||||
|
|
||||||
let mut stale_message = String::new();
|
let mut stale_message = String::new();
|
||||||
|
|
||||||
@@ -1438,7 +1440,7 @@ impl Thread {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if stale_message.is_empty() {
|
if stale_message.is_empty() {
|
||||||
write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
|
write!(&mut stale_message, "{}\n", STALE_FILES_HEADER.trim()).ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
writeln!(&mut stale_message, "- {}", file.path().display()).ok();
|
writeln!(&mut stale_message, "- {}", file.path().display()).ok();
|
||||||
@@ -1671,6 +1673,7 @@ impl Thread {
|
|||||||
}
|
}
|
||||||
CompletionRequestStatus::ToolUseLimitReached => {
|
CompletionRequestStatus::ToolUseLimitReached => {
|
||||||
thread.tool_use_limit_reached = true;
|
thread.tool_use_limit_reached = true;
|
||||||
|
cx.emit(ThreadEvent::ToolUseLimitReached);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1852,12 +1855,14 @@ impl Thread {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
|
let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
|
||||||
Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
|
|
||||||
If the conversation is about a specific subject, include it in the title. \
|
|
||||||
Be descriptive. DO NOT speak in the first person.";
|
|
||||||
|
|
||||||
let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
|
let request = self.to_summarize_request(
|
||||||
|
&model.model,
|
||||||
|
CompletionIntent::ThreadSummarization,
|
||||||
|
added_user_message.into(),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
self.summary = ThreadSummary::Generating;
|
self.summary = ThreadSummary::Generating;
|
||||||
|
|
||||||
@@ -1951,14 +1956,14 @@ impl Thread {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
|
let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
|
||||||
1. A brief overview of what was discussed\n\
|
|
||||||
2. Key facts or information discovered\n\
|
|
||||||
3. Outcomes or conclusions reached\n\
|
|
||||||
4. Any action items or next steps if any\n\
|
|
||||||
Format it in Markdown with headings and bullet points.";
|
|
||||||
|
|
||||||
let request = self.to_summarize_request(&model, added_user_message.into(), cx);
|
let request = self.to_summarize_request(
|
||||||
|
&model,
|
||||||
|
CompletionIntent::ThreadContextSummarization,
|
||||||
|
added_user_message.into(),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
*self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
|
*self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
|
||||||
message_id: last_message_id,
|
message_id: last_message_id,
|
||||||
@@ -2050,7 +2055,8 @@ impl Thread {
|
|||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
) -> Vec<PendingToolUse> {
|
) -> Vec<PendingToolUse> {
|
||||||
self.auto_capture_telemetry(cx);
|
self.auto_capture_telemetry(cx);
|
||||||
let request = Arc::new(self.to_completion_request(model.clone(), cx));
|
let request =
|
||||||
|
Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
|
||||||
let pending_tool_uses = self
|
let pending_tool_uses = self
|
||||||
.tool_use
|
.tool_use
|
||||||
.pending_tool_uses()
|
.pending_tool_uses()
|
||||||
@@ -2246,7 +2252,7 @@ impl Thread {
|
|||||||
if self.all_tools_finished() {
|
if self.all_tools_finished() {
|
||||||
if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
|
if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
|
||||||
if !canceled {
|
if !canceled {
|
||||||
self.send_to_model(model.clone(), window, cx);
|
self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
|
||||||
}
|
}
|
||||||
self.auto_capture_telemetry(cx);
|
self.auto_capture_telemetry(cx);
|
||||||
}
|
}
|
||||||
@@ -2593,11 +2599,7 @@ impl Thread {
|
|||||||
|
|
||||||
writeln!(markdown, "**\n")?;
|
writeln!(markdown, "**\n")?;
|
||||||
match &tool_result.content {
|
match &tool_result.content {
|
||||||
LanguageModelToolResultContent::Text(text)
|
LanguageModelToolResultContent::Text(text) => {
|
||||||
| LanguageModelToolResultContent::WrappedText(WrappedTextContent {
|
|
||||||
text,
|
|
||||||
..
|
|
||||||
}) => {
|
|
||||||
writeln!(markdown, "{text}")?;
|
writeln!(markdown, "{text}")?;
|
||||||
}
|
}
|
||||||
LanguageModelToolResultContent::Image(image) => {
|
LanguageModelToolResultContent::Image(image) => {
|
||||||
@@ -2842,6 +2844,7 @@ pub enum ThreadEvent {
|
|||||||
},
|
},
|
||||||
CheckpointChanged,
|
CheckpointChanged,
|
||||||
ToolConfirmationNeeded,
|
ToolConfirmationNeeded,
|
||||||
|
ToolUseLimitReached,
|
||||||
CancelEditing,
|
CancelEditing,
|
||||||
CompletionCanceled,
|
CompletionCanceled,
|
||||||
}
|
}
|
||||||
@@ -2941,7 +2944,7 @@ fn main() {{
|
|||||||
|
|
||||||
// Check message in request
|
// Check message in request
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.to_completion_request(model.clone(), cx)
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
assert_eq!(request.messages.len(), 2);
|
assert_eq!(request.messages.len(), 2);
|
||||||
@@ -3036,7 +3039,7 @@ fn main() {{
|
|||||||
|
|
||||||
// Check entire request to make sure all contexts are properly included
|
// Check entire request to make sure all contexts are properly included
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.to_completion_request(model.clone(), cx)
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
// The request should contain all 3 messages
|
// The request should contain all 3 messages
|
||||||
@@ -3143,7 +3146,7 @@ fn main() {{
|
|||||||
|
|
||||||
// Check message in request
|
// Check message in request
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.to_completion_request(model.clone(), cx)
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
assert_eq!(request.messages.len(), 2);
|
assert_eq!(request.messages.len(), 2);
|
||||||
@@ -3169,7 +3172,7 @@ fn main() {{
|
|||||||
|
|
||||||
// Check that both messages appear in the request
|
// Check that both messages appear in the request
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.to_completion_request(model.clone(), cx)
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
assert_eq!(request.messages.len(), 3);
|
assert_eq!(request.messages.len(), 3);
|
||||||
@@ -3214,7 +3217,7 @@ fn main() {{
|
|||||||
|
|
||||||
// Create a request and check that it doesn't have a stale buffer warning yet
|
// Create a request and check that it doesn't have a stale buffer warning yet
|
||||||
let initial_request = thread.update(cx, |thread, cx| {
|
let initial_request = thread.update(cx, |thread, cx| {
|
||||||
thread.to_completion_request(model.clone(), cx)
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Make sure we don't have a stale file warning yet
|
// Make sure we don't have a stale file warning yet
|
||||||
@@ -3250,7 +3253,7 @@ fn main() {{
|
|||||||
|
|
||||||
// Create a new request and check for the stale buffer warning
|
// Create a new request and check for the stale buffer warning
|
||||||
let new_request = thread.update(cx, |thread, cx| {
|
let new_request = thread.update(cx, |thread, cx| {
|
||||||
thread.to_completion_request(model.clone(), cx)
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
// We should have a stale file warning as the last message
|
// We should have a stale file warning as the last message
|
||||||
@@ -3300,7 +3303,7 @@ fn main() {{
|
|||||||
});
|
});
|
||||||
|
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.to_completion_request(model.clone(), cx)
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
assert_eq!(request.temperature, Some(0.66));
|
assert_eq!(request.temperature, Some(0.66));
|
||||||
|
|
||||||
@@ -3320,7 +3323,7 @@ fn main() {{
|
|||||||
});
|
});
|
||||||
|
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.to_completion_request(model.clone(), cx)
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
assert_eq!(request.temperature, Some(0.66));
|
assert_eq!(request.temperature, Some(0.66));
|
||||||
|
|
||||||
@@ -3340,7 +3343,7 @@ fn main() {{
|
|||||||
});
|
});
|
||||||
|
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.to_completion_request(model.clone(), cx)
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
assert_eq!(request.temperature, Some(0.66));
|
assert_eq!(request.temperature, Some(0.66));
|
||||||
|
|
||||||
@@ -3360,7 +3363,7 @@ fn main() {{
|
|||||||
});
|
});
|
||||||
|
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.to_completion_request(model.clone(), cx)
|
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
assert_eq!(request.temperature, None);
|
assert_eq!(request.temperature, None);
|
||||||
}
|
}
|
||||||
@@ -3392,7 +3395,12 @@ fn main() {{
|
|||||||
// Send a message
|
// Send a message
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
|
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
|
||||||
thread.send_to_model(model.clone(), None, cx);
|
thread.send_to_model(
|
||||||
|
model.clone(),
|
||||||
|
CompletionIntent::ThreadSummarization,
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
@@ -3487,7 +3495,7 @@ fn main() {{
|
|||||||
vec![],
|
vec![],
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
thread.send_to_model(model.clone(), None, cx);
|
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
|
||||||
});
|
});
|
||||||
|
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
@@ -3525,7 +3533,12 @@ fn main() {{
|
|||||||
) {
|
) {
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
|
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
|
||||||
thread.send_to_model(model.clone(), None, cx);
|
thread.send_to_model(
|
||||||
|
model.clone(),
|
||||||
|
CompletionIntent::ThreadSummarization,
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use std::borrow::Cow;
|
|
||||||
use std::cell::{Ref, RefCell};
|
use std::cell::{Ref, RefCell};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
|
use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
@@ -17,8 +16,7 @@ use gpui::{
|
|||||||
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
|
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
|
||||||
Subscription, Task, prelude::*,
|
Subscription, Task, prelude::*,
|
||||||
};
|
};
|
||||||
use heed::Database;
|
|
||||||
use heed::types::SerdeBincode;
|
|
||||||
use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
|
use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
|
||||||
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
||||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||||
@@ -35,6 +33,42 @@ use crate::context_server_tool::ContextServerTool;
|
|||||||
use crate::thread::{
|
use crate::thread::{
|
||||||
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
|
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
|
||||||
};
|
};
|
||||||
|
use indoc::indoc;
|
||||||
|
use sqlez::{
|
||||||
|
bindable::{Bind, Column},
|
||||||
|
connection::Connection,
|
||||||
|
statement::Statement,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum DataType {
|
||||||
|
#[serde(rename = "json")]
|
||||||
|
Json,
|
||||||
|
#[serde(rename = "zstd")]
|
||||||
|
Zstd,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Bind for DataType {
|
||||||
|
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||||
|
let value = match self {
|
||||||
|
DataType::Json => "json",
|
||||||
|
DataType::Zstd => "zstd",
|
||||||
|
};
|
||||||
|
value.bind(statement, start_index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Column for DataType {
|
||||||
|
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||||
|
let (value, next_index) = String::column(statement, start_index)?;
|
||||||
|
let data_type = match value.as_str() {
|
||||||
|
"json" => DataType::Json,
|
||||||
|
"zstd" => DataType::Zstd,
|
||||||
|
_ => anyhow::bail!("Unknown data type: {}", value),
|
||||||
|
};
|
||||||
|
Ok((data_type, next_index))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const RULES_FILE_NAMES: [&'static str; 6] = [
|
const RULES_FILE_NAMES: [&'static str; 6] = [
|
||||||
".rules",
|
".rules",
|
||||||
@@ -866,25 +900,27 @@ impl Global for GlobalThreadsDatabase {}
|
|||||||
|
|
||||||
pub(crate) struct ThreadsDatabase {
|
pub(crate) struct ThreadsDatabase {
|
||||||
executor: BackgroundExecutor,
|
executor: BackgroundExecutor,
|
||||||
env: heed::Env,
|
connection: Arc<Mutex<Connection>>,
|
||||||
threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl heed::BytesEncode<'_> for SerializedThread {
|
impl ThreadsDatabase {
|
||||||
type EItem = SerializedThread;
|
fn connection(&self) -> Arc<Mutex<Connection>> {
|
||||||
|
self.connection.clone()
|
||||||
|
}
|
||||||
|
|
||||||
fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
|
const COMPRESSION_LEVEL: i32 = 3;
|
||||||
serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
|
}
|
||||||
|
|
||||||
|
impl Bind for ThreadId {
|
||||||
|
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||||
|
self.to_string().bind(statement, start_index)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> heed::BytesDecode<'a> for SerializedThread {
|
impl Column for ThreadId {
|
||||||
type DItem = SerializedThread;
|
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||||
|
let (id_str, next_index) = String::column(statement, start_index)?;
|
||||||
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
|
Ok((ThreadId::from(id_str.as_str()), next_index))
|
||||||
// We implement this type manually because we want to call `SerializedThread::from_json`,
|
|
||||||
// instead of the Deserialize trait implementation for `SerializedThread`.
|
|
||||||
SerializedThread::from_json(bytes).map_err(Into::into)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -900,8 +936,8 @@ impl ThreadsDatabase {
|
|||||||
let database_future = executor
|
let database_future = executor
|
||||||
.spawn({
|
.spawn({
|
||||||
let executor = executor.clone();
|
let executor = executor.clone();
|
||||||
let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
|
let threads_dir = paths::data_dir().join("threads");
|
||||||
async move { ThreadsDatabase::new(database_path, executor) }
|
async move { ThreadsDatabase::new(threads_dir, executor) }
|
||||||
})
|
})
|
||||||
.then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
|
.then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
|
||||||
.boxed()
|
.boxed()
|
||||||
@@ -910,41 +946,144 @@ impl ThreadsDatabase {
|
|||||||
cx.set_global(GlobalThreadsDatabase(database_future));
|
cx.set_global(GlobalThreadsDatabase(database_future));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
|
pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
|
||||||
std::fs::create_dir_all(&path)?;
|
std::fs::create_dir_all(&threads_dir)?;
|
||||||
|
|
||||||
|
let sqlite_path = threads_dir.join("threads.db");
|
||||||
|
let mdb_path = threads_dir.join("threads-db.1.mdb");
|
||||||
|
|
||||||
|
let needs_migration_from_heed = mdb_path.exists();
|
||||||
|
|
||||||
|
let connection = Connection::open_file(&sqlite_path.to_string_lossy());
|
||||||
|
|
||||||
|
connection.exec(indoc! {"
|
||||||
|
CREATE TABLE IF NOT EXISTS threads (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
summary TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL,
|
||||||
|
data_type TEXT NOT NULL,
|
||||||
|
data BLOB NOT NULL
|
||||||
|
)
|
||||||
|
"})?()
|
||||||
|
.map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
|
||||||
|
|
||||||
|
let db = Self {
|
||||||
|
executor: executor.clone(),
|
||||||
|
connection: Arc::new(Mutex::new(connection)),
|
||||||
|
};
|
||||||
|
|
||||||
|
if needs_migration_from_heed {
|
||||||
|
let db_connection = db.connection();
|
||||||
|
let executor_clone = executor.clone();
|
||||||
|
executor
|
||||||
|
.spawn(async move {
|
||||||
|
log::info!("Starting threads.db migration");
|
||||||
|
Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
|
||||||
|
std::fs::remove_dir_all(mdb_path)?;
|
||||||
|
log::info!("threads.db migrated to sqlite");
|
||||||
|
Ok::<(), anyhow::Error>(())
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove this migration after 2025-09-01
|
||||||
|
fn migrate_from_heed(
|
||||||
|
mdb_path: &Path,
|
||||||
|
connection: Arc<Mutex<Connection>>,
|
||||||
|
_executor: BackgroundExecutor,
|
||||||
|
) -> Result<()> {
|
||||||
|
use heed::types::SerdeBincode;
|
||||||
|
struct SerializedThreadHeed(SerializedThread);
|
||||||
|
|
||||||
|
impl heed::BytesEncode<'_> for SerializedThreadHeed {
|
||||||
|
type EItem = SerializedThreadHeed;
|
||||||
|
|
||||||
|
fn bytes_encode(
|
||||||
|
item: &Self::EItem,
|
||||||
|
) -> Result<std::borrow::Cow<[u8]>, heed::BoxedError> {
|
||||||
|
serde_json::to_vec(&item.0)
|
||||||
|
.map(std::borrow::Cow::Owned)
|
||||||
|
.map_err(Into::into)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
|
||||||
|
type DItem = SerializedThreadHeed;
|
||||||
|
|
||||||
|
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
|
||||||
|
SerializedThread::from_json(bytes)
|
||||||
|
.map(SerializedThreadHeed)
|
||||||
|
.map_err(Into::into)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
|
const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
|
||||||
|
|
||||||
let env = unsafe {
|
let env = unsafe {
|
||||||
heed::EnvOpenOptions::new()
|
heed::EnvOpenOptions::new()
|
||||||
.map_size(ONE_GB_IN_BYTES)
|
.map_size(ONE_GB_IN_BYTES)
|
||||||
.max_dbs(1)
|
.max_dbs(1)
|
||||||
.open(path)?
|
.open(mdb_path)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut txn = env.write_txn()?;
|
let txn = env.write_txn()?;
|
||||||
let threads = env.create_database(&mut txn, Some("threads"))?;
|
let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
|
||||||
txn.commit()?;
|
.open_database(&txn, Some("threads"))?
|
||||||
|
.ok_or_else(|| anyhow!("threads database not found"))?;
|
||||||
|
|
||||||
Ok(Self {
|
for result in threads.iter(&txn)? {
|
||||||
executor,
|
let (thread_id, thread_heed) = result?;
|
||||||
env,
|
Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
|
||||||
threads,
|
}
|
||||||
})
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_thread_sync(
|
||||||
|
connection: &Arc<Mutex<Connection>>,
|
||||||
|
id: ThreadId,
|
||||||
|
thread: SerializedThread,
|
||||||
|
) -> Result<()> {
|
||||||
|
let json_data = serde_json::to_string(&thread)?;
|
||||||
|
let summary = thread.summary.to_string();
|
||||||
|
let updated_at = thread.updated_at.to_rfc3339();
|
||||||
|
|
||||||
|
let connection = connection.lock().unwrap();
|
||||||
|
|
||||||
|
let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
|
||||||
|
let data_type = DataType::Zstd;
|
||||||
|
let data = compressed;
|
||||||
|
|
||||||
|
let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
|
||||||
|
INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
|
||||||
|
"})?;
|
||||||
|
|
||||||
|
insert((id, summary, updated_at, data_type, data))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
|
pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
|
||||||
let env = self.env.clone();
|
let connection = self.connection.clone();
|
||||||
let threads = self.threads;
|
|
||||||
|
|
||||||
self.executor.spawn(async move {
|
self.executor.spawn(async move {
|
||||||
let txn = env.read_txn()?;
|
let connection = connection.lock().unwrap();
|
||||||
let mut iter = threads.iter(&txn)?;
|
let mut select =
|
||||||
|
connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
|
||||||
|
SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
|
||||||
|
"})?;
|
||||||
|
|
||||||
|
let rows = select(())?;
|
||||||
let mut threads = Vec::new();
|
let mut threads = Vec::new();
|
||||||
while let Some((key, value)) = iter.next().transpose()? {
|
|
||||||
|
for (id, summary, updated_at) in rows {
|
||||||
threads.push(SerializedThreadMetadata {
|
threads.push(SerializedThreadMetadata {
|
||||||
id: key,
|
id,
|
||||||
summary: value.summary,
|
summary: summary.into(),
|
||||||
updated_at: value.updated_at,
|
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -953,36 +1092,51 @@ impl ThreadsDatabase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
|
pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
|
||||||
let env = self.env.clone();
|
let connection = self.connection.clone();
|
||||||
let threads = self.threads;
|
|
||||||
|
|
||||||
self.executor.spawn(async move {
|
self.executor.spawn(async move {
|
||||||
let txn = env.read_txn()?;
|
let connection = connection.lock().unwrap();
|
||||||
let thread = threads.get(&txn, &id)?;
|
let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
|
||||||
Ok(thread)
|
SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
|
||||||
|
"})?;
|
||||||
|
|
||||||
|
let rows = select(id)?;
|
||||||
|
if let Some((data_type, data)) = rows.into_iter().next() {
|
||||||
|
let json_data = match data_type {
|
||||||
|
DataType::Zstd => {
|
||||||
|
let decompressed = zstd::decode_all(&data[..])?;
|
||||||
|
String::from_utf8(decompressed)?
|
||||||
|
}
|
||||||
|
DataType::Json => String::from_utf8(data)?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread = SerializedThread::from_json(json_data.as_bytes())?;
|
||||||
|
Ok(Some(thread))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
|
pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
|
||||||
let env = self.env.clone();
|
let connection = self.connection.clone();
|
||||||
let threads = self.threads;
|
|
||||||
|
|
||||||
self.executor.spawn(async move {
|
self.executor
|
||||||
let mut txn = env.write_txn()?;
|
.spawn(async move { Self::save_thread_sync(&connection, id, thread) })
|
||||||
threads.put(&mut txn, &id, &thread)?;
|
|
||||||
txn.commit()?;
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
|
pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
|
||||||
let env = self.env.clone();
|
let connection = self.connection.clone();
|
||||||
let threads = self.threads;
|
|
||||||
|
|
||||||
self.executor.spawn(async move {
|
self.executor.spawn(async move {
|
||||||
let mut txn = env.write_txn()?;
|
let connection = connection.lock().unwrap();
|
||||||
threads.delete(&mut txn, &id)?;
|
|
||||||
txn.commit()?;
|
let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
|
||||||
|
DELETE FROM threads WHERE id = ?
|
||||||
|
"})?;
|
||||||
|
|
||||||
|
delete(id)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -304,7 +304,7 @@ impl AddedContext {
|
|||||||
AgentContextHandle::Thread(handle) => Some(Self::pending_thread(handle, cx)),
|
AgentContextHandle::Thread(handle) => Some(Self::pending_thread(handle, cx)),
|
||||||
AgentContextHandle::TextThread(handle) => Some(Self::pending_text_thread(handle, cx)),
|
AgentContextHandle::TextThread(handle) => Some(Self::pending_text_thread(handle, cx)),
|
||||||
AgentContextHandle::Rules(handle) => Self::pending_rules(handle, prompt_store, cx),
|
AgentContextHandle::Rules(handle) => Self::pending_rules(handle, prompt_store, cx),
|
||||||
AgentContextHandle::Image(handle) => Some(Self::image(handle)),
|
AgentContextHandle::Image(handle) => Some(Self::image(handle, cx)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,7 +318,7 @@ impl AddedContext {
|
|||||||
AgentContext::Thread(context) => Self::attached_thread(context),
|
AgentContext::Thread(context) => Self::attached_thread(context),
|
||||||
AgentContext::TextThread(context) => Self::attached_text_thread(context),
|
AgentContext::TextThread(context) => Self::attached_text_thread(context),
|
||||||
AgentContext::Rules(context) => Self::attached_rules(context),
|
AgentContext::Rules(context) => Self::attached_rules(context),
|
||||||
AgentContext::Image(context) => Self::image(context.clone()),
|
AgentContext::Image(context) => Self::image(context.clone(), cx),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,14 +333,8 @@ impl AddedContext {
|
|||||||
|
|
||||||
fn file(handle: FileContextHandle, full_path: &Path, cx: &App) -> AddedContext {
|
fn file(handle: FileContextHandle, full_path: &Path, cx: &App) -> AddedContext {
|
||||||
let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into();
|
let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into();
|
||||||
let name = full_path
|
let (name, parent) =
|
||||||
.file_name()
|
extract_file_name_and_directory_from_full_path(full_path, &full_path_string);
|
||||||
.map(|n| n.to_string_lossy().into_owned().into())
|
|
||||||
.unwrap_or_else(|| full_path_string.clone());
|
|
||||||
let parent = full_path
|
|
||||||
.parent()
|
|
||||||
.and_then(|p| p.file_name())
|
|
||||||
.map(|n| n.to_string_lossy().into_owned().into());
|
|
||||||
AddedContext {
|
AddedContext {
|
||||||
kind: ContextKind::File,
|
kind: ContextKind::File,
|
||||||
name,
|
name,
|
||||||
@@ -370,14 +364,8 @@ impl AddedContext {
|
|||||||
|
|
||||||
fn directory(handle: DirectoryContextHandle, full_path: &Path) -> AddedContext {
|
fn directory(handle: DirectoryContextHandle, full_path: &Path) -> AddedContext {
|
||||||
let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into();
|
let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into();
|
||||||
let name = full_path
|
let (name, parent) =
|
||||||
.file_name()
|
extract_file_name_and_directory_from_full_path(full_path, &full_path_string);
|
||||||
.map(|n| n.to_string_lossy().into_owned().into())
|
|
||||||
.unwrap_or_else(|| full_path_string.clone());
|
|
||||||
let parent = full_path
|
|
||||||
.parent()
|
|
||||||
.and_then(|p| p.file_name())
|
|
||||||
.map(|n| n.to_string_lossy().into_owned().into());
|
|
||||||
AddedContext {
|
AddedContext {
|
||||||
kind: ContextKind::Directory,
|
kind: ContextKind::Directory,
|
||||||
name,
|
name,
|
||||||
@@ -605,13 +593,23 @@ impl AddedContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn image(context: ImageContext) -> AddedContext {
|
fn image(context: ImageContext, cx: &App) -> AddedContext {
|
||||||
|
let (name, parent, icon_path) = if let Some(full_path) = context.full_path.as_ref() {
|
||||||
|
let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into();
|
||||||
|
let (name, parent) =
|
||||||
|
extract_file_name_and_directory_from_full_path(full_path, &full_path_string);
|
||||||
|
let icon_path = FileIcons::get_icon(&full_path, cx);
|
||||||
|
(name, parent, icon_path)
|
||||||
|
} else {
|
||||||
|
("Image".into(), None, None)
|
||||||
|
};
|
||||||
|
|
||||||
AddedContext {
|
AddedContext {
|
||||||
kind: ContextKind::Image,
|
kind: ContextKind::Image,
|
||||||
name: "Image".into(),
|
name,
|
||||||
parent: None,
|
parent,
|
||||||
tooltip: None,
|
tooltip: None,
|
||||||
icon_path: None,
|
icon_path,
|
||||||
status: match context.status() {
|
status: match context.status() {
|
||||||
ImageStatus::Loading => ContextStatus::Loading {
|
ImageStatus::Loading => ContextStatus::Loading {
|
||||||
message: "Loading…".into(),
|
message: "Loading…".into(),
|
||||||
@@ -639,6 +637,22 @@ impl AddedContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn extract_file_name_and_directory_from_full_path(
|
||||||
|
path: &Path,
|
||||||
|
name_fallback: &SharedString,
|
||||||
|
) -> (SharedString, Option<SharedString>) {
|
||||||
|
let name = path
|
||||||
|
.file_name()
|
||||||
|
.map(|n| n.to_string_lossy().into_owned().into())
|
||||||
|
.unwrap_or_else(|| name_fallback.clone());
|
||||||
|
let parent = path
|
||||||
|
.parent()
|
||||||
|
.and_then(|p| p.file_name())
|
||||||
|
.map(|n| n.to_string_lossy().into_owned().into());
|
||||||
|
|
||||||
|
(name, parent)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct ContextFileExcerpt {
|
struct ContextFileExcerpt {
|
||||||
pub file_name_and_range: SharedString,
|
pub file_name_and_range: SharedString,
|
||||||
@@ -765,37 +779,49 @@ impl Component for AddedContext {
|
|||||||
let mut next_context_id = ContextId::zero();
|
let mut next_context_id = ContextId::zero();
|
||||||
let image_ready = (
|
let image_ready = (
|
||||||
"Ready",
|
"Ready",
|
||||||
AddedContext::image(ImageContext {
|
AddedContext::image(
|
||||||
context_id: next_context_id.post_inc(),
|
ImageContext {
|
||||||
project_path: None,
|
context_id: next_context_id.post_inc(),
|
||||||
original_image: Arc::new(Image::empty()),
|
project_path: None,
|
||||||
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
|
full_path: None,
|
||||||
}),
|
original_image: Arc::new(Image::empty()),
|
||||||
|
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
let image_loading = (
|
let image_loading = (
|
||||||
"Loading",
|
"Loading",
|
||||||
AddedContext::image(ImageContext {
|
AddedContext::image(
|
||||||
context_id: next_context_id.post_inc(),
|
ImageContext {
|
||||||
project_path: None,
|
context_id: next_context_id.post_inc(),
|
||||||
original_image: Arc::new(Image::empty()),
|
project_path: None,
|
||||||
image_task: cx
|
full_path: None,
|
||||||
.background_spawn(async move {
|
original_image: Arc::new(Image::empty()),
|
||||||
smol::Timer::after(Duration::from_secs(60 * 5)).await;
|
image_task: cx
|
||||||
Some(LanguageModelImage::empty())
|
.background_spawn(async move {
|
||||||
})
|
smol::Timer::after(Duration::from_secs(60 * 5)).await;
|
||||||
.shared(),
|
Some(LanguageModelImage::empty())
|
||||||
}),
|
})
|
||||||
|
.shared(),
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
let image_error = (
|
let image_error = (
|
||||||
"Error",
|
"Error",
|
||||||
AddedContext::image(ImageContext {
|
AddedContext::image(
|
||||||
context_id: next_context_id.post_inc(),
|
ImageContext {
|
||||||
project_path: None,
|
context_id: next_context_id.post_inc(),
|
||||||
original_image: Arc::new(Image::empty()),
|
project_path: None,
|
||||||
image_task: Task::ready(None).shared(),
|
full_path: None,
|
||||||
}),
|
original_image: Arc::new(Image::empty()),
|
||||||
|
image_task: Task::ready(None).shared(),
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
Some(
|
Some(
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use gpui::{Context, IntoElement, Render, Window};
|
use crate::ToggleBurnMode;
|
||||||
use ui::{prelude::*, tooltip_container};
|
use gpui::{Context, FontWeight, IntoElement, Render, Window};
|
||||||
|
use ui::{KeyBinding, prelude::*, tooltip_container};
|
||||||
|
|
||||||
pub struct MaxModeTooltip {
|
pub struct MaxModeTooltip {
|
||||||
selected: bool,
|
selected: bool,
|
||||||
@@ -18,39 +19,48 @@ impl MaxModeTooltip {
|
|||||||
|
|
||||||
impl Render for MaxModeTooltip {
|
impl Render for MaxModeTooltip {
|
||||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||||
let icon = if self.selected {
|
let (icon, color) = if self.selected {
|
||||||
IconName::ZedBurnModeOn
|
(IconName::ZedBurnModeOn, Color::Error)
|
||||||
} else {
|
} else {
|
||||||
IconName::ZedBurnMode
|
(IconName::ZedBurnMode, Color::Default)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let turned_on = h_flex()
|
||||||
|
.h_4()
|
||||||
|
.px_1()
|
||||||
|
.border_1()
|
||||||
|
.border_color(cx.theme().colors().border)
|
||||||
|
.bg(cx.theme().colors().text_accent.opacity(0.1))
|
||||||
|
.rounded_sm()
|
||||||
|
.child(
|
||||||
|
Label::new("ON")
|
||||||
|
.size(LabelSize::XSmall)
|
||||||
|
.weight(FontWeight::SEMIBOLD)
|
||||||
|
.color(Color::Accent),
|
||||||
|
);
|
||||||
|
|
||||||
let title = h_flex()
|
let title = h_flex()
|
||||||
.gap_1()
|
.gap_1p5()
|
||||||
.child(Icon::new(icon).size(IconSize::Small))
|
.child(Icon::new(icon).size(IconSize::Small).color(color))
|
||||||
.child(Label::new("Burn Mode"));
|
.child(Label::new("Burn Mode"))
|
||||||
|
.when(self.selected, |title| title.child(turned_on));
|
||||||
|
|
||||||
|
let keybinding = KeyBinding::for_action(&ToggleBurnMode, window, cx)
|
||||||
|
.map(|kb| kb.size(rems_from_px(12.)));
|
||||||
|
|
||||||
tooltip_container(window, cx, |this, _, _| {
|
tooltip_container(window, cx, |this, _, _| {
|
||||||
this.gap_0p5()
|
this
|
||||||
.map(|header| if self.selected {
|
.child(
|
||||||
header.child(
|
h_flex()
|
||||||
h_flex()
|
.justify_between()
|
||||||
.justify_between()
|
.child(title)
|
||||||
.child(title)
|
.children(keybinding)
|
||||||
.child(
|
)
|
||||||
h_flex()
|
|
||||||
.gap_0p5()
|
|
||||||
.child(Icon::new(IconName::Check).size(IconSize::XSmall).color(Color::Accent))
|
|
||||||
.child(Label::new("Turned On").size(LabelSize::XSmall).color(Color::Accent))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
header.child(title)
|
|
||||||
})
|
|
||||||
.child(
|
.child(
|
||||||
div()
|
div()
|
||||||
.max_w_72()
|
.max_w_64()
|
||||||
.child(
|
.child(
|
||||||
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning, offering an unfettered agentic experience.")
|
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
|
||||||
.size(LabelSize::Small)
|
.size(LabelSize::Small)
|
||||||
.color(Color::Muted)
|
.color(Color::Muted)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -372,6 +372,8 @@ impl AgentSettingsContent {
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
Some(language_model.supports_tools()),
|
Some(language_model.supports_tools()),
|
||||||
|
Some(language_model.supports_images()),
|
||||||
|
None,
|
||||||
)),
|
)),
|
||||||
api_url,
|
api_url,
|
||||||
});
|
});
|
||||||
@@ -689,14 +691,15 @@ pub struct AgentSettingsContentV2 {
|
|||||||
pub enum CompletionMode {
|
pub enum CompletionMode {
|
||||||
#[default]
|
#[default]
|
||||||
Normal,
|
Normal,
|
||||||
Max,
|
#[serde(alias = "max")]
|
||||||
|
Burn,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<CompletionMode> for zed_llm_client::CompletionMode {
|
impl From<CompletionMode> for zed_llm_client::CompletionMode {
|
||||||
fn from(value: CompletionMode) -> Self {
|
fn from(value: CompletionMode) -> Self {
|
||||||
match value {
|
match value {
|
||||||
CompletionMode::Normal => zed_llm_client::CompletionMode::Normal,
|
CompletionMode::Normal => zed_llm_client::CompletionMode::Normal,
|
||||||
CompletionMode::Max => zed_llm_client::CompletionMode::Max,
|
CompletionMode::Burn => zed_llm_client::CompletionMode::Max,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,8 +57,10 @@ uuid.workspace = true
|
|||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
workspace.workspace = true
|
workspace.workspace = true
|
||||||
zed_actions.workspace = true
|
zed_actions.workspace = true
|
||||||
|
zed_llm_client.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
indoc.workspace = true
|
||||||
language_model = { workspace = true, features = ["test-support"] }
|
language_model = { workspace = true, features = ["test-support"] }
|
||||||
languages = { workspace = true, features = ["test-support"] }
|
languages = { workspace = true, features = ["test-support"] }
|
||||||
pretty_assertions.workspace = true
|
pretty_assertions.workspace = true
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ use text::{BufferSnapshot, ToPoint};
|
|||||||
use ui::IconName;
|
use ui::IconName;
|
||||||
use util::{ResultExt, TryFutureExt, post_inc};
|
use util::{ResultExt, TryFutureExt, post_inc};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
use zed_llm_client::CompletionIntent;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||||
pub struct ContextId(String);
|
pub struct ContextId(String);
|
||||||
@@ -2272,6 +2273,7 @@ impl AssistantContext {
|
|||||||
let mut completion_request = LanguageModelRequest {
|
let mut completion_request = LanguageModelRequest {
|
||||||
thread_id: None,
|
thread_id: None,
|
||||||
prompt_id: None,
|
prompt_id: None,
|
||||||
|
intent: Some(CompletionIntent::UserPrompt),
|
||||||
mode: None,
|
mode: None,
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
language_model_selector::{
|
language_model_selector::{
|
||||||
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
|
LanguageModelSelector, ToggleModelSelector, language_model_selector,
|
||||||
},
|
},
|
||||||
max_mode_tooltip::MaxModeTooltip,
|
max_mode_tooltip::MaxModeTooltip,
|
||||||
};
|
};
|
||||||
@@ -43,7 +43,7 @@ use language_model::{
|
|||||||
Role,
|
Role,
|
||||||
};
|
};
|
||||||
use multi_buffer::MultiBufferRow;
|
use multi_buffer::MultiBufferRow;
|
||||||
use picker::Picker;
|
use picker::{Picker, popover_menu::PickerPopoverMenu};
|
||||||
use project::{Project, Worktree};
|
use project::{Project, Worktree};
|
||||||
use project::{ProjectPath, lsp_store::LocalLspAdapterDelegate};
|
use project::{ProjectPath, lsp_store::LocalLspAdapterDelegate};
|
||||||
use rope::Point;
|
use rope::Point;
|
||||||
@@ -283,7 +283,7 @@ impl ContextEditor {
|
|||||||
slash_menu_handle: Default::default(),
|
slash_menu_handle: Default::default(),
|
||||||
dragged_file_worktrees: Vec::new(),
|
dragged_file_worktrees: Vec::new(),
|
||||||
language_model_selector: cx.new(|cx| {
|
language_model_selector: cx.new(|cx| {
|
||||||
LanguageModelSelector::new(
|
language_model_selector(
|
||||||
|cx| LanguageModelRegistry::read_global(cx).default_model(),
|
|cx| LanguageModelRegistry::read_global(cx).default_model(),
|
||||||
move |model, cx| {
|
move |model, cx| {
|
||||||
update_settings_file::<AgentSettings>(
|
update_settings_file::<AgentSettings>(
|
||||||
@@ -1646,34 +1646,35 @@ impl ContextEditor {
|
|||||||
let context = self.context.read(cx);
|
let context = self.context.read(cx);
|
||||||
|
|
||||||
let mut text = String::new();
|
let mut text = String::new();
|
||||||
for message in context.messages(cx) {
|
|
||||||
if message.offset_range.start >= selection.range().end {
|
// If selection is empty, we want to copy the entire line
|
||||||
break;
|
if selection.range().is_empty() {
|
||||||
} else if message.offset_range.end >= selection.range().start {
|
let snapshot = context.buffer().read(cx).snapshot();
|
||||||
let range = cmp::max(message.offset_range.start, selection.range().start)
|
let point = snapshot.offset_to_point(selection.range().start);
|
||||||
..cmp::min(message.offset_range.end, selection.range().end);
|
selection.start = snapshot.point_to_offset(Point::new(point.row, 0));
|
||||||
if range.is_empty() {
|
selection.end = snapshot
|
||||||
let snapshot = context.buffer().read(cx).snapshot();
|
.point_to_offset(cmp::min(Point::new(point.row + 1, 0), snapshot.max_point()));
|
||||||
let point = snapshot.offset_to_point(range.start);
|
for chunk in context.buffer().read(cx).text_for_range(selection.range()) {
|
||||||
selection.start = snapshot.point_to_offset(Point::new(point.row, 0));
|
text.push_str(chunk);
|
||||||
selection.end = snapshot.point_to_offset(cmp::min(
|
}
|
||||||
Point::new(point.row + 1, 0),
|
} else {
|
||||||
snapshot.max_point(),
|
for message in context.messages(cx) {
|
||||||
));
|
if message.offset_range.start >= selection.range().end {
|
||||||
for chunk in context.buffer().read(cx).text_for_range(selection.range()) {
|
break;
|
||||||
text.push_str(chunk);
|
} else if message.offset_range.end >= selection.range().start {
|
||||||
}
|
let range = cmp::max(message.offset_range.start, selection.range().start)
|
||||||
} else {
|
..cmp::min(message.offset_range.end, selection.range().end);
|
||||||
for chunk in context.buffer().read(cx).text_for_range(range) {
|
if !range.is_empty() {
|
||||||
text.push_str(chunk);
|
for chunk in context.buffer().read(cx).text_for_range(range) {
|
||||||
}
|
text.push_str(chunk);
|
||||||
if message.offset_range.end < selection.range().end {
|
}
|
||||||
text.push('\n');
|
if message.offset_range.end < selection.range().end {
|
||||||
|
text.push('\n');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(text, CopyMetadata { creases }, vec![selection])
|
(text, CopyMetadata { creases }, vec![selection])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2071,8 +2072,8 @@ impl ContextEditor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let active_completion_mode = context.completion_mode();
|
let active_completion_mode = context.completion_mode();
|
||||||
let max_mode_enabled = active_completion_mode == CompletionMode::Max;
|
let burn_mode_enabled = active_completion_mode == CompletionMode::Burn;
|
||||||
let icon = if max_mode_enabled {
|
let icon = if burn_mode_enabled {
|
||||||
IconName::ZedBurnModeOn
|
IconName::ZedBurnModeOn
|
||||||
} else {
|
} else {
|
||||||
IconName::ZedBurnMode
|
IconName::ZedBurnMode
|
||||||
@@ -2082,25 +2083,29 @@ impl ContextEditor {
|
|||||||
IconButton::new("burn-mode", icon)
|
IconButton::new("burn-mode", icon)
|
||||||
.icon_size(IconSize::Small)
|
.icon_size(IconSize::Small)
|
||||||
.icon_color(Color::Muted)
|
.icon_color(Color::Muted)
|
||||||
.toggle_state(max_mode_enabled)
|
.toggle_state(burn_mode_enabled)
|
||||||
.selected_icon_color(Color::Error)
|
.selected_icon_color(Color::Error)
|
||||||
.on_click(cx.listener(move |this, _event, _window, cx| {
|
.on_click(cx.listener(move |this, _event, _window, cx| {
|
||||||
this.context().update(cx, |context, _cx| {
|
this.context().update(cx, |context, _cx| {
|
||||||
context.set_completion_mode(match active_completion_mode {
|
context.set_completion_mode(match active_completion_mode {
|
||||||
CompletionMode::Max => CompletionMode::Normal,
|
CompletionMode::Burn => CompletionMode::Normal,
|
||||||
CompletionMode::Normal => CompletionMode::Max,
|
CompletionMode::Normal => CompletionMode::Burn,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}))
|
}))
|
||||||
.tooltip(move |_window, cx| {
|
.tooltip(move |_window, cx| {
|
||||||
cx.new(|_| MaxModeTooltip::new().selected(max_mode_enabled))
|
cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
|
||||||
.into()
|
.into()
|
||||||
})
|
})
|
||||||
.into_any_element(),
|
.into_any_element(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_language_model_selector(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
fn render_language_model_selector(
|
||||||
|
&self,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> impl IntoElement {
|
||||||
let active_model = LanguageModelRegistry::read_global(cx)
|
let active_model = LanguageModelRegistry::read_global(cx)
|
||||||
.default_model()
|
.default_model()
|
||||||
.map(|default| default.model);
|
.map(|default| default.model);
|
||||||
@@ -2110,7 +2115,7 @@ impl ContextEditor {
|
|||||||
None => SharedString::from("No model selected"),
|
None => SharedString::from("No model selected"),
|
||||||
};
|
};
|
||||||
|
|
||||||
LanguageModelSelectorPopoverMenu::new(
|
PickerPopoverMenu::new(
|
||||||
self.language_model_selector.clone(),
|
self.language_model_selector.clone(),
|
||||||
ButtonLike::new("active-model")
|
ButtonLike::new("active-model")
|
||||||
.style(ButtonStyle::Subtle)
|
.style(ButtonStyle::Subtle)
|
||||||
@@ -2138,8 +2143,10 @@ impl ContextEditor {
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
gpui::Corner::BottomLeft,
|
gpui::Corner::BottomLeft,
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
.with_handle(self.language_model_selector_menu_handle.clone())
|
.with_handle(self.language_model_selector_menu_handle.clone())
|
||||||
|
.render(window, cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||||
@@ -2615,7 +2622,7 @@ impl Render for ContextEditor {
|
|||||||
.child(
|
.child(
|
||||||
h_flex()
|
h_flex()
|
||||||
.gap_1()
|
.gap_1()
|
||||||
.child(self.render_language_model_selector(cx))
|
.child(self.render_language_model_selector(window, cx))
|
||||||
.child(self.render_send_button(window, cx)),
|
.child(self.render_send_button(window, cx)),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -3258,74 +3265,92 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use fs::FakeFs;
|
use fs::FakeFs;
|
||||||
use gpui::{App, TestAppContext, VisualTestContext};
|
use gpui::{App, TestAppContext, VisualTestContext};
|
||||||
|
use indoc::indoc;
|
||||||
use language::{Buffer, LanguageRegistry};
|
use language::{Buffer, LanguageRegistry};
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
use prompt_store::PromptBuilder;
|
use prompt_store::PromptBuilder;
|
||||||
|
use text::OffsetRangeExt;
|
||||||
use unindent::Unindent;
|
use unindent::Unindent;
|
||||||
use util::path;
|
use util::path;
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_copy_paste_whole_message(cx: &mut TestAppContext) {
|
||||||
|
let (context, context_editor, mut cx) = setup_context_editor_text(vec![
|
||||||
|
(Role::User, "What is the Zed editor?"),
|
||||||
|
(
|
||||||
|
Role::Assistant,
|
||||||
|
"Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.",
|
||||||
|
),
|
||||||
|
(Role::User, ""),
|
||||||
|
],cx).await;
|
||||||
|
|
||||||
|
// Select & Copy whole user message
|
||||||
|
assert_copy_paste_context_editor(
|
||||||
|
&context_editor,
|
||||||
|
message_range(&context, 0, &mut cx),
|
||||||
|
indoc! {"
|
||||||
|
What is the Zed editor?
|
||||||
|
Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.
|
||||||
|
What is the Zed editor?
|
||||||
|
"},
|
||||||
|
&mut cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Select & Copy whole assistant message
|
||||||
|
assert_copy_paste_context_editor(
|
||||||
|
&context_editor,
|
||||||
|
message_range(&context, 1, &mut cx),
|
||||||
|
indoc! {"
|
||||||
|
What is the Zed editor?
|
||||||
|
Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.
|
||||||
|
What is the Zed editor?
|
||||||
|
Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.
|
||||||
|
"},
|
||||||
|
&mut cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_copy_paste_no_selection(cx: &mut TestAppContext) {
|
async fn test_copy_paste_no_selection(cx: &mut TestAppContext) {
|
||||||
cx.update(init_test);
|
let (context, context_editor, mut cx) = setup_context_editor_text(
|
||||||
|
vec![
|
||||||
|
(Role::User, "user1"),
|
||||||
|
(Role::Assistant, "assistant1"),
|
||||||
|
(Role::Assistant, "assistant2"),
|
||||||
|
(Role::User, ""),
|
||||||
|
],
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
let fs = FakeFs::new(cx.executor());
|
// Copy and paste first assistant message
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
let message_2_range = message_range(&context, 1, &mut cx);
|
||||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
assert_copy_paste_context_editor(
|
||||||
let context = cx.new(|cx| {
|
&context_editor,
|
||||||
AssistantContext::local(
|
message_2_range.start..message_2_range.start,
|
||||||
registry,
|
indoc! {"
|
||||||
None,
|
user1
|
||||||
None,
|
assistant1
|
||||||
prompt_builder.clone(),
|
assistant2
|
||||||
Arc::new(SlashCommandWorkingSet::default()),
|
assistant1
|
||||||
cx,
|
"},
|
||||||
)
|
&mut cx,
|
||||||
});
|
);
|
||||||
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
|
||||||
let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
|
||||||
let workspace = window.root(cx).unwrap();
|
|
||||||
let cx = &mut VisualTestContext::from_window(*window, cx);
|
|
||||||
|
|
||||||
let context_editor = window
|
// Copy and cut second assistant message
|
||||||
.update(cx, |_, window, cx| {
|
let message_3_range = message_range(&context, 2, &mut cx);
|
||||||
cx.new(|cx| {
|
assert_copy_paste_context_editor(
|
||||||
ContextEditor::for_context(
|
&context_editor,
|
||||||
context,
|
message_3_range.start..message_3_range.start,
|
||||||
fs,
|
indoc! {"
|
||||||
workspace.downgrade(),
|
user1
|
||||||
project,
|
assistant1
|
||||||
None,
|
assistant2
|
||||||
window,
|
assistant1
|
||||||
cx,
|
assistant2
|
||||||
)
|
"},
|
||||||
})
|
&mut cx,
|
||||||
})
|
);
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
context_editor.update_in(cx, |context_editor, window, cx| {
|
|
||||||
context_editor.editor.update(cx, |editor, cx| {
|
|
||||||
editor.set_text("abc\ndef\nghi", window, cx);
|
|
||||||
editor.move_to_beginning(&Default::default(), window, cx);
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
context_editor.update_in(cx, |context_editor, window, cx| {
|
|
||||||
context_editor.editor.update(cx, |editor, cx| {
|
|
||||||
editor.copy(&Default::default(), window, cx);
|
|
||||||
editor.paste(&Default::default(), window, cx);
|
|
||||||
|
|
||||||
assert_eq!(editor.text(cx), "abc\nabc\ndef\nghi");
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
context_editor.update_in(cx, |context_editor, window, cx| {
|
|
||||||
context_editor.editor.update(cx, |editor, cx| {
|
|
||||||
editor.cut(&Default::default(), window, cx);
|
|
||||||
assert_eq!(editor.text(cx), "abc\ndef\nghi");
|
|
||||||
|
|
||||||
editor.paste(&Default::default(), window, cx);
|
|
||||||
assert_eq!(editor.text(cx), "abc\nabc\ndef\nghi");
|
|
||||||
})
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
@@ -3402,6 +3427,129 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn setup_context_editor_text(
|
||||||
|
messages: Vec<(Role, &str)>,
|
||||||
|
cx: &mut TestAppContext,
|
||||||
|
) -> (
|
||||||
|
Entity<AssistantContext>,
|
||||||
|
Entity<ContextEditor>,
|
||||||
|
VisualTestContext,
|
||||||
|
) {
|
||||||
|
cx.update(init_test);
|
||||||
|
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
let context = create_context_with_messages(messages, cx);
|
||||||
|
|
||||||
|
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
||||||
|
let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||||
|
let workspace = window.root(cx).unwrap();
|
||||||
|
let mut cx = VisualTestContext::from_window(*window, cx);
|
||||||
|
|
||||||
|
let context_editor = window
|
||||||
|
.update(&mut cx, |_, window, cx| {
|
||||||
|
cx.new(|cx| {
|
||||||
|
let editor = ContextEditor::for_context(
|
||||||
|
context.clone(),
|
||||||
|
fs,
|
||||||
|
workspace.downgrade(),
|
||||||
|
project,
|
||||||
|
None,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
editor
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
(context, context_editor, cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn message_range(
|
||||||
|
context: &Entity<AssistantContext>,
|
||||||
|
message_ix: usize,
|
||||||
|
cx: &mut TestAppContext,
|
||||||
|
) -> Range<usize> {
|
||||||
|
context.update(cx, |context, cx| {
|
||||||
|
context
|
||||||
|
.messages(cx)
|
||||||
|
.nth(message_ix)
|
||||||
|
.unwrap()
|
||||||
|
.anchor_range
|
||||||
|
.to_offset(&context.buffer().read(cx).snapshot())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_copy_paste_context_editor<T: editor::ToOffset>(
|
||||||
|
context_editor: &Entity<ContextEditor>,
|
||||||
|
range: Range<T>,
|
||||||
|
expected_text: &str,
|
||||||
|
cx: &mut VisualTestContext,
|
||||||
|
) {
|
||||||
|
context_editor.update_in(cx, |context_editor, window, cx| {
|
||||||
|
context_editor.editor.update(cx, |editor, cx| {
|
||||||
|
editor.change_selections(None, window, cx, |s| s.select_ranges([range]));
|
||||||
|
});
|
||||||
|
|
||||||
|
context_editor.copy(&Default::default(), window, cx);
|
||||||
|
|
||||||
|
context_editor.editor.update(cx, |editor, cx| {
|
||||||
|
editor.move_to_end(&Default::default(), window, cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
context_editor.paste(&Default::default(), window, cx);
|
||||||
|
|
||||||
|
context_editor.editor.update(cx, |editor, cx| {
|
||||||
|
assert_eq!(editor.text(cx), expected_text);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_context_with_messages(
|
||||||
|
mut messages: Vec<(Role, &str)>,
|
||||||
|
cx: &mut TestAppContext,
|
||||||
|
) -> Entity<AssistantContext> {
|
||||||
|
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||||
|
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||||
|
cx.new(|cx| {
|
||||||
|
let mut context = AssistantContext::local(
|
||||||
|
registry,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
prompt_builder.clone(),
|
||||||
|
Arc::new(SlashCommandWorkingSet::default()),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
let mut message_1 = context.messages(cx).next().unwrap();
|
||||||
|
let (role, text) = messages.remove(0);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if role == message_1.role {
|
||||||
|
context.buffer().update(cx, |buffer, cx| {
|
||||||
|
buffer.edit([(message_1.offset_range, text)], None, cx);
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let mut ids = HashSet::default();
|
||||||
|
ids.insert(message_1.id);
|
||||||
|
context.cycle_message_roles(ids, cx);
|
||||||
|
message_1 = context.messages(cx).next().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut last_message_id = message_1.id;
|
||||||
|
for (role, text) in messages {
|
||||||
|
context.insert_message_after(last_message_id, role, MessageStatus::Done, cx);
|
||||||
|
let message = context.messages(cx).last().unwrap();
|
||||||
|
last_message_id = message.id;
|
||||||
|
context.buffer().update(cx, |buffer, cx| {
|
||||||
|
buffer.edit([(message.offset_range, text)], None, cx);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
context
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn init_test(cx: &mut App) {
|
fn init_test(cx: &mut App) {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
prompt_store::init(cx);
|
prompt_store::init(cx);
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ use collections::{HashSet, IndexMap};
|
|||||||
use feature_flags::ZedProFeatureFlag;
|
use feature_flags::ZedProFeatureFlag;
|
||||||
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
|
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
|
Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task,
|
||||||
EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
|
|
||||||
action_with_deprecated_aliases,
|
action_with_deprecated_aliases,
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
@@ -15,7 +14,7 @@ use language_model::{
|
|||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use picker::{Picker, PickerDelegate};
|
use picker::{Picker, PickerDelegate};
|
||||||
use proto::Plan;
|
use proto::Plan;
|
||||||
use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
|
use ui::{ListItem, ListItemSpacing, prelude::*};
|
||||||
|
|
||||||
action_with_deprecated_aliases!(
|
action_with_deprecated_aliases!(
|
||||||
agent,
|
agent,
|
||||||
@@ -31,77 +30,128 @@ const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
|
|||||||
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
|
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
|
||||||
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
|
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
|
||||||
|
|
||||||
pub struct LanguageModelSelector {
|
pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
|
||||||
picker: Entity<Picker<LanguageModelPickerDelegate>>,
|
|
||||||
|
pub fn language_model_selector(
|
||||||
|
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
|
||||||
|
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<LanguageModelSelector>,
|
||||||
|
) -> LanguageModelSelector {
|
||||||
|
let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
|
||||||
|
Picker::list(delegate, window, cx)
|
||||||
|
.show_scrollbar(true)
|
||||||
|
.width(rems(20.))
|
||||||
|
.max_height(Some(rems(20.).into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn all_models(cx: &App) -> GroupedModels {
|
||||||
|
let providers = LanguageModelRegistry::global(cx).read(cx).providers();
|
||||||
|
|
||||||
|
let recommended = providers
|
||||||
|
.iter()
|
||||||
|
.flat_map(|provider| {
|
||||||
|
provider
|
||||||
|
.recommended_models(cx)
|
||||||
|
.into_iter()
|
||||||
|
.map(|model| ModelInfo {
|
||||||
|
model,
|
||||||
|
icon: provider.icon(),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let other = providers
|
||||||
|
.iter()
|
||||||
|
.flat_map(|provider| {
|
||||||
|
provider
|
||||||
|
.provided_models(cx)
|
||||||
|
.into_iter()
|
||||||
|
.map(|model| ModelInfo {
|
||||||
|
model,
|
||||||
|
icon: provider.icon(),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
GroupedModels::new(other, recommended)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct ModelInfo {
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
icon: IconName,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LanguageModelPickerDelegate {
|
||||||
|
on_model_changed: OnModelChanged,
|
||||||
|
get_active_model: GetActiveModel,
|
||||||
|
all_models: Arc<GroupedModels>,
|
||||||
|
filtered_entries: Vec<LanguageModelPickerEntry>,
|
||||||
|
selected_index: usize,
|
||||||
_authenticate_all_providers_task: Task<()>,
|
_authenticate_all_providers_task: Task<()>,
|
||||||
_subscriptions: Vec<Subscription>,
|
_subscriptions: Vec<Subscription>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelSelector {
|
impl LanguageModelPickerDelegate {
|
||||||
pub fn new(
|
fn new(
|
||||||
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
|
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
|
||||||
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
|
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Picker<Self>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let on_model_changed = Arc::new(on_model_changed);
|
let on_model_changed = Arc::new(on_model_changed);
|
||||||
|
let models = all_models(cx);
|
||||||
|
let entries = models.entries();
|
||||||
|
|
||||||
let all_models = Self::all_models(cx);
|
Self {
|
||||||
let entries = all_models.entries();
|
|
||||||
|
|
||||||
let delegate = LanguageModelPickerDelegate {
|
|
||||||
language_model_selector: cx.entity().downgrade(),
|
|
||||||
on_model_changed: on_model_changed.clone(),
|
on_model_changed: on_model_changed.clone(),
|
||||||
all_models: Arc::new(all_models),
|
all_models: Arc::new(models),
|
||||||
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
|
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
|
||||||
filtered_entries: entries,
|
filtered_entries: entries,
|
||||||
get_active_model: Arc::new(get_active_model),
|
get_active_model: Arc::new(get_active_model),
|
||||||
};
|
|
||||||
|
|
||||||
let picker = cx.new(|cx| {
|
|
||||||
Picker::list(delegate, window, cx)
|
|
||||||
.show_scrollbar(true)
|
|
||||||
.width(rems(20.))
|
|
||||||
.max_height(Some(rems(20.).into()))
|
|
||||||
});
|
|
||||||
|
|
||||||
let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
|
|
||||||
|
|
||||||
LanguageModelSelector {
|
|
||||||
picker,
|
|
||||||
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
|
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
|
||||||
_subscriptions: vec![
|
_subscriptions: vec![cx.subscribe_in(
|
||||||
cx.subscribe_in(
|
&LanguageModelRegistry::global(cx),
|
||||||
&LanguageModelRegistry::global(cx),
|
window,
|
||||||
window,
|
|picker, _, event, window, cx| {
|
||||||
Self::handle_language_model_registry_event,
|
match event {
|
||||||
),
|
language_model::Event::ProviderStateChanged
|
||||||
subscription,
|
| language_model::Event::AddedProvider(_)
|
||||||
],
|
| language_model::Event::RemovedProvider(_) => {
|
||||||
|
let query = picker.query(cx);
|
||||||
|
picker.delegate.all_models = Arc::new(all_models(cx));
|
||||||
|
// Update matches will automatically drop the previous task
|
||||||
|
// if we get a provider event again
|
||||||
|
picker.update_matches(query, window, cx)
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_language_model_registry_event(
|
fn get_active_model_index(
|
||||||
&mut self,
|
entries: &[LanguageModelPickerEntry],
|
||||||
_registry: &Entity<LanguageModelRegistry>,
|
active_model: Option<ConfiguredModel>,
|
||||||
event: &language_model::Event,
|
) -> usize {
|
||||||
window: &mut Window,
|
entries
|
||||||
cx: &mut Context<Self>,
|
.iter()
|
||||||
) {
|
.position(|entry| {
|
||||||
match event {
|
if let LanguageModelPickerEntry::Model(model) = entry {
|
||||||
language_model::Event::ProviderStateChanged
|
active_model
|
||||||
| language_model::Event::AddedProvider(_)
|
.as_ref()
|
||||||
| language_model::Event::RemovedProvider(_) => {
|
.map(|active_model| {
|
||||||
self.picker.update(cx, |this, cx| {
|
active_model.model.id() == model.model.id()
|
||||||
let query = this.query(cx);
|
&& active_model.provider.id() == model.model.provider_id()
|
||||||
this.delegate.all_models = Arc::new(Self::all_models(cx));
|
})
|
||||||
// Update matches will automatically drop the previous task
|
.unwrap_or_default()
|
||||||
// if we get a provider event again
|
} else {
|
||||||
this.update_matches(query, window, cx)
|
false
|
||||||
});
|
}
|
||||||
}
|
})
|
||||||
_ => {}
|
.unwrap_or(0)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Authenticates all providers in the [`LanguageModelRegistry`].
|
/// Authenticates all providers in the [`LanguageModelRegistry`].
|
||||||
@@ -154,169 +204,9 @@ impl LanguageModelSelector {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn all_models(cx: &App) -> GroupedModels {
|
|
||||||
let mut recommended = Vec::new();
|
|
||||||
let mut recommended_set = HashSet::default();
|
|
||||||
for provider in LanguageModelRegistry::global(cx)
|
|
||||||
.read(cx)
|
|
||||||
.providers()
|
|
||||||
.iter()
|
|
||||||
{
|
|
||||||
let models = provider.recommended_models(cx);
|
|
||||||
recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
|
|
||||||
recommended.extend(
|
|
||||||
provider
|
|
||||||
.recommended_models(cx)
|
|
||||||
.into_iter()
|
|
||||||
.map(move |model| ModelInfo {
|
|
||||||
model: model.clone(),
|
|
||||||
icon: provider.icon(),
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let other_models = LanguageModelRegistry::global(cx)
|
|
||||||
.read(cx)
|
|
||||||
.providers()
|
|
||||||
.iter()
|
|
||||||
.map(|provider| {
|
|
||||||
(
|
|
||||||
provider.id(),
|
|
||||||
provider
|
|
||||||
.provided_models(cx)
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|model| {
|
|
||||||
let not_included =
|
|
||||||
!recommended_set.contains(&(model.provider_id(), model.id()));
|
|
||||||
not_included.then(|| ModelInfo {
|
|
||||||
model: model.clone(),
|
|
||||||
icon: provider.icon(),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect::<IndexMap<_, _>>();
|
|
||||||
|
|
||||||
GroupedModels {
|
|
||||||
recommended,
|
|
||||||
other: other_models,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
|
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
|
||||||
(self.picker.read(cx).delegate.get_active_model)(cx)
|
(self.get_active_model)(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_active_model_index(
|
|
||||||
entries: &[LanguageModelPickerEntry],
|
|
||||||
active_model: Option<ConfiguredModel>,
|
|
||||||
) -> usize {
|
|
||||||
entries
|
|
||||||
.iter()
|
|
||||||
.position(|entry| {
|
|
||||||
if let LanguageModelPickerEntry::Model(model) = entry {
|
|
||||||
active_model
|
|
||||||
.as_ref()
|
|
||||||
.map(|active_model| {
|
|
||||||
active_model.model.id() == model.model.id()
|
|
||||||
&& active_model.provider.id() == model.model.provider_id()
|
|
||||||
})
|
|
||||||
.unwrap_or_default()
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap_or(0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EventEmitter<DismissEvent> for LanguageModelSelector {}
|
|
||||||
|
|
||||||
impl Focusable for LanguageModelSelector {
|
|
||||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
|
||||||
self.picker.focus_handle(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Render for LanguageModelSelector {
|
|
||||||
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
|
|
||||||
self.picker.clone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(IntoElement)]
|
|
||||||
pub struct LanguageModelSelectorPopoverMenu<T, TT>
|
|
||||||
where
|
|
||||||
T: PopoverTrigger + ButtonCommon,
|
|
||||||
TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
|
|
||||||
{
|
|
||||||
language_model_selector: Entity<LanguageModelSelector>,
|
|
||||||
trigger: T,
|
|
||||||
tooltip: TT,
|
|
||||||
handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
|
|
||||||
anchor: Corner,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
|
|
||||||
where
|
|
||||||
T: PopoverTrigger + ButtonCommon,
|
|
||||||
TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
|
|
||||||
{
|
|
||||||
pub fn new(
|
|
||||||
language_model_selector: Entity<LanguageModelSelector>,
|
|
||||||
trigger: T,
|
|
||||||
tooltip: TT,
|
|
||||||
anchor: Corner,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
language_model_selector,
|
|
||||||
trigger,
|
|
||||||
tooltip,
|
|
||||||
handle: None,
|
|
||||||
anchor,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
|
|
||||||
self.handle = Some(handle);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
|
|
||||||
where
|
|
||||||
T: PopoverTrigger + ButtonCommon,
|
|
||||||
TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
|
|
||||||
{
|
|
||||||
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
|
|
||||||
let language_model_selector = self.language_model_selector.clone();
|
|
||||||
|
|
||||||
PopoverMenu::new("model-switcher")
|
|
||||||
.menu(move |_window, _cx| Some(language_model_selector.clone()))
|
|
||||||
.trigger_with_tooltip(self.trigger, self.tooltip)
|
|
||||||
.anchor(self.anchor)
|
|
||||||
.when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
|
|
||||||
.offset(gpui::Point {
|
|
||||||
x: px(0.0),
|
|
||||||
y: px(-2.0),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct ModelInfo {
|
|
||||||
model: Arc<dyn LanguageModel>,
|
|
||||||
icon: IconName,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct LanguageModelPickerDelegate {
|
|
||||||
language_model_selector: WeakEntity<LanguageModelSelector>,
|
|
||||||
on_model_changed: OnModelChanged,
|
|
||||||
get_active_model: GetActiveModel,
|
|
||||||
all_models: Arc<GroupedModels>,
|
|
||||||
filtered_entries: Vec<LanguageModelPickerEntry>,
|
|
||||||
selected_index: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GroupedModels {
|
struct GroupedModels {
|
||||||
@@ -326,11 +216,14 @@ struct GroupedModels {
|
|||||||
|
|
||||||
impl GroupedModels {
|
impl GroupedModels {
|
||||||
pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
|
pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
|
||||||
let recommended_ids: HashSet<_> = recommended.iter().map(|info| info.model.id()).collect();
|
let recommended_ids = recommended
|
||||||
|
.iter()
|
||||||
|
.map(|info| (info.model.provider_id(), info.model.id()))
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
|
let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
|
||||||
for model in other {
|
for model in other {
|
||||||
if recommended_ids.contains(&model.model.id()) {
|
if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -577,9 +470,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||||
self.language_model_selector
|
cx.emit(DismissEvent);
|
||||||
.update(cx, |_this, cx| cx.emit(DismissEvent))
|
|
||||||
.ok();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_match(
|
fn render_match(
|
||||||
@@ -917,4 +808,26 @@ mod tests {
|
|||||||
// Recommended models should not appear in "other"
|
// Recommended models should not appear in "other"
|
||||||
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
|
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
|
||||||
|
let recommended_models = create_models(vec![("zed", "claude")]);
|
||||||
|
let all_models = create_models(vec![
|
||||||
|
("zed", "claude"), // Should be filtered out from "other"
|
||||||
|
("zed", "gemini"),
|
||||||
|
("copilot", "claude"), // Should not be filtered out from "other"
|
||||||
|
]);
|
||||||
|
|
||||||
|
let grouped_models = GroupedModels::new(all_models, recommended_models);
|
||||||
|
|
||||||
|
let actual_other_models = grouped_models
|
||||||
|
.other
|
||||||
|
.values()
|
||||||
|
.flatten()
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Recommended models should not appear in "other"
|
||||||
|
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use gpui::{Context, IntoElement, Render, Window};
|
use gpui::{Context, FontWeight, IntoElement, Render, Window};
|
||||||
use ui::{prelude::*, tooltip_container};
|
use ui::{prelude::*, tooltip_container};
|
||||||
|
|
||||||
pub struct MaxModeTooltip {
|
pub struct MaxModeTooltip {
|
||||||
@@ -18,39 +18,40 @@ impl MaxModeTooltip {
|
|||||||
|
|
||||||
impl Render for MaxModeTooltip {
|
impl Render for MaxModeTooltip {
|
||||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||||
let icon = if self.selected {
|
let (icon, color) = if self.selected {
|
||||||
IconName::ZedBurnModeOn
|
(IconName::ZedBurnModeOn, Color::Error)
|
||||||
} else {
|
} else {
|
||||||
IconName::ZedBurnMode
|
(IconName::ZedBurnMode, Color::Default)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let turned_on = h_flex()
|
||||||
|
.h_4()
|
||||||
|
.px_1()
|
||||||
|
.border_1()
|
||||||
|
.border_color(cx.theme().colors().border)
|
||||||
|
.bg(cx.theme().colors().text_accent.opacity(0.1))
|
||||||
|
.rounded_sm()
|
||||||
|
.child(
|
||||||
|
Label::new("ON")
|
||||||
|
.size(LabelSize::XSmall)
|
||||||
|
.weight(FontWeight::SEMIBOLD)
|
||||||
|
.color(Color::Accent),
|
||||||
|
);
|
||||||
|
|
||||||
let title = h_flex()
|
let title = h_flex()
|
||||||
.gap_1()
|
.gap_1p5()
|
||||||
.child(Icon::new(icon).size(IconSize::Small))
|
.child(Icon::new(icon).size(IconSize::Small).color(color))
|
||||||
.child(Label::new("Burn Mode"));
|
.child(Label::new("Burn Mode"))
|
||||||
|
.when(self.selected, |title| title.child(turned_on));
|
||||||
|
|
||||||
tooltip_container(window, cx, |this, _, _| {
|
tooltip_container(window, cx, |this, _, _| {
|
||||||
this.gap_0p5()
|
this
|
||||||
.map(|header| if self.selected {
|
.child(title)
|
||||||
header.child(
|
|
||||||
h_flex()
|
|
||||||
.justify_between()
|
|
||||||
.child(title)
|
|
||||||
.child(
|
|
||||||
h_flex()
|
|
||||||
.gap_0p5()
|
|
||||||
.child(Icon::new(IconName::Check).size(IconSize::XSmall).color(Color::Accent))
|
|
||||||
.child(Label::new("Turned On").size(LabelSize::XSmall).color(Color::Accent))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
header.child(title)
|
|
||||||
})
|
|
||||||
.child(
|
.child(
|
||||||
div()
|
div()
|
||||||
.max_w_72()
|
.max_w_64()
|
||||||
.child(
|
.child(
|
||||||
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning, offering an unfettered agentic experience.")
|
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
|
||||||
.size(LabelSize::Small)
|
.size(LabelSize::Small)
|
||||||
.color(Color::Muted)
|
.color(Color::Muted)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ impl SlashCommandCompletionProvider {
|
|||||||
name_range: Range<Anchor>,
|
name_range: Range<Anchor>,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Option<Vec<project::Completion>>>> {
|
) -> Task<Result<Vec<project::CompletionResponse>>> {
|
||||||
let slash_commands = self.slash_commands.clone();
|
let slash_commands = self.slash_commands.clone();
|
||||||
let candidates = slash_commands
|
let candidates = slash_commands
|
||||||
.command_names(cx)
|
.command_names(cx)
|
||||||
@@ -71,28 +71,27 @@ impl SlashCommandCompletionProvider {
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
cx.update(|_, cx| {
|
cx.update(|_, cx| {
|
||||||
Some(
|
let completions = matches
|
||||||
matches
|
.into_iter()
|
||||||
.into_iter()
|
.filter_map(|mat| {
|
||||||
.filter_map(|mat| {
|
let command = slash_commands.command(&mat.string, cx)?;
|
||||||
let command = slash_commands.command(&mat.string, cx)?;
|
let mut new_text = mat.string.clone();
|
||||||
let mut new_text = mat.string.clone();
|
let requires_argument = command.requires_argument();
|
||||||
let requires_argument = command.requires_argument();
|
let accepts_arguments = command.accepts_arguments();
|
||||||
let accepts_arguments = command.accepts_arguments();
|
if requires_argument || accepts_arguments {
|
||||||
if requires_argument || accepts_arguments {
|
new_text.push(' ');
|
||||||
new_text.push(' ');
|
}
|
||||||
}
|
|
||||||
|
|
||||||
let confirm =
|
let confirm =
|
||||||
editor
|
editor
|
||||||
.clone()
|
.clone()
|
||||||
.zip(workspace.clone())
|
.zip(workspace.clone())
|
||||||
.map(|(editor, workspace)| {
|
.map(|(editor, workspace)| {
|
||||||
let command_name = mat.string.clone();
|
let command_name = mat.string.clone();
|
||||||
let command_range = command_range.clone();
|
let command_range = command_range.clone();
|
||||||
let editor = editor.clone();
|
let editor = editor.clone();
|
||||||
let workspace = workspace.clone();
|
let workspace = workspace.clone();
|
||||||
Arc::new(
|
Arc::new(
|
||||||
move |intent: CompletionIntent,
|
move |intent: CompletionIntent,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut App| {
|
cx: &mut App| {
|
||||||
@@ -118,22 +117,27 @@ impl SlashCommandCompletionProvider {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
) as Arc<_>
|
) as Arc<_>
|
||||||
});
|
});
|
||||||
Some(project::Completion {
|
|
||||||
replace_range: name_range.clone(),
|
Some(project::Completion {
|
||||||
documentation: Some(CompletionDocumentation::SingleLine(
|
replace_range: name_range.clone(),
|
||||||
command.description().into(),
|
documentation: Some(CompletionDocumentation::SingleLine(
|
||||||
)),
|
command.description().into(),
|
||||||
new_text,
|
)),
|
||||||
label: command.label(cx),
|
new_text,
|
||||||
icon_path: None,
|
label: command.label(cx),
|
||||||
insert_text_mode: None,
|
icon_path: None,
|
||||||
confirm,
|
insert_text_mode: None,
|
||||||
source: CompletionSource::Custom,
|
confirm,
|
||||||
})
|
source: CompletionSource::Custom,
|
||||||
})
|
})
|
||||||
.collect(),
|
})
|
||||||
)
|
.collect();
|
||||||
|
|
||||||
|
vec![project::CompletionResponse {
|
||||||
|
completions,
|
||||||
|
is_incomplete: false,
|
||||||
|
}]
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -147,7 +151,7 @@ impl SlashCommandCompletionProvider {
|
|||||||
last_argument_range: Range<Anchor>,
|
last_argument_range: Range<Anchor>,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Option<Vec<project::Completion>>>> {
|
) -> Task<Result<Vec<project::CompletionResponse>>> {
|
||||||
let new_cancel_flag = Arc::new(AtomicBool::new(false));
|
let new_cancel_flag = Arc::new(AtomicBool::new(false));
|
||||||
let mut flag = self.cancel_flag.lock();
|
let mut flag = self.cancel_flag.lock();
|
||||||
flag.store(true, SeqCst);
|
flag.store(true, SeqCst);
|
||||||
@@ -165,28 +169,27 @@ impl SlashCommandCompletionProvider {
|
|||||||
let workspace = self.workspace.clone();
|
let workspace = self.workspace.clone();
|
||||||
let arguments = arguments.to_vec();
|
let arguments = arguments.to_vec();
|
||||||
cx.background_spawn(async move {
|
cx.background_spawn(async move {
|
||||||
Ok(Some(
|
let completions = completions
|
||||||
completions
|
.await?
|
||||||
.await?
|
.into_iter()
|
||||||
.into_iter()
|
.map(|new_argument| {
|
||||||
.map(|new_argument| {
|
let confirm =
|
||||||
let confirm =
|
editor
|
||||||
editor
|
.clone()
|
||||||
.clone()
|
.zip(workspace.clone())
|
||||||
.zip(workspace.clone())
|
.map(|(editor, workspace)| {
|
||||||
.map(|(editor, workspace)| {
|
Arc::new({
|
||||||
Arc::new({
|
let mut completed_arguments = arguments.clone();
|
||||||
let mut completed_arguments = arguments.clone();
|
if new_argument.replace_previous_arguments {
|
||||||
if new_argument.replace_previous_arguments {
|
completed_arguments.clear();
|
||||||
completed_arguments.clear();
|
} else {
|
||||||
} else {
|
completed_arguments.pop();
|
||||||
completed_arguments.pop();
|
}
|
||||||
}
|
completed_arguments.push(new_argument.new_text.clone());
|
||||||
completed_arguments.push(new_argument.new_text.clone());
|
|
||||||
|
|
||||||
let command_range = command_range.clone();
|
let command_range = command_range.clone();
|
||||||
let command_name = command_name.clone();
|
let command_name = command_name.clone();
|
||||||
move |intent: CompletionIntent,
|
move |intent: CompletionIntent,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut App| {
|
cx: &mut App| {
|
||||||
if new_argument.after_completion.run()
|
if new_argument.after_completion.run()
|
||||||
@@ -210,34 +213,41 @@ impl SlashCommandCompletionProvider {
|
|||||||
!new_argument.after_completion.run()
|
!new_argument.after_completion.run()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}) as Arc<_>
|
}) as Arc<_>
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut new_text = new_argument.new_text.clone();
|
let mut new_text = new_argument.new_text.clone();
|
||||||
if new_argument.after_completion == AfterCompletion::Continue {
|
if new_argument.after_completion == AfterCompletion::Continue {
|
||||||
new_text.push(' ');
|
new_text.push(' ');
|
||||||
}
|
}
|
||||||
|
|
||||||
project::Completion {
|
project::Completion {
|
||||||
replace_range: if new_argument.replace_previous_arguments {
|
replace_range: if new_argument.replace_previous_arguments {
|
||||||
argument_range.clone()
|
argument_range.clone()
|
||||||
} else {
|
} else {
|
||||||
last_argument_range.clone()
|
last_argument_range.clone()
|
||||||
},
|
},
|
||||||
label: new_argument.label,
|
label: new_argument.label,
|
||||||
icon_path: None,
|
icon_path: None,
|
||||||
new_text,
|
new_text,
|
||||||
documentation: None,
|
documentation: None,
|
||||||
confirm,
|
confirm,
|
||||||
insert_text_mode: None,
|
insert_text_mode: None,
|
||||||
source: CompletionSource::Custom,
|
source: CompletionSource::Custom,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect();
|
||||||
))
|
|
||||||
|
Ok(vec![project::CompletionResponse {
|
||||||
|
completions,
|
||||||
|
is_incomplete: false,
|
||||||
|
}])
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Task::ready(Ok(Some(Vec::new())))
|
Task::ready(Ok(vec![project::CompletionResponse {
|
||||||
|
completions: Vec::new(),
|
||||||
|
is_incomplete: false,
|
||||||
|
}]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -251,7 +261,7 @@ impl CompletionProvider for SlashCommandCompletionProvider {
|
|||||||
_: editor::CompletionContext,
|
_: editor::CompletionContext,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Editor>,
|
cx: &mut Context<Editor>,
|
||||||
) -> Task<Result<Option<Vec<project::Completion>>>> {
|
) -> Task<Result<Vec<project::CompletionResponse>>> {
|
||||||
let Some((name, arguments, command_range, last_argument_range)) =
|
let Some((name, arguments, command_range, last_argument_range)) =
|
||||||
buffer.update(cx, |buffer, _cx| {
|
buffer.update(cx, |buffer, _cx| {
|
||||||
let position = buffer_position.to_point(buffer);
|
let position = buffer_position.to_point(buffer);
|
||||||
@@ -295,7 +305,10 @@ impl CompletionProvider for SlashCommandCompletionProvider {
|
|||||||
Some((name, arguments, command_range, last_argument_range))
|
Some((name, arguments, command_range, last_argument_range))
|
||||||
})
|
})
|
||||||
else {
|
else {
|
||||||
return Task::ready(Ok(Some(Vec::new())));
|
return Task::ready(Ok(vec![project::CompletionResponse {
|
||||||
|
completions: Vec::new(),
|
||||||
|
is_incomplete: false,
|
||||||
|
}]));
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some((arguments, argument_range)) = arguments {
|
if let Some((arguments, argument_range)) = arguments {
|
||||||
|
|||||||
@@ -415,14 +415,38 @@ impl ActionLog {
|
|||||||
self.project
|
self.project
|
||||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||||
} else {
|
} else {
|
||||||
buffer
|
// For a file created by AI with no pre-existing content,
|
||||||
.read(cx)
|
// only delete the file if we're certain it contains only AI content
|
||||||
.entry_id(cx)
|
// with no edits from the user.
|
||||||
.and_then(|entry_id| {
|
|
||||||
self.project
|
let initial_version = tracked_buffer.version.clone();
|
||||||
.update(cx, |project, cx| project.delete_entry(entry_id, false, cx))
|
let current_version = buffer.read(cx).version();
|
||||||
})
|
|
||||||
.unwrap_or(Task::ready(Ok(())))
|
let current_content = buffer.read(cx).text();
|
||||||
|
let tracked_content = tracked_buffer.snapshot.text();
|
||||||
|
|
||||||
|
let is_ai_only_content =
|
||||||
|
initial_version == current_version && current_content == tracked_content;
|
||||||
|
|
||||||
|
if is_ai_only_content {
|
||||||
|
buffer
|
||||||
|
.read(cx)
|
||||||
|
.entry_id(cx)
|
||||||
|
.and_then(|entry_id| {
|
||||||
|
self.project.update(cx, |project, cx| {
|
||||||
|
project.delete_entry(entry_id, false, cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.unwrap_or(Task::ready(Ok(())))
|
||||||
|
} else {
|
||||||
|
// Not sure how to disentangle edits made by the user
|
||||||
|
// from edits made by the AI at this point.
|
||||||
|
// For now, preserve both to avoid data loss.
|
||||||
|
//
|
||||||
|
// TODO: Better solution (disable "Reject" after user makes some
|
||||||
|
// edit or find a way to differentiate between AI and user edits)
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self.tracked_buffers.remove(&buffer);
|
self.tracked_buffers.remove(&buffer);
|
||||||
@@ -1576,7 +1600,6 @@ mod tests {
|
|||||||
project.find_project_path("dir/new_file", cx)
|
project.find_project_path("dir/new_file", cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let buffer = project
|
let buffer = project
|
||||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||||
.await
|
.await
|
||||||
@@ -1619,6 +1642,72 @@ mod tests {
|
|||||||
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
|
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_reject_created_file_with_user_edits(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||||
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
|
||||||
|
let file_path = project
|
||||||
|
.read_with(cx, |project, cx| {
|
||||||
|
project.find_project_path("dir/new_file", cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
let buffer = project
|
||||||
|
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// AI creates file with initial content
|
||||||
|
cx.update(|cx| {
|
||||||
|
action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
|
||||||
|
buffer.update(cx, |buffer, cx| buffer.set_text("ai content", cx));
|
||||||
|
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||||
|
});
|
||||||
|
|
||||||
|
project
|
||||||
|
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
// User makes additional edits
|
||||||
|
cx.update(|cx| {
|
||||||
|
buffer.update(cx, |buffer, cx| {
|
||||||
|
buffer.edit([(10..10, "\nuser added this line")], None, cx);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
project
|
||||||
|
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(fs.is_file(path!("/dir/new_file").as_ref()).await);
|
||||||
|
|
||||||
|
// Reject all
|
||||||
|
action_log
|
||||||
|
.update(cx, |log, cx| {
|
||||||
|
log.reject_edits_in_ranges(
|
||||||
|
buffer.clone(),
|
||||||
|
vec![Point::new(0, 0)..Point::new(100, 0)],
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
// File should still contain all the content
|
||||||
|
assert!(fs.is_file(path!("/dir/new_file").as_ref()).await);
|
||||||
|
|
||||||
|
let content = buffer.read_with(cx, |buffer, _| buffer.text());
|
||||||
|
assert_eq!(content, "ai content\nuser added this line");
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test(iterations = 100)]
|
#[gpui::test(iterations = 100)]
|
||||||
async fn test_random_diffs(mut rng: StdRng, cx: &mut TestAppContext) {
|
async fn test_random_diffs(mut rng: StdRng, cx: &mut TestAppContext) {
|
||||||
init_test(cx);
|
init_test(cx);
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ itertools.workspace = true
|
|||||||
language.workspace = true
|
language.workspace = true
|
||||||
language_model.workspace = true
|
language_model.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
|
lsp.workspace = true
|
||||||
markdown.workspace = true
|
markdown.workspace = true
|
||||||
open.workspace = true
|
open.workspace = true
|
||||||
paths.workspace = true
|
paths.workspace = true
|
||||||
@@ -64,6 +65,7 @@ workspace.workspace = true
|
|||||||
zed_llm_client.workspace = true
|
zed_llm_client.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
lsp = { workspace = true, features = ["test-support"] }
|
||||||
client = { workspace = true, features = ["test-support"] }
|
client = { workspace = true, features = ["test-support"] }
|
||||||
clock = { workspace = true, features = ["test-support"] }
|
clock = { workspace = true, features = ["test-support"] }
|
||||||
collections = { workspace = true, features = ["test-support"] }
|
collections = { workspace = true, features = ["test-support"] }
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
Invoke multiple other tool calls either sequentially or concurrently.
|
|
||||||
|
|
||||||
This tool is useful when you need to perform several operations at once, improving efficiency by reducing the number of back-and-forth interactions needed to complete complex tasks.
|
|
||||||
|
|
||||||
If the tool calls are set to be run sequentially, then each tool call within the batch is executed in the order provided. If it's set to run concurrently, then they may run in a different order. Regardless, all tool calls will have the same permissions and context as if they were called individually.
|
|
||||||
|
|
||||||
This tool should never be used to run a total of one tool. Instead, just run that one tool directly. You can run batches within batches if desired, which is a way you can mix concurrent and sequential tool call execution.
|
|
||||||
|
|
||||||
When it's possible to run tools in a batch, you should run as many as possible in the batch, up to a maximum of 32. For example, don't run multiple consecutive batches of 10 when you could instead run one batch of 30.
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
A tool for applying code actions to specific sections of your code. It uses language servers to provide refactoring capabilities similar to what you'd find in an IDE.
|
|
||||||
|
|
||||||
This tool can:
|
|
||||||
- List all available code actions for a selected text range
|
|
||||||
- Execute a specific code action on that range
|
|
||||||
- Rename symbols across your codebase. This tool is the preferred way to rename things, and you should always prefer to rename code symbols using this tool rather than using textual find/replace when both are available.
|
|
||||||
|
|
||||||
Use this tool when you want to:
|
|
||||||
- Discover what code actions are available for a piece of code
|
|
||||||
- Apply automatic fixes and code transformations
|
|
||||||
- Rename variables, functions, or other symbols consistently throughout your project
|
|
||||||
- Clean up imports, implement interfaces, or perform other language-specific operations
|
|
||||||
|
|
||||||
- If unsure what actions are available, call the tool without specifying an action to get a list
|
|
||||||
- For common operations, you can directly specify actions like "quickfix.all" or "source.organizeImports"
|
|
||||||
- For renaming, use the special "textDocument/rename" action and provide the new name in the arguments field
|
|
||||||
- Be specific with your text range and context to ensure the tool identifies the correct code location
|
|
||||||
|
|
||||||
The tool will automatically save any changes it makes to your files.
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
Returns either an outline of the public code symbols in the entire project (grouped by file) or else an outline of both the public and private code symbols within a particular file.
|
|
||||||
|
|
||||||
When a path is provided, this tool returns a hierarchical outline of code symbols for that specific file.
|
|
||||||
When no path is provided, it returns a list of all public code symbols in the project, organized by file.
|
|
||||||
|
|
||||||
You can also provide an optional regular expression which filters the output by only showing code symbols which match that regex.
|
|
||||||
|
|
||||||
Results are paginated with 2000 entries per page. Use the optional 'offset' parameter to request subsequent pages.
|
|
||||||
|
|
||||||
Markdown headings indicate the structure of the output; just like
|
|
||||||
with markdown headings, the more # symbols there are at the beginning of a line,
|
|
||||||
the deeper it is in the hierarchy.
|
|
||||||
|
|
||||||
Each code symbol entry ends with a line number or range, which tells you what portion of the
|
|
||||||
underlying source code file corresponds to that part of the outline. You can use
|
|
||||||
that line information with other tools, to strategically read portions of the source code.
|
|
||||||
|
|
||||||
For example, you can use this tool to find a relevant symbol in the project, then get the outline of the file which contains that symbol, then use the line number information from that file's outline to read different sections of that file, without having to read the entire file all at once (which can be slow, or use a lot of tokens).
|
|
||||||
|
|
||||||
<example>
|
|
||||||
# class Foo [L123-136]
|
|
||||||
## method do_something(arg1, arg2) [L124-126]
|
|
||||||
## method process_data(data) [L128-135]
|
|
||||||
# class Bar [L145-161]
|
|
||||||
## method initialize() [L146-149]
|
|
||||||
## method update_state(new_state) [L160]
|
|
||||||
## private method _validate_state(state) [L161-162]
|
|
||||||
</example>
|
|
||||||
|
|
||||||
This example shows how tree-sitter outlines the structure of source code:
|
|
||||||
|
|
||||||
1. `class Foo` is defined on lines 123-136
|
|
||||||
- It contains a method `do_something` spanning lines 124-126
|
|
||||||
- It also has a method `process_data` spanning lines 128-135
|
|
||||||
|
|
||||||
2. `class Bar` is defined on lines 145-161
|
|
||||||
- It has an `initialize` method spanning lines 146-149
|
|
||||||
- It has an `update_state` method on line 160
|
|
||||||
- It has a private method `_validate_state` spanning lines 161-162
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
Reads the contents of a path on the filesystem.
|
|
||||||
|
|
||||||
If the path is a directory, this lists all files and directories within that path.
|
|
||||||
If the path is a file, this returns the file's contents.
|
|
||||||
|
|
||||||
When reading a file, if the file is too big and no line range is specified, an outline of the file's code symbols is listed instead, which can be used to request specific line ranges in a subsequent call.
|
|
||||||
|
|
||||||
Similarly, if a directory has too many entries to show at once, a subset of entries will be shown,
|
|
||||||
and subsequent requests can use starting and ending line numbers to get other subsets.
|
|
||||||
@@ -28,6 +28,7 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::
|
|||||||
use streaming_diff::{CharOperation, StreamingDiff};
|
use streaming_diff::{CharOperation, StreamingDiff};
|
||||||
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
|
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
|
||||||
use util::debug_panic;
|
use util::debug_panic;
|
||||||
|
use zed_llm_client::CompletionIntent;
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct CreateFilePromptTemplate {
|
struct CreateFilePromptTemplate {
|
||||||
@@ -106,7 +107,9 @@ impl EditAgent {
|
|||||||
edit_description,
|
edit_description,
|
||||||
}
|
}
|
||||||
.render(&this.templates)?;
|
.render(&this.templates)?;
|
||||||
let new_chunks = this.request(conversation, prompt, cx).await?;
|
let new_chunks = this
|
||||||
|
.request(conversation, CompletionIntent::CreateFile, prompt, cx)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
|
let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
|
||||||
while let Some(event) = inner_events.next().await {
|
while let Some(event) = inner_events.next().await {
|
||||||
@@ -213,7 +216,9 @@ impl EditAgent {
|
|||||||
edit_description,
|
edit_description,
|
||||||
}
|
}
|
||||||
.render(&this.templates)?;
|
.render(&this.templates)?;
|
||||||
let edit_chunks = this.request(conversation, prompt, cx).await?;
|
let edit_chunks = this
|
||||||
|
.request(conversation, CompletionIntent::EditFile, prompt, cx)
|
||||||
|
.await?;
|
||||||
this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx)
|
this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx)
|
||||||
.await
|
.await
|
||||||
});
|
});
|
||||||
@@ -589,6 +594,7 @@ impl EditAgent {
|
|||||||
async fn request(
|
async fn request(
|
||||||
&self,
|
&self,
|
||||||
mut conversation: LanguageModelRequest,
|
mut conversation: LanguageModelRequest,
|
||||||
|
intent: CompletionIntent,
|
||||||
prompt: String,
|
prompt: String,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
|
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
|
||||||
@@ -646,6 +652,7 @@ impl EditAgent {
|
|||||||
let request = LanguageModelRequest {
|
let request = LanguageModelRequest {
|
||||||
thread_id: conversation.thread_id,
|
thread_id: conversation.thread_id,
|
||||||
prompt_id: conversation.prompt_id,
|
prompt_id: conversation.prompt_id,
|
||||||
|
intent: Some(intent),
|
||||||
mode: conversation.mode,
|
mode: conversation.mode,
|
||||||
messages: conversation.messages,
|
messages: conversation.messages,
|
||||||
tool_choice,
|
tool_choice,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use std::cell::LazyCell;
|
|||||||
use util::debug_panic;
|
use util::debug_panic;
|
||||||
|
|
||||||
const START_MARKER: LazyCell<Regex> = LazyCell::new(|| Regex::new(r"\n?```\S*\n").unwrap());
|
const START_MARKER: LazyCell<Regex> = LazyCell::new(|| Regex::new(r"\n?```\S*\n").unwrap());
|
||||||
const END_MARKER: LazyCell<Regex> = LazyCell::new(|| Regex::new(r"\n```\s*$").unwrap());
|
const END_MARKER: LazyCell<Regex> = LazyCell::new(|| Regex::new(r"(^|\n)```\s*$").unwrap());
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum CreateFileParserEvent {
|
pub enum CreateFileParserEvent {
|
||||||
@@ -184,6 +184,22 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test(iterations = 10)]
|
||||||
|
fn test_empty_file(mut rng: StdRng) {
|
||||||
|
let mut parser = CreateFileParser::new();
|
||||||
|
assert_eq!(
|
||||||
|
parse_random_chunks(
|
||||||
|
indoc! {"
|
||||||
|
```
|
||||||
|
```
|
||||||
|
"},
|
||||||
|
&mut parser,
|
||||||
|
&mut rng
|
||||||
|
),
|
||||||
|
"".to_string()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_random_chunks(input: &str, parser: &mut CreateFileParser, rng: &mut StdRng) -> String {
|
fn parse_random_chunks(input: &str, parser: &mut CreateFileParser, rng: &mut StdRng) -> String {
|
||||||
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
|
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
|
||||||
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
|
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
|
||||||
|
|||||||
@@ -18,16 +18,21 @@ use gpui::{
|
|||||||
use indoc::formatdoc;
|
use indoc::formatdoc;
|
||||||
use language::{
|
use language::{
|
||||||
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Point, Rope,
|
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Point, Rope,
|
||||||
TextBuffer, language_settings::SoftWrap,
|
TextBuffer,
|
||||||
|
language_settings::{self, FormatOnSave, SoftWrap},
|
||||||
};
|
};
|
||||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
|
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
|
||||||
use project::{Project, ProjectPath};
|
use project::{
|
||||||
|
Project, ProjectPath,
|
||||||
|
lsp_store::{FormatTrigger, LspFormatTarget},
|
||||||
|
};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::Settings;
|
use settings::Settings;
|
||||||
use std::{
|
use std::{
|
||||||
cmp::Reverse,
|
cmp::Reverse,
|
||||||
|
collections::HashSet,
|
||||||
ops::Range,
|
ops::Range,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
@@ -189,8 +194,10 @@ impl Tool for EditFileTool {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let card_clone = card.clone();
|
let card_clone = card.clone();
|
||||||
|
let action_log_clone = action_log.clone();
|
||||||
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
||||||
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
|
let edit_agent =
|
||||||
|
EditAgent::new(model, project.clone(), action_log_clone, Templates::new());
|
||||||
|
|
||||||
let buffer = project
|
let buffer = project
|
||||||
.update(cx, |project, cx| {
|
.update(cx, |project, cx| {
|
||||||
@@ -244,19 +251,53 @@ impl Tool for EditFileTool {
|
|||||||
}
|
}
|
||||||
let agent_output = output.await?;
|
let agent_output = output.await?;
|
||||||
|
|
||||||
|
// If format_on_save is enabled, format the buffer
|
||||||
|
let format_on_save_enabled = buffer
|
||||||
|
.read_with(cx, |buffer, cx| {
|
||||||
|
let settings = language_settings::language_settings(
|
||||||
|
buffer.language().map(|l| l.name()),
|
||||||
|
buffer.file(),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
!matches!(settings.format_on_save, FormatOnSave::Off)
|
||||||
|
})
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if format_on_save_enabled {
|
||||||
|
let format_task = project.update(cx, |project, cx| {
|
||||||
|
project.format(
|
||||||
|
HashSet::from_iter([buffer.clone()]),
|
||||||
|
LspFormatTarget::Buffers,
|
||||||
|
false, // Don't push to history since the tool did it.
|
||||||
|
FormatTrigger::Save,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
format_task.await.log_err();
|
||||||
|
}
|
||||||
|
|
||||||
project
|
project
|
||||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Notify the action log that we've edited the buffer (*after* formatting has completed).
|
||||||
|
action_log.update(cx, |log, cx| {
|
||||||
|
log.buffer_edited(buffer.clone(), cx);
|
||||||
|
})?;
|
||||||
|
|
||||||
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||||
let new_text = cx.background_spawn({
|
let (new_text, diff) = cx
|
||||||
let new_snapshot = new_snapshot.clone();
|
.background_spawn({
|
||||||
async move { new_snapshot.text() }
|
let new_snapshot = new_snapshot.clone();
|
||||||
});
|
let old_text = old_text.clone();
|
||||||
let diff = cx.background_spawn(async move {
|
async move {
|
||||||
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
|
let new_text = new_snapshot.text();
|
||||||
});
|
let diff = language::unified_diff(&old_text, &new_text);
|
||||||
let (new_text, diff) = futures::join!(new_text, diff);
|
|
||||||
|
(new_text, diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
let output = EditFileToolOutput {
|
let output = EditFileToolOutput {
|
||||||
original_path: project_path.path.to_path_buf(),
|
original_path: project_path.path.to_path_buf(),
|
||||||
@@ -1099,8 +1140,8 @@ async fn build_buffer_diff(
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use client::TelemetrySettings;
|
use client::TelemetrySettings;
|
||||||
use fs::FakeFs;
|
use fs::{FakeFs, Fs};
|
||||||
use gpui::TestAppContext;
|
use gpui::{TestAppContext, UpdateGlobal};
|
||||||
use language_model::fake_provider::FakeLanguageModel;
|
use language_model::fake_provider::FakeLanguageModel;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
@@ -1310,4 +1351,340 @@ mod tests {
|
|||||||
Project::init_settings(cx);
|
Project::init_settings(cx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_format_on_save(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
fs.insert_tree("/root", json!({"src": {}})).await;
|
||||||
|
|
||||||
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
|
|
||||||
|
// Set up a Rust language with LSP formatting support
|
||||||
|
let rust_language = Arc::new(language::Language::new(
|
||||||
|
language::LanguageConfig {
|
||||||
|
name: "Rust".into(),
|
||||||
|
matcher: language::LanguageMatcher {
|
||||||
|
path_suffixes: vec!["rs".to_string()],
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
));
|
||||||
|
|
||||||
|
// Register the language and fake LSP
|
||||||
|
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||||
|
language_registry.add(rust_language);
|
||||||
|
|
||||||
|
let mut fake_language_servers = language_registry.register_fake_lsp(
|
||||||
|
"Rust",
|
||||||
|
language::FakeLspAdapter {
|
||||||
|
capabilities: lsp::ServerCapabilities {
|
||||||
|
document_formatting_provider: Some(lsp::OneOf::Left(true)),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create the file
|
||||||
|
fs.save(
|
||||||
|
path!("/root/src/main.rs").as_ref(),
|
||||||
|
&"initial content".into(),
|
||||||
|
language::LineEnding::Unix,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Open the buffer to trigger LSP initialization
|
||||||
|
let buffer = project
|
||||||
|
.update(cx, |project, cx| {
|
||||||
|
project.open_local_buffer(path!("/root/src/main.rs"), cx)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Register the buffer with language servers
|
||||||
|
let _handle = project.update(cx, |project, cx| {
|
||||||
|
project.register_buffer_with_language_servers(&buffer, cx)
|
||||||
|
});
|
||||||
|
|
||||||
|
const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
|
||||||
|
const FORMATTED_CONTENT: &str =
|
||||||
|
"This file was formatted by the fake formatter in the test.\n";
|
||||||
|
|
||||||
|
// Get the fake language server and set up formatting handler
|
||||||
|
let fake_language_server = fake_language_servers.next().await.unwrap();
|
||||||
|
fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
|
||||||
|
|_, _| async move {
|
||||||
|
Ok(Some(vec![lsp::TextEdit {
|
||||||
|
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
|
||||||
|
new_text: FORMATTED_CONTENT.to_string(),
|
||||||
|
}]))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
|
|
||||||
|
// First, test with format_on_save enabled
|
||||||
|
cx.update(|cx| {
|
||||||
|
SettingsStore::update_global(cx, |store, cx| {
|
||||||
|
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
|
||||||
|
cx,
|
||||||
|
|settings| {
|
||||||
|
settings.defaults.format_on_save = Some(FormatOnSave::On);
|
||||||
|
settings.defaults.formatter =
|
||||||
|
Some(language::language_settings::SelectedFormatter::Auto);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Have the model stream unformatted content
|
||||||
|
let edit_result = {
|
||||||
|
let edit_task = cx.update(|cx| {
|
||||||
|
let input = serde_json::to_value(EditFileToolInput {
|
||||||
|
display_description: "Create main function".into(),
|
||||||
|
path: "root/src/main.rs".into(),
|
||||||
|
mode: EditFileMode::Overwrite,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
Arc::new(EditFileTool)
|
||||||
|
.run(
|
||||||
|
input,
|
||||||
|
Arc::default(),
|
||||||
|
project.clone(),
|
||||||
|
action_log.clone(),
|
||||||
|
model.clone(),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.output
|
||||||
|
});
|
||||||
|
|
||||||
|
// Stream the unformatted content
|
||||||
|
cx.executor().run_until_parked();
|
||||||
|
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
|
||||||
|
model.end_last_completion_stream();
|
||||||
|
|
||||||
|
edit_task.await
|
||||||
|
};
|
||||||
|
assert!(edit_result.is_ok());
|
||||||
|
|
||||||
|
// Wait for any async operations (e.g. formatting) to complete
|
||||||
|
cx.executor().run_until_parked();
|
||||||
|
|
||||||
|
// Read the file to verify it was formatted automatically
|
||||||
|
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
// Ignore carriage returns on Windows
|
||||||
|
new_content.replace("\r\n", "\n"),
|
||||||
|
FORMATTED_CONTENT,
|
||||||
|
"Code should be formatted when format_on_save is enabled"
|
||||||
|
);
|
||||||
|
|
||||||
|
let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count());
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stale_buffer_count, 0,
|
||||||
|
"BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
|
||||||
|
This causes the agent to think the file was modified externally when it was just formatted.",
|
||||||
|
stale_buffer_count
|
||||||
|
);
|
||||||
|
|
||||||
|
// Next, test with format_on_save disabled
|
||||||
|
cx.update(|cx| {
|
||||||
|
SettingsStore::update_global(cx, |store, cx| {
|
||||||
|
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
|
||||||
|
cx,
|
||||||
|
|settings| {
|
||||||
|
settings.defaults.format_on_save = Some(FormatOnSave::Off);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Stream unformatted edits again
|
||||||
|
let edit_result = {
|
||||||
|
let edit_task = cx.update(|cx| {
|
||||||
|
let input = serde_json::to_value(EditFileToolInput {
|
||||||
|
display_description: "Update main function".into(),
|
||||||
|
path: "root/src/main.rs".into(),
|
||||||
|
mode: EditFileMode::Overwrite,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
Arc::new(EditFileTool)
|
||||||
|
.run(
|
||||||
|
input,
|
||||||
|
Arc::default(),
|
||||||
|
project.clone(),
|
||||||
|
action_log.clone(),
|
||||||
|
model.clone(),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.output
|
||||||
|
});
|
||||||
|
|
||||||
|
// Stream the unformatted content
|
||||||
|
cx.executor().run_until_parked();
|
||||||
|
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
|
||||||
|
model.end_last_completion_stream();
|
||||||
|
|
||||||
|
edit_task.await
|
||||||
|
};
|
||||||
|
assert!(edit_result.is_ok());
|
||||||
|
|
||||||
|
// Wait for any async operations (e.g. formatting) to complete
|
||||||
|
cx.executor().run_until_parked();
|
||||||
|
|
||||||
|
// Verify the file was not formatted
|
||||||
|
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
// Ignore carriage returns on Windows
|
||||||
|
new_content.replace("\r\n", "\n"),
|
||||||
|
UNFORMATTED_CONTENT,
|
||||||
|
"Code should not be formatted when format_on_save is disabled"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
fs.insert_tree("/root", json!({"src": {}})).await;
|
||||||
|
|
||||||
|
// Create a simple file with trailing whitespace
|
||||||
|
fs.save(
|
||||||
|
path!("/root/src/main.rs").as_ref(),
|
||||||
|
&"initial content".into(),
|
||||||
|
language::LineEnding::Unix,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
|
|
||||||
|
// First, test with remove_trailing_whitespace_on_save enabled
|
||||||
|
cx.update(|cx| {
|
||||||
|
SettingsStore::update_global(cx, |store, cx| {
|
||||||
|
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
|
||||||
|
cx,
|
||||||
|
|settings| {
|
||||||
|
settings.defaults.remove_trailing_whitespace_on_save = Some(true);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
const CONTENT_WITH_TRAILING_WHITESPACE: &str =
|
||||||
|
"fn main() { \n println!(\"Hello!\"); \n}\n";
|
||||||
|
|
||||||
|
// Have the model stream content that contains trailing whitespace
|
||||||
|
let edit_result = {
|
||||||
|
let edit_task = cx.update(|cx| {
|
||||||
|
let input = serde_json::to_value(EditFileToolInput {
|
||||||
|
display_description: "Create main function".into(),
|
||||||
|
path: "root/src/main.rs".into(),
|
||||||
|
mode: EditFileMode::Overwrite,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
Arc::new(EditFileTool)
|
||||||
|
.run(
|
||||||
|
input,
|
||||||
|
Arc::default(),
|
||||||
|
project.clone(),
|
||||||
|
action_log.clone(),
|
||||||
|
model.clone(),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.output
|
||||||
|
});
|
||||||
|
|
||||||
|
// Stream the content with trailing whitespace
|
||||||
|
cx.executor().run_until_parked();
|
||||||
|
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
|
||||||
|
model.end_last_completion_stream();
|
||||||
|
|
||||||
|
edit_task.await
|
||||||
|
};
|
||||||
|
assert!(edit_result.is_ok());
|
||||||
|
|
||||||
|
// Wait for any async operations (e.g. formatting) to complete
|
||||||
|
cx.executor().run_until_parked();
|
||||||
|
|
||||||
|
// Read the file to verify trailing whitespace was removed automatically
|
||||||
|
assert_eq!(
|
||||||
|
// Ignore carriage returns on Windows
|
||||||
|
fs.load(path!("/root/src/main.rs").as_ref())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.replace("\r\n", "\n"),
|
||||||
|
"fn main() {\n println!(\"Hello!\");\n}\n",
|
||||||
|
"Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Next, test with remove_trailing_whitespace_on_save disabled
|
||||||
|
cx.update(|cx| {
|
||||||
|
SettingsStore::update_global(cx, |store, cx| {
|
||||||
|
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
|
||||||
|
cx,
|
||||||
|
|settings| {
|
||||||
|
settings.defaults.remove_trailing_whitespace_on_save = Some(false);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Stream edits again with trailing whitespace
|
||||||
|
let edit_result = {
|
||||||
|
let edit_task = cx.update(|cx| {
|
||||||
|
let input = serde_json::to_value(EditFileToolInput {
|
||||||
|
display_description: "Update main function".into(),
|
||||||
|
path: "root/src/main.rs".into(),
|
||||||
|
mode: EditFileMode::Overwrite,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
Arc::new(EditFileTool)
|
||||||
|
.run(
|
||||||
|
input,
|
||||||
|
Arc::default(),
|
||||||
|
project.clone(),
|
||||||
|
action_log.clone(),
|
||||||
|
model.clone(),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.output
|
||||||
|
});
|
||||||
|
|
||||||
|
// Stream the content with trailing whitespace
|
||||||
|
cx.executor().run_until_parked();
|
||||||
|
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
|
||||||
|
model.end_last_completion_stream();
|
||||||
|
|
||||||
|
edit_task.await
|
||||||
|
};
|
||||||
|
assert!(edit_result.is_ok());
|
||||||
|
|
||||||
|
// Wait for any async operations (e.g. formatting) to complete
|
||||||
|
cx.executor().run_until_parked();
|
||||||
|
|
||||||
|
// Verify the file still has trailing whitespace
|
||||||
|
// Read the file again - it should still have trailing whitespace
|
||||||
|
let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
// Ignore carriage returns on Windows
|
||||||
|
final_content.replace("\r\n", "\n"),
|
||||||
|
CONTENT_WITH_TRAILING_WHITESPACE,
|
||||||
|
"Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use std::cell::RefCell;
|
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::{borrow::Cow, cell::RefCell};
|
||||||
|
|
||||||
use crate::schema::json_schema_for;
|
use crate::schema::json_schema_for;
|
||||||
use anyhow::{Context as _, Result, anyhow, bail};
|
use anyhow::{Context as _, Result, anyhow, bail};
|
||||||
@@ -39,10 +39,11 @@ impl FetchTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
|
async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
|
||||||
let mut url = url.to_owned();
|
let url = if !url.starts_with("https://") && !url.starts_with("http://") {
|
||||||
if !url.starts_with("https://") && !url.starts_with("http://") {
|
Cow::Owned(format!("https://{url}"))
|
||||||
url = format!("https://{url}");
|
} else {
|
||||||
}
|
Cow::Borrowed(url)
|
||||||
|
};
|
||||||
|
|
||||||
let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
|
let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
|
||||||
|
|
||||||
@@ -156,8 +157,7 @@ impl Tool for FetchTool {
|
|||||||
|
|
||||||
let text = cx.background_spawn({
|
let text = cx.background_spawn({
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let url = input.url.clone();
|
async move { Self::build_message(http_client, &input.url).await }
|
||||||
async move { Self::build_message(http_client, &url).await }
|
|
||||||
});
|
});
|
||||||
|
|
||||||
cx.foreground_executor()
|
cx.foreground_executor()
|
||||||
|
|||||||
@@ -119,14 +119,16 @@ impl Tool for FindPathTool {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
let output = FindPathToolOutput {
|
|
||||||
glob,
|
|
||||||
paths: matches.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
|
for mat in matches.iter().skip(offset).take(RESULTS_PER_PAGE) {
|
||||||
write!(&mut message, "\n{}", mat.display()).unwrap();
|
write!(&mut message, "\n{}", mat.display()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let output = FindPathToolOutput {
|
||||||
|
glob,
|
||||||
|
paths: matches,
|
||||||
|
};
|
||||||
|
|
||||||
Ok(ToolResultOutput {
|
Ok(ToolResultOutput {
|
||||||
content: ToolResultContent::Text(message),
|
content: ToolResultContent::Text(message),
|
||||||
output: Some(serde_json::to_value(output)?),
|
output: Some(serde_json::to_value(output)?),
|
||||||
@@ -235,8 +237,6 @@ impl ToolCard for FindPathToolCard {
|
|||||||
format!("{} matches", self.paths.len()).into()
|
format!("{} matches", self.paths.len()).into()
|
||||||
};
|
};
|
||||||
|
|
||||||
let glob_label = self.glob.to_string();
|
|
||||||
|
|
||||||
let content = if !self.paths.is_empty() && self.expanded {
|
let content = if !self.paths.is_empty() && self.expanded {
|
||||||
Some(
|
Some(
|
||||||
v_flex()
|
v_flex()
|
||||||
@@ -310,7 +310,7 @@ impl ToolCard for FindPathToolCard {
|
|||||||
.gap_1()
|
.gap_1()
|
||||||
.child(
|
.child(
|
||||||
ToolCallCardHeader::new(IconName::SearchCode, matches_label)
|
ToolCallCardHeader::new(IconName::SearchCode, matches_label)
|
||||||
.with_code_path(glob_label)
|
.with_code_path(&self.glob)
|
||||||
.disclosure_slot(
|
.disclosure_slot(
|
||||||
Disclosure::new("path-search-disclosure", self.expanded)
|
Disclosure::new("path-search-disclosure", self.expanded)
|
||||||
.opened_icon(IconName::ChevronUp)
|
.opened_icon(IconName::ChevronUp)
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
Renames a symbol across your codebase using the language server's semantic knowledge.
|
|
||||||
|
|
||||||
This tool performs a rename refactoring operation on a specified symbol. It uses the project's language server to analyze the code and perform the rename correctly across all files where the symbol is referenced.
|
|
||||||
|
|
||||||
Unlike a simple find and replace, this tool understands the semantic meaning of the code, so it only renames the specific symbol you specify and not unrelated text that happens to have the same name.
|
|
||||||
|
|
||||||
Examples of symbols you can rename:
|
|
||||||
- Variables
|
|
||||||
- Functions
|
|
||||||
- Classes/structs
|
|
||||||
- Fields/properties
|
|
||||||
- Methods
|
|
||||||
- Interfaces/traits
|
|
||||||
|
|
||||||
The language server handles updating all references to the renamed symbol throughout the codebase.
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
Gives detailed information about code symbols in your project such as variables, functions, classes, interface, traits, and other programming constructs, using the editor's integrated Language Server Protocol (LSP) servers.
|
|
||||||
|
|
||||||
This tool is the preferred way to do things like:
|
|
||||||
* Find out where a code symbol is first declared (or first defined - that is, assigned)
|
|
||||||
* Find all the places where a code symbol is referenced
|
|
||||||
* Find the type definition for a code symbol
|
|
||||||
* Find a code symbol's implementation
|
|
||||||
|
|
||||||
This tool gives more reliable answers than things like regex searches, because it can account for relevant semantics like aliases. It should be used over textual search tools (e.g. regex) when searching for information about code symbols that this tool supports directly.
|
|
||||||
|
|
||||||
This tool should not be used when you need to search for something that is not a code symbol.
|
|
||||||
@@ -182,9 +182,8 @@ impl Tool for TerminalTool {
|
|||||||
let mut child = pair.slave.spawn_command(cmd)?;
|
let mut child = pair.slave.spawn_command(cmd)?;
|
||||||
let mut reader = pair.master.try_clone_reader()?;
|
let mut reader = pair.master.try_clone_reader()?;
|
||||||
drop(pair);
|
drop(pair);
|
||||||
let mut content = Vec::new();
|
let mut content = String::new();
|
||||||
reader.read_to_end(&mut content)?;
|
reader.read_to_string(&mut content)?;
|
||||||
let mut content = String::from_utf8(content)?;
|
|
||||||
// Massage the pty output a bit to try to match what the terminal codepath gives us
|
// Massage the pty output a bit to try to match what the terminal codepath gives us
|
||||||
LineEnding::normalize(&mut content);
|
LineEnding::normalize(&mut content);
|
||||||
content = content
|
content = content
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ impl ToolCard for WebSearchToolCard {
|
|||||||
.gap_1()
|
.gap_1()
|
||||||
.children(response.results.iter().enumerate().map(|(index, result)| {
|
.children(response.results.iter().enumerate().map(|(index, result)| {
|
||||||
let title = result.title.clone();
|
let title = result.title.clone();
|
||||||
let url = result.url.clone();
|
let url = SharedString::from(result.url.clone());
|
||||||
|
|
||||||
Button::new(("result", index), title)
|
Button::new(("result", index), title)
|
||||||
.label_size(LabelSize::Small)
|
.label_size(LabelSize::Small)
|
||||||
|
|||||||
@@ -49,8 +49,12 @@ pub enum VersionCheckType {
|
|||||||
pub enum AutoUpdateStatus {
|
pub enum AutoUpdateStatus {
|
||||||
Idle,
|
Idle,
|
||||||
Checking,
|
Checking,
|
||||||
Downloading,
|
Downloading {
|
||||||
Installing,
|
version: VersionCheckType,
|
||||||
|
},
|
||||||
|
Installing {
|
||||||
|
version: VersionCheckType,
|
||||||
|
},
|
||||||
Updated {
|
Updated {
|
||||||
binary_path: PathBuf,
|
binary_path: PathBuf,
|
||||||
version: VersionCheckType,
|
version: VersionCheckType,
|
||||||
@@ -511,12 +515,12 @@ impl AutoUpdater {
|
|||||||
Self::get_latest_release(&this, "zed", OS, ARCH, release_channel, &mut cx).await?;
|
Self::get_latest_release(&this, "zed", OS, ARCH, release_channel, &mut cx).await?;
|
||||||
let fetched_version = fetched_release_data.clone().version;
|
let fetched_version = fetched_release_data.clone().version;
|
||||||
let app_commit_sha = cx.update(|cx| AppCommitSha::try_global(cx).map(|sha| sha.full()));
|
let app_commit_sha = cx.update(|cx| AppCommitSha::try_global(cx).map(|sha| sha.full()));
|
||||||
let newer_version = Self::check_for_newer_version(
|
let newer_version = Self::check_if_fetched_version_is_newer(
|
||||||
*RELEASE_CHANNEL,
|
*RELEASE_CHANNEL,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
previous_status.clone(),
|
|
||||||
fetched_version,
|
fetched_version,
|
||||||
|
previous_status.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let Some(newer_version) = newer_version else {
|
let Some(newer_version) = newer_version else {
|
||||||
@@ -531,7 +535,9 @@ impl AutoUpdater {
|
|||||||
};
|
};
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.status = AutoUpdateStatus::Downloading;
|
this.status = AutoUpdateStatus::Downloading {
|
||||||
|
version: newer_version.clone(),
|
||||||
|
};
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
@@ -540,7 +546,9 @@ impl AutoUpdater {
|
|||||||
download_release(&target_path, fetched_release_data, client, &cx).await?;
|
download_release(&target_path, fetched_release_data, client, &cx).await?;
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.status = AutoUpdateStatus::Installing;
|
this.status = AutoUpdateStatus::Installing {
|
||||||
|
version: newer_version.clone(),
|
||||||
|
};
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
@@ -557,12 +565,12 @@ impl AutoUpdater {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_for_newer_version(
|
fn check_if_fetched_version_is_newer(
|
||||||
release_channel: ReleaseChannel,
|
release_channel: ReleaseChannel,
|
||||||
app_commit_sha: Result<Option<String>>,
|
app_commit_sha: Result<Option<String>>,
|
||||||
installed_version: SemanticVersion,
|
installed_version: SemanticVersion,
|
||||||
status: AutoUpdateStatus,
|
|
||||||
fetched_version: String,
|
fetched_version: String,
|
||||||
|
status: AutoUpdateStatus,
|
||||||
) -> Result<Option<VersionCheckType>> {
|
) -> Result<Option<VersionCheckType>> {
|
||||||
let parsed_fetched_version = fetched_version.parse::<SemanticVersion>();
|
let parsed_fetched_version = fetched_version.parse::<SemanticVersion>();
|
||||||
|
|
||||||
@@ -575,7 +583,7 @@ impl AutoUpdater {
|
|||||||
return Ok(newer_version);
|
return Ok(newer_version);
|
||||||
}
|
}
|
||||||
VersionCheckType::Semantic(cached_version) => {
|
VersionCheckType::Semantic(cached_version) => {
|
||||||
return Self::check_for_newer_version_non_nightly(
|
return Self::check_if_fetched_version_is_newer_non_nightly(
|
||||||
cached_version,
|
cached_version,
|
||||||
parsed_fetched_version?,
|
parsed_fetched_version?,
|
||||||
);
|
);
|
||||||
@@ -594,7 +602,7 @@ impl AutoUpdater {
|
|||||||
.then(|| VersionCheckType::Sha(AppCommitSha::new(fetched_version)));
|
.then(|| VersionCheckType::Sha(AppCommitSha::new(fetched_version)));
|
||||||
Ok(newer_version)
|
Ok(newer_version)
|
||||||
}
|
}
|
||||||
_ => Self::check_for_newer_version_non_nightly(
|
_ => Self::check_if_fetched_version_is_newer_non_nightly(
|
||||||
installed_version,
|
installed_version,
|
||||||
parsed_fetched_version?,
|
parsed_fetched_version?,
|
||||||
),
|
),
|
||||||
@@ -631,7 +639,7 @@ impl AutoUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_for_newer_version_non_nightly(
|
fn check_if_fetched_version_is_newer_non_nightly(
|
||||||
installed_version: SemanticVersion,
|
installed_version: SemanticVersion,
|
||||||
fetched_version: SemanticVersion,
|
fetched_version: SemanticVersion,
|
||||||
) -> Result<Option<VersionCheckType>> {
|
) -> Result<Option<VersionCheckType>> {
|
||||||
@@ -925,12 +933,12 @@ mod tests {
|
|||||||
let status = AutoUpdateStatus::Idle;
|
let status = AutoUpdateStatus::Idle;
|
||||||
let fetched_version = SemanticVersion::new(1, 0, 0);
|
let fetched_version = SemanticVersion::new(1, 0, 0);
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_version.to_string(),
|
fetched_version.to_string(),
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(newer_version.unwrap(), None);
|
assert_eq!(newer_version.unwrap(), None);
|
||||||
@@ -944,12 +952,12 @@ mod tests {
|
|||||||
let status = AutoUpdateStatus::Idle;
|
let status = AutoUpdateStatus::Idle;
|
||||||
let fetched_version = SemanticVersion::new(1, 0, 1);
|
let fetched_version = SemanticVersion::new(1, 0, 1);
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_version.to_string(),
|
fetched_version.to_string(),
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -969,12 +977,12 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let fetched_version = SemanticVersion::new(1, 0, 1);
|
let fetched_version = SemanticVersion::new(1, 0, 1);
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_version.to_string(),
|
fetched_version.to_string(),
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(newer_version.unwrap(), None);
|
assert_eq!(newer_version.unwrap(), None);
|
||||||
@@ -991,12 +999,12 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let fetched_version = SemanticVersion::new(1, 0, 2);
|
let fetched_version = SemanticVersion::new(1, 0, 2);
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_version.to_string(),
|
fetched_version.to_string(),
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -1013,12 +1021,12 @@ mod tests {
|
|||||||
let status = AutoUpdateStatus::Idle;
|
let status = AutoUpdateStatus::Idle;
|
||||||
let fetched_sha = "a".to_string();
|
let fetched_sha = "a".to_string();
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_sha,
|
fetched_sha,
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(newer_version.unwrap(), None);
|
assert_eq!(newer_version.unwrap(), None);
|
||||||
@@ -1032,12 +1040,12 @@ mod tests {
|
|||||||
let status = AutoUpdateStatus::Idle;
|
let status = AutoUpdateStatus::Idle;
|
||||||
let fetched_sha = "b".to_string();
|
let fetched_sha = "b".to_string();
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_sha.clone(),
|
fetched_sha.clone(),
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -1057,12 +1065,12 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let fetched_sha = "b".to_string();
|
let fetched_sha = "b".to_string();
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_sha,
|
fetched_sha,
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(newer_version.unwrap(), None);
|
assert_eq!(newer_version.unwrap(), None);
|
||||||
@@ -1079,12 +1087,12 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let fetched_sha = "c".to_string();
|
let fetched_sha = "c".to_string();
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_sha.clone(),
|
fetched_sha.clone(),
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -1101,12 +1109,12 @@ mod tests {
|
|||||||
let status = AutoUpdateStatus::Idle;
|
let status = AutoUpdateStatus::Idle;
|
||||||
let fetched_sha = "a".to_string();
|
let fetched_sha = "a".to_string();
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_sha.clone(),
|
fetched_sha.clone(),
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -1127,12 +1135,12 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let fetched_sha = "b".to_string();
|
let fetched_sha = "b".to_string();
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_sha,
|
fetched_sha,
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(newer_version.unwrap(), None);
|
assert_eq!(newer_version.unwrap(), None);
|
||||||
@@ -1150,12 +1158,12 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let fetched_sha = "c".to_string();
|
let fetched_sha = "c".to_string();
|
||||||
|
|
||||||
let newer_version = AutoUpdater::check_for_newer_version(
|
let newer_version = AutoUpdater::check_if_fetched_version_is_newer(
|
||||||
release_channel,
|
release_channel,
|
||||||
app_commit_sha,
|
app_commit_sha,
|
||||||
installed_version,
|
installed_version,
|
||||||
status,
|
|
||||||
fetched_sha.clone(),
|
fetched_sha.clone(),
|
||||||
|
status,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ fn view_release_notes_locally(
|
|||||||
|
|
||||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||||
|
|
||||||
let tab_content = SharedString::from(body.title.to_string());
|
let tab_content = Some(SharedString::from(body.title.to_string()));
|
||||||
let editor = cx.new(|cx| {
|
let editor = cx.new(|cx| {
|
||||||
Editor::for_multibuffer(buffer, Some(project), window, cx)
|
Editor::for_multibuffer(buffer, Some(project), window, cx)
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ doctest = false
|
|||||||
editor.workspace = true
|
editor.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
itertools.workspace = true
|
itertools.workspace = true
|
||||||
|
settings.workspace = true
|
||||||
theme.workspace = true
|
theme.workspace = true
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
workspace.workspace = true
|
workspace.workspace = true
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
use editor::Editor;
|
use editor::Editor;
|
||||||
use gpui::{
|
use gpui::{
|
||||||
Context, Element, EventEmitter, Focusable, IntoElement, ParentElement, Render, StyledText,
|
Context, Element, EventEmitter, Focusable, FontWeight, IntoElement, ParentElement, Render,
|
||||||
Subscription, Window,
|
StyledText, Subscription, Window,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
use settings::Settings;
|
||||||
use std::cmp;
|
use std::cmp;
|
||||||
use theme::ActiveTheme;
|
use theme::ActiveTheme;
|
||||||
use ui::{ButtonLike, ButtonStyle, Label, Tooltip, prelude::*};
|
use ui::{ButtonLike, ButtonStyle, Label, Tooltip, prelude::*};
|
||||||
use workspace::{
|
use workspace::{
|
||||||
ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
|
TabBarSettings, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
|
||||||
item::{BreadcrumbText, ItemEvent, ItemHandle},
|
item::{BreadcrumbText, ItemEvent, ItemHandle},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -71,16 +72,23 @@ impl Render for Breadcrumbs {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let highlighted_segments = segments.into_iter().map(|segment| {
|
let highlighted_segments = segments.into_iter().enumerate().map(|(index, segment)| {
|
||||||
let mut text_style = window.text_style();
|
let mut text_style = window.text_style();
|
||||||
if let Some(font) = segment.font {
|
if let Some(ref font) = segment.font {
|
||||||
text_style.font_family = font.family;
|
text_style.font_family = font.family.clone();
|
||||||
text_style.font_features = font.features;
|
text_style.font_features = font.features.clone();
|
||||||
text_style.font_style = font.style;
|
text_style.font_style = font.style;
|
||||||
text_style.font_weight = font.weight;
|
text_style.font_weight = font.weight;
|
||||||
}
|
}
|
||||||
text_style.color = Color::Muted.color(cx);
|
text_style.color = Color::Muted.color(cx);
|
||||||
|
|
||||||
|
if index == 0 && !TabBarSettings::get_global(cx).show && active_item.is_dirty(cx) {
|
||||||
|
if let Some(styled_element) = apply_dirty_filename_style(&segment, &text_style, cx)
|
||||||
|
{
|
||||||
|
return styled_element;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
StyledText::new(segment.text.replace('\n', "⏎"))
|
StyledText::new(segment.text.replace('\n', "⏎"))
|
||||||
.with_default_highlights(&text_style, segment.highlights.unwrap_or_default())
|
.with_default_highlights(&text_style, segment.highlights.unwrap_or_default())
|
||||||
.into_any()
|
.into_any()
|
||||||
@@ -184,3 +192,46 @@ impl ToolbarItemView for Breadcrumbs {
|
|||||||
self.pane_focused = pane_focused;
|
self.pane_focused = pane_focused;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn apply_dirty_filename_style(
|
||||||
|
segment: &BreadcrumbText,
|
||||||
|
text_style: &gpui::TextStyle,
|
||||||
|
cx: &mut Context<Breadcrumbs>,
|
||||||
|
) -> Option<gpui::AnyElement> {
|
||||||
|
let text = segment.text.replace('\n', "⏎");
|
||||||
|
|
||||||
|
let filename_position = std::path::Path::new(&segment.text)
|
||||||
|
.file_name()
|
||||||
|
.and_then(|f| {
|
||||||
|
let filename_str = f.to_string_lossy();
|
||||||
|
segment.text.rfind(filename_str.as_ref())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let bold_weight = FontWeight::BOLD;
|
||||||
|
let default_color = Color::Default.color(cx);
|
||||||
|
|
||||||
|
if filename_position == 0 {
|
||||||
|
let mut filename_style = text_style.clone();
|
||||||
|
filename_style.font_weight = bold_weight;
|
||||||
|
filename_style.color = default_color;
|
||||||
|
|
||||||
|
return Some(
|
||||||
|
StyledText::new(text)
|
||||||
|
.with_default_highlights(&filename_style, [])
|
||||||
|
.into_any(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let highlight_style = gpui::HighlightStyle {
|
||||||
|
font_weight: Some(bold_weight),
|
||||||
|
color: Some(default_color),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let highlight = vec![(filename_position..text.len(), highlight_style)];
|
||||||
|
Some(
|
||||||
|
StyledText::new(text)
|
||||||
|
.with_default_highlights(&text_style, highlight)
|
||||||
|
.into_any(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ pub struct ChannelBuffer {
|
|||||||
pub enum ChannelBufferEvent {
|
pub enum ChannelBufferEvent {
|
||||||
CollaboratorsChanged,
|
CollaboratorsChanged,
|
||||||
Disconnected,
|
Disconnected,
|
||||||
|
Connected,
|
||||||
BufferEdited,
|
BufferEdited,
|
||||||
ChannelChanged,
|
ChannelChanged,
|
||||||
}
|
}
|
||||||
@@ -103,6 +104,17 @@ impl ChannelBuffer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn connected(&mut self, cx: &mut Context<Self>) {
|
||||||
|
self.connected = true;
|
||||||
|
if self.subscription.is_none() {
|
||||||
|
let Ok(subscription) = self.client.subscribe_to_entity(self.channel_id.0) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
self.subscription = Some(subscription.set_entity(&cx.entity(), &mut cx.to_async()));
|
||||||
|
cx.emit(ChannelBufferEvent::Connected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn remote_id(&self, cx: &App) -> BufferId {
|
pub fn remote_id(&self, cx: &App) -> BufferId {
|
||||||
self.buffer.read(cx).remote_id()
|
self.buffer.read(cx).remote_id()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -972,6 +972,7 @@ impl ChannelStore {
|
|||||||
.log_err();
|
.log_err();
|
||||||
|
|
||||||
if let Some(operations) = operations {
|
if let Some(operations) = operations {
|
||||||
|
channel_buffer.connected(cx);
|
||||||
let client = this.client.clone();
|
let client = this.client.clone();
|
||||||
cx.background_spawn(async move {
|
cx.background_spawn(async move {
|
||||||
let operations = operations.await;
|
let operations = operations.await;
|
||||||
@@ -1012,8 +1013,8 @@ impl ChannelStore {
|
|||||||
|
|
||||||
if let Some(this) = this.upgrade() {
|
if let Some(this) = this.upgrade() {
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
for (_, buffer) in this.opened_buffers.drain() {
|
for (_, buffer) in &this.opened_buffers {
|
||||||
if let OpenEntityHandle::Open(buffer) = buffer {
|
if let OpenEntityHandle::Open(buffer) = &buffer {
|
||||||
if let Some(buffer) = buffer.upgrade() {
|
if let Some(buffer) = buffer.upgrade() {
|
||||||
buffer.update(cx, |buffer, cx| buffer.disconnect(cx));
|
buffer.update(cx, |buffer, cx| buffer.disconnect(cx));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ test-support = ["sqlite"]
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
async-stripe.workspace = true
|
async-stripe.workspace = true
|
||||||
|
async-trait.workspace = true
|
||||||
async-tungstenite.workspace = true
|
async-tungstenite.workspace = true
|
||||||
aws-config = { version = "1.1.5" }
|
aws-config = { version = "1.1.5" }
|
||||||
aws-sdk-s3 = { version = "1.15.0" }
|
aws-sdk-s3 = { version = "1.15.0" }
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ use stripe::{
|
|||||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
|
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
|
||||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
|
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
|
||||||
CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
|
CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
|
||||||
Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
|
PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
|
||||||
};
|
};
|
||||||
use util::{ResultExt, maybe};
|
use util::{ResultExt, maybe};
|
||||||
|
|
||||||
@@ -29,6 +29,10 @@ use crate::db::billing_subscription::{
|
|||||||
use crate::llm::db::subscription_usage_meter::CompletionMode;
|
use crate::llm::db::subscription_usage_meter::CompletionMode;
|
||||||
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
|
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
|
||||||
use crate::rpc::{ResultExt as _, Server};
|
use crate::rpc::{ResultExt as _, Server};
|
||||||
|
use crate::stripe_client::{
|
||||||
|
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
|
||||||
|
StripeSubscriptionId,
|
||||||
|
};
|
||||||
use crate::{AppState, Error, Result};
|
use crate::{AppState, Error, Result};
|
||||||
use crate::{db::UserId, llm::db::LlmDatabase};
|
use crate::{db::UserId, llm::db::LlmDatabase};
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -54,10 +58,6 @@ pub fn router() -> Router {
|
|||||||
"/billing/subscriptions/manage",
|
"/billing/subscriptions/manage",
|
||||||
post(manage_billing_subscription),
|
post(manage_billing_subscription),
|
||||||
)
|
)
|
||||||
.route(
|
|
||||||
"/billing/subscriptions/migrate",
|
|
||||||
post(migrate_to_new_billing),
|
|
||||||
)
|
|
||||||
.route(
|
.route(
|
||||||
"/billing/subscriptions/sync",
|
"/billing/subscriptions/sync",
|
||||||
post(sync_billing_subscription),
|
post(sync_billing_subscription),
|
||||||
@@ -282,7 +282,6 @@ async fn list_billing_subscriptions(
|
|||||||
enum ProductCode {
|
enum ProductCode {
|
||||||
ZedPro,
|
ZedPro,
|
||||||
ZedProTrial,
|
ZedProTrial,
|
||||||
ZedFree,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -338,8 +337,7 @@ async fn create_billing_subscription(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
|
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
|
||||||
CustomerId::from_str(&existing_customer.stripe_customer_id)
|
StripeCustomerId(existing_customer.stripe_customer_id.clone().into())
|
||||||
.context("failed to parse customer ID")?
|
|
||||||
} else {
|
} else {
|
||||||
stripe_billing
|
stripe_billing
|
||||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||||
@@ -354,7 +352,7 @@ async fn create_billing_subscription(
|
|||||||
let checkout_session_url = match body.product {
|
let checkout_session_url = match body.product {
|
||||||
ProductCode::ZedPro => {
|
ProductCode::ZedPro => {
|
||||||
stripe_billing
|
stripe_billing
|
||||||
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
|
.checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
|
||||||
.await?
|
.await?
|
||||||
}
|
}
|
||||||
ProductCode::ZedProTrial => {
|
ProductCode::ZedProTrial => {
|
||||||
@@ -371,18 +369,13 @@ async fn create_billing_subscription(
|
|||||||
|
|
||||||
stripe_billing
|
stripe_billing
|
||||||
.checkout_with_zed_pro_trial(
|
.checkout_with_zed_pro_trial(
|
||||||
customer_id,
|
&customer_id,
|
||||||
&user.github_login,
|
&user.github_login,
|
||||||
feature_flags,
|
feature_flags,
|
||||||
&success_url,
|
&success_url,
|
||||||
)
|
)
|
||||||
.await?
|
.await?
|
||||||
}
|
}
|
||||||
ProductCode::ZedFree => {
|
|
||||||
stripe_billing
|
|
||||||
.checkout_with_zed_free(customer_id, &user.github_login, &success_url)
|
|
||||||
.await?
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Json(CreateBillingSubscriptionResponse {
|
Ok(Json(CreateBillingSubscriptionResponse {
|
||||||
@@ -432,7 +425,7 @@ async fn manage_billing_subscription(
|
|||||||
.await?
|
.await?
|
||||||
.context("user not found")?;
|
.context("user not found")?;
|
||||||
|
|
||||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
let Some(stripe_client) = app.real_stripe_client.clone() else {
|
||||||
log::error!("failed to retrieve Stripe client");
|
log::error!("failed to retrieve Stripe client");
|
||||||
Err(Error::http(
|
Err(Error::http(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
@@ -498,8 +491,10 @@ async fn manage_billing_subscription(
|
|||||||
let flow = match body.intent {
|
let flow = match body.intent {
|
||||||
ManageSubscriptionIntent::ManageSubscription => None,
|
ManageSubscriptionIntent::ManageSubscription => None,
|
||||||
ManageSubscriptionIntent::UpgradeToPro => {
|
ManageSubscriptionIntent::UpgradeToPro => {
|
||||||
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
|
let zed_pro_price_id: stripe::PriceId =
|
||||||
let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
|
stripe_billing.zed_pro_price_id().await?.try_into()?;
|
||||||
|
let zed_free_price_id: stripe::PriceId =
|
||||||
|
stripe_billing.zed_free_price_id().await?.try_into()?;
|
||||||
|
|
||||||
let stripe_subscription =
|
let stripe_subscription =
|
||||||
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
|
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
|
||||||
@@ -633,86 +628,6 @@ async fn manage_billing_subscription(
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct MigrateToNewBillingBody {
|
|
||||||
github_user_id: i32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
|
||||||
struct MigrateToNewBillingResponse {
|
|
||||||
/// The ID of the subscription that was canceled.
|
|
||||||
canceled_subscription_id: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn migrate_to_new_billing(
|
|
||||||
Extension(app): Extension<Arc<AppState>>,
|
|
||||||
extract::Json(body): extract::Json<MigrateToNewBillingBody>,
|
|
||||||
) -> Result<Json<MigrateToNewBillingResponse>> {
|
|
||||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
|
||||||
log::error!("failed to retrieve Stripe client");
|
|
||||||
Err(Error::http(
|
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
|
||||||
"not supported".into(),
|
|
||||||
))?
|
|
||||||
};
|
|
||||||
|
|
||||||
let user = app
|
|
||||||
.db
|
|
||||||
.get_user_by_github_user_id(body.github_user_id)
|
|
||||||
.await?
|
|
||||||
.context("user not found")?;
|
|
||||||
|
|
||||||
let old_billing_subscriptions_by_user = app
|
|
||||||
.db
|
|
||||||
.get_active_billing_subscriptions(HashSet::from_iter([user.id]))
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
|
|
||||||
old_billing_subscriptions_by_user.get(&user.id)
|
|
||||||
{
|
|
||||||
let stripe_subscription_id = billing_subscription
|
|
||||||
.stripe_subscription_id
|
|
||||||
.parse::<stripe::SubscriptionId>()
|
|
||||||
.context("failed to parse Stripe subscription ID from database")?;
|
|
||||||
|
|
||||||
Subscription::cancel(
|
|
||||||
&stripe_client,
|
|
||||||
&stripe_subscription_id,
|
|
||||||
stripe::CancelSubscription {
|
|
||||||
invoice_now: Some(true),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Some(stripe_subscription_id)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let all_feature_flags = app.db.list_feature_flags().await?;
|
|
||||||
let user_feature_flags = app.db.get_user_flags(user.id).await?;
|
|
||||||
|
|
||||||
for feature_flag in ["new-billing", "assistant2"] {
|
|
||||||
let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
|
|
||||||
if already_in_feature_flag {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let feature_flag = all_feature_flags
|
|
||||||
.iter()
|
|
||||||
.find(|flag| flag.flag == feature_flag)
|
|
||||||
.context("failed to find feature flag: {feature_flag:?}")?;
|
|
||||||
|
|
||||||
app.db.add_user_flag(user.id, feature_flag.id).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Json(MigrateToNewBillingResponse {
|
|
||||||
canceled_subscription_id: canceled_subscription_id
|
|
||||||
.map(|subscription_id| subscription_id.to_string()),
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct SyncBillingSubscriptionBody {
|
struct SyncBillingSubscriptionBody {
|
||||||
github_user_id: i32,
|
github_user_id: i32,
|
||||||
@@ -746,23 +661,13 @@ async fn sync_billing_subscription(
|
|||||||
.get_billing_customer_by_user_id(user.id)
|
.get_billing_customer_by_user_id(user.id)
|
||||||
.await?
|
.await?
|
||||||
.context("billing customer not found")?;
|
.context("billing customer not found")?;
|
||||||
let stripe_customer_id = billing_customer
|
let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||||
.stripe_customer_id
|
|
||||||
.parse::<stripe::CustomerId>()
|
|
||||||
.context("failed to parse Stripe customer ID from database")?;
|
|
||||||
|
|
||||||
let subscriptions = Subscription::list(
|
let subscriptions = stripe_client
|
||||||
&stripe_client,
|
.list_subscriptions_for_customer(&stripe_customer_id)
|
||||||
&stripe::ListSubscriptions {
|
.await?;
|
||||||
customer: Some(stripe_customer_id),
|
|
||||||
// Sync all non-canceled subscriptions.
|
|
||||||
status: None,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
for subscription in subscriptions.data {
|
for subscription in subscriptions {
|
||||||
let subscription_id = subscription.id.clone();
|
let subscription_id = subscription.id.clone();
|
||||||
|
|
||||||
sync_subscription(&app, &stripe_client, subscription)
|
sync_subscription(&app, &stripe_client, subscription)
|
||||||
@@ -810,6 +715,10 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
|
|||||||
/// Polls the Stripe events API periodically to reconcile the records in our
|
/// Polls the Stripe events API periodically to reconcile the records in our
|
||||||
/// database with the data in Stripe.
|
/// database with the data in Stripe.
|
||||||
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
|
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
|
||||||
|
let Some(real_stripe_client) = app.real_stripe_client.clone() else {
|
||||||
|
log::warn!("failed to retrieve Stripe client");
|
||||||
|
return;
|
||||||
|
};
|
||||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||||
log::warn!("failed to retrieve Stripe client");
|
log::warn!("failed to retrieve Stripe client");
|
||||||
return;
|
return;
|
||||||
@@ -820,7 +729,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
|
|||||||
let executor = executor.clone();
|
let executor = executor.clone();
|
||||||
async move {
|
async move {
|
||||||
loop {
|
loop {
|
||||||
poll_stripe_events(&app, &rpc_server, &stripe_client)
|
poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
|
||||||
.await
|
.await
|
||||||
.log_err();
|
.log_err();
|
||||||
|
|
||||||
@@ -833,7 +742,8 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
|
|||||||
async fn poll_stripe_events(
|
async fn poll_stripe_events(
|
||||||
app: &Arc<AppState>,
|
app: &Arc<AppState>,
|
||||||
rpc_server: &Arc<Server>,
|
rpc_server: &Arc<Server>,
|
||||||
stripe_client: &stripe::Client,
|
stripe_client: &Arc<dyn StripeClient>,
|
||||||
|
real_stripe_client: &stripe::Client,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
fn event_type_to_string(event_type: EventType) -> String {
|
fn event_type_to_string(event_type: EventType) -> String {
|
||||||
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
|
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
|
||||||
@@ -865,7 +775,7 @@ async fn poll_stripe_events(
|
|||||||
params.types = Some(event_types.clone());
|
params.types = Some(event_types.clone());
|
||||||
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
|
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
|
||||||
|
|
||||||
let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
|
let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms)
|
||||||
.await?
|
.await?
|
||||||
.paginate(params);
|
.paginate(params);
|
||||||
|
|
||||||
@@ -909,7 +819,7 @@ async fn poll_stripe_events(
|
|||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
log::info!("Stripe events: retrieving next page");
|
log::info!("Stripe events: retrieving next page");
|
||||||
event_pages = event_pages.next(&stripe_client).await?;
|
event_pages = event_pages.next(&real_stripe_client).await?;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
@@ -949,7 +859,7 @@ async fn poll_stripe_events(
|
|||||||
|
|
||||||
let process_result = match event.type_ {
|
let process_result = match event.type_ {
|
||||||
EventType::CustomerCreated | EventType::CustomerUpdated => {
|
EventType::CustomerCreated | EventType::CustomerUpdated => {
|
||||||
handle_customer_event(app, stripe_client, event).await
|
handle_customer_event(app, real_stripe_client, event).await
|
||||||
}
|
}
|
||||||
EventType::CustomerSubscriptionCreated
|
EventType::CustomerSubscriptionCreated
|
||||||
| EventType::CustomerSubscriptionUpdated
|
| EventType::CustomerSubscriptionUpdated
|
||||||
@@ -1024,8 +934,8 @@ async fn handle_customer_event(
|
|||||||
|
|
||||||
async fn sync_subscription(
|
async fn sync_subscription(
|
||||||
app: &Arc<AppState>,
|
app: &Arc<AppState>,
|
||||||
stripe_client: &stripe::Client,
|
stripe_client: &Arc<dyn StripeClient>,
|
||||||
subscription: stripe::Subscription,
|
subscription: StripeSubscription,
|
||||||
) -> anyhow::Result<billing_customer::Model> {
|
) -> anyhow::Result<billing_customer::Model> {
|
||||||
let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
|
let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
|
||||||
stripe_billing
|
stripe_billing
|
||||||
@@ -1036,7 +946,7 @@ async fn sync_subscription(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let billing_customer =
|
let billing_customer =
|
||||||
find_or_create_billing_customer(app, stripe_client, subscription.customer)
|
find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
|
||||||
.await?
|
.await?
|
||||||
.context("billing customer not found")?;
|
.context("billing customer not found")?;
|
||||||
|
|
||||||
@@ -1064,7 +974,7 @@ async fn sync_subscription(
|
|||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|details| details.reason)
|
.and_then(|details| details.reason)
|
||||||
.map_or(false, |reason| {
|
.map_or(false, |reason| {
|
||||||
reason == CancellationDetailsReason::PaymentFailed
|
reason == StripeCancellationDetailsReason::PaymentFailed
|
||||||
});
|
});
|
||||||
|
|
||||||
if was_canceled_due_to_payment_failure {
|
if was_canceled_due_to_payment_failure {
|
||||||
@@ -1081,7 +991,7 @@ async fn sync_subscription(
|
|||||||
|
|
||||||
if let Some(existing_subscription) = app
|
if let Some(existing_subscription) = app
|
||||||
.db
|
.db
|
||||||
.get_billing_subscription_by_stripe_subscription_id(&subscription.id)
|
.get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
|
||||||
.await?
|
.await?
|
||||||
{
|
{
|
||||||
app.db
|
app.db
|
||||||
@@ -1122,20 +1032,13 @@ async fn sync_subscription(
|
|||||||
if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
|
if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
|
||||||
&& subscription_kind == Some(SubscriptionKind::ZedProTrial)
|
&& subscription_kind == Some(SubscriptionKind::ZedProTrial)
|
||||||
{
|
{
|
||||||
let stripe_subscription_id = existing_subscription
|
let stripe_subscription_id = StripeSubscriptionId(
|
||||||
.stripe_subscription_id
|
existing_subscription.stripe_subscription_id.clone().into(),
|
||||||
.parse::<stripe::SubscriptionId>()
|
);
|
||||||
.context("failed to parse Stripe subscription ID from database")?;
|
|
||||||
|
|
||||||
Subscription::cancel(
|
stripe_client
|
||||||
&stripe_client,
|
.cancel_subscription(&stripe_subscription_id)
|
||||||
&stripe_subscription_id,
|
.await?;
|
||||||
stripe::CancelSubscription {
|
|
||||||
invoice_now: None,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
} else {
|
} else {
|
||||||
// If the user already has an active billing subscription, ignore the
|
// If the user already has an active billing subscription, ignore the
|
||||||
// event and return an `Ok` to signal that it was processed
|
// event and return an `Ok` to signal that it was processed
|
||||||
@@ -1186,10 +1089,8 @@ async fn sync_subscription(
|
|||||||
.has_active_billing_subscription(billing_customer.user_id)
|
.has_active_billing_subscription(billing_customer.user_id)
|
||||||
.await?;
|
.await?;
|
||||||
if !already_has_active_billing_subscription {
|
if !already_has_active_billing_subscription {
|
||||||
let stripe_customer_id = billing_customer
|
let stripe_customer_id =
|
||||||
.stripe_customer_id
|
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||||
.parse::<stripe::CustomerId>()
|
|
||||||
.context("failed to parse Stripe customer ID from database")?;
|
|
||||||
|
|
||||||
stripe_billing
|
stripe_billing
|
||||||
.subscribe_to_zed_free(stripe_customer_id)
|
.subscribe_to_zed_free(stripe_customer_id)
|
||||||
@@ -1204,7 +1105,7 @@ async fn sync_subscription(
|
|||||||
async fn handle_customer_subscription_event(
|
async fn handle_customer_subscription_event(
|
||||||
app: &Arc<AppState>,
|
app: &Arc<AppState>,
|
||||||
rpc_server: &Arc<Server>,
|
rpc_server: &Arc<Server>,
|
||||||
stripe_client: &stripe::Client,
|
stripe_client: &Arc<dyn StripeClient>,
|
||||||
event: stripe::Event,
|
event: stripe::Event,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let EventObject::Subscription(subscription) = event.data.object else {
|
let EventObject::Subscription(subscription) = event.data.object else {
|
||||||
@@ -1213,7 +1114,7 @@ async fn handle_customer_subscription_event(
|
|||||||
|
|
||||||
log::info!("handling Stripe {} event: {}", event.type_, event.id);
|
log::info!("handling Stripe {} event: {}", event.type_, event.id);
|
||||||
|
|
||||||
let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
|
let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
|
||||||
|
|
||||||
// When the user's subscription changes, push down any changes to their plan.
|
// When the user's subscription changes, push down any changes to their plan.
|
||||||
rpc_server
|
rpc_server
|
||||||
@@ -1409,30 +1310,20 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
|
|||||||
/// Finds or creates a billing customer using the provided customer.
|
/// Finds or creates a billing customer using the provided customer.
|
||||||
pub async fn find_or_create_billing_customer(
|
pub async fn find_or_create_billing_customer(
|
||||||
app: &Arc<AppState>,
|
app: &Arc<AppState>,
|
||||||
stripe_client: &stripe::Client,
|
stripe_client: &dyn StripeClient,
|
||||||
customer_or_id: Expandable<Customer>,
|
customer_id: &StripeCustomerId,
|
||||||
) -> anyhow::Result<Option<billing_customer::Model>> {
|
) -> anyhow::Result<Option<billing_customer::Model>> {
|
||||||
let customer_id = match &customer_or_id {
|
|
||||||
Expandable::Id(id) => id,
|
|
||||||
Expandable::Object(customer) => customer.id.as_ref(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// If we already have a billing customer record associated with the Stripe customer,
|
// If we already have a billing customer record associated with the Stripe customer,
|
||||||
// there's nothing more we need to do.
|
// there's nothing more we need to do.
|
||||||
if let Some(billing_customer) = app
|
if let Some(billing_customer) = app
|
||||||
.db
|
.db
|
||||||
.get_billing_customer_by_stripe_customer_id(customer_id)
|
.get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
|
||||||
.await?
|
.await?
|
||||||
{
|
{
|
||||||
return Ok(Some(billing_customer));
|
return Ok(Some(billing_customer));
|
||||||
}
|
}
|
||||||
|
|
||||||
// If all we have is a customer ID, resolve it to a full customer record by
|
let customer = stripe_client.get_customer(customer_id).await?;
|
||||||
// hitting the Stripe API.
|
|
||||||
let customer = match customer_or_id {
|
|
||||||
Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
|
|
||||||
Expandable::Object(customer) => *customer,
|
|
||||||
};
|
|
||||||
|
|
||||||
let Some(email) = customer.email else {
|
let Some(email) = customer.email else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
@@ -1542,14 +1433,10 @@ async fn sync_model_request_usage_with_stripe(
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
let stripe_customer_id = billing_customer
|
let stripe_customer_id =
|
||||||
.stripe_customer_id
|
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||||
.parse::<stripe::CustomerId>()
|
let stripe_subscription_id =
|
||||||
.context("failed to parse Stripe customer ID from database")?;
|
StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
|
||||||
let stripe_subscription_id = billing_subscription
|
|
||||||
.stripe_subscription_id
|
|
||||||
.parse::<stripe::SubscriptionId>()
|
|
||||||
.context("failed to parse Stripe subscription ID from database")?;
|
|
||||||
|
|
||||||
let model = llm_db.model_by_id(usage_meter.model_id)?;
|
let model = llm_db.model_by_id(usage_meter.model_id)?;
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::db::{BillingCustomerId, BillingSubscriptionId};
|
use crate::db::{BillingCustomerId, BillingSubscriptionId};
|
||||||
|
use crate::stripe_client;
|
||||||
use chrono::{Datelike as _, NaiveDate, Utc};
|
use chrono::{Datelike as _, NaiveDate, Utc};
|
||||||
use sea_orm::entity::prelude::*;
|
use sea_orm::entity::prelude::*;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
@@ -159,3 +160,17 @@ pub enum StripeCancellationReason {
|
|||||||
#[sea_orm(string_value = "payment_failed")]
|
#[sea_orm(string_value = "payment_failed")]
|
||||||
PaymentFailed,
|
PaymentFailed,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<stripe_client::StripeCancellationDetailsReason> for StripeCancellationReason {
|
||||||
|
fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self {
|
||||||
|
match value {
|
||||||
|
stripe_client::StripeCancellationDetailsReason::CancellationRequested => {
|
||||||
|
Self::CancellationRequested
|
||||||
|
}
|
||||||
|
stripe_client::StripeCancellationDetailsReason::PaymentDisputed => {
|
||||||
|
Self::PaymentDisputed
|
||||||
|
}
|
||||||
|
stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ pub mod migrations;
|
|||||||
pub mod rpc;
|
pub mod rpc;
|
||||||
pub mod seed;
|
pub mod seed;
|
||||||
pub mod stripe_billing;
|
pub mod stripe_billing;
|
||||||
|
pub mod stripe_client;
|
||||||
pub mod user_backfiller;
|
pub mod user_backfiller;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -29,6 +30,7 @@ use std::{path::PathBuf, sync::Arc};
|
|||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
use crate::stripe_billing::StripeBilling;
|
use crate::stripe_billing::StripeBilling;
|
||||||
|
use crate::stripe_client::{RealStripeClient, StripeClient};
|
||||||
|
|
||||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||||
|
|
||||||
@@ -269,7 +271,10 @@ pub struct AppState {
|
|||||||
pub llm_db: Option<Arc<LlmDatabase>>,
|
pub llm_db: Option<Arc<LlmDatabase>>,
|
||||||
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
|
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
|
||||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||||
pub stripe_client: Option<Arc<stripe::Client>>,
|
/// This is a real instance of the Stripe client; we're working to replace references to this with the
|
||||||
|
/// [`StripeClient`] trait.
|
||||||
|
pub real_stripe_client: Option<Arc<stripe::Client>>,
|
||||||
|
pub stripe_client: Option<Arc<dyn StripeClient>>,
|
||||||
pub stripe_billing: Option<Arc<StripeBilling>>,
|
pub stripe_billing: Option<Arc<StripeBilling>>,
|
||||||
pub executor: Executor,
|
pub executor: Executor,
|
||||||
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
|
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
|
||||||
@@ -322,7 +327,9 @@ impl AppState {
|
|||||||
stripe_billing: stripe_client
|
stripe_billing: stripe_client
|
||||||
.clone()
|
.clone()
|
||||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||||
stripe_client,
|
real_stripe_client: stripe_client.clone(),
|
||||||
|
stripe_client: stripe_client
|
||||||
|
.map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _),
|
||||||
executor,
|
executor,
|
||||||
kinesis_client: if config.kinesis_access_key.is_some() {
|
kinesis_client: if config.kinesis_access_key.is_some() {
|
||||||
build_kinesis_client(&config).await.log_err()
|
build_kinesis_client(&config).await.log_err()
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
|||||||
use crate::db::billing_subscription::SubscriptionKind;
|
use crate::db::billing_subscription::SubscriptionKind;
|
||||||
use crate::llm::db::LlmDatabase;
|
use crate::llm::db::LlmDatabase;
|
||||||
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
|
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
|
||||||
|
use crate::stripe_client::StripeCustomerId;
|
||||||
use crate::{
|
use crate::{
|
||||||
AppState, Error, Result, auth,
|
AppState, Error, Result, auth,
|
||||||
db::{
|
db::{
|
||||||
@@ -4033,31 +4034,26 @@ async fn get_llm_api_token(
|
|||||||
.as_ref()
|
.as_ref()
|
||||||
.context("failed to retrieve Stripe billing object")?;
|
.context("failed to retrieve Stripe billing object")?;
|
||||||
|
|
||||||
let billing_customer =
|
let billing_customer = if let Some(billing_customer) =
|
||||||
if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
|
db.get_billing_customer_by_user_id(user.id).await?
|
||||||
billing_customer
|
{
|
||||||
} else {
|
billing_customer
|
||||||
let customer_id = stripe_billing
|
} else {
|
||||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
let customer_id = stripe_billing
|
||||||
.await?;
|
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||||
|
.await?;
|
||||||
|
|
||||||
find_or_create_billing_customer(
|
find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
|
||||||
&session.app_state,
|
|
||||||
&stripe_client,
|
|
||||||
stripe::Expandable::Id(customer_id),
|
|
||||||
)
|
|
||||||
.await?
|
.await?
|
||||||
.context("billing customer not found")?
|
.context("billing customer not found")?
|
||||||
};
|
};
|
||||||
|
|
||||||
let billing_subscription =
|
let billing_subscription =
|
||||||
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
|
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
|
||||||
billing_subscription
|
billing_subscription
|
||||||
} else {
|
} else {
|
||||||
let stripe_customer_id = billing_customer
|
let stripe_customer_id =
|
||||||
.stripe_customer_id
|
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||||
.parse::<stripe::CustomerId>()
|
|
||||||
.context("failed to parse Stripe customer ID from database")?;
|
|
||||||
|
|
||||||
let stripe_subscription = stripe_billing
|
let stripe_subscription = stripe_billing
|
||||||
.subscribe_to_zed_free(stripe_customer_id)
|
.subscribe_to_zed_free(stripe_customer_id)
|
||||||
|
|||||||
@@ -1,30 +1,49 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{Context as _, anyhow};
|
||||||
|
use chrono::Utc;
|
||||||
|
use collections::HashMap;
|
||||||
|
use stripe::SubscriptionStatus;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::db::billing_subscription::SubscriptionKind;
|
use crate::db::billing_subscription::SubscriptionKind;
|
||||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||||
use anyhow::{Context as _, anyhow};
|
use crate::stripe_client::{
|
||||||
use chrono::Utc;
|
RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
|
||||||
use collections::HashMap;
|
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||||
use serde::{Deserialize, Serialize};
|
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
|
||||||
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
|
StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
|
||||||
use tokio::sync::RwLock;
|
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
|
||||||
use uuid::Uuid;
|
StripeSubscriptionId, StripeSubscriptionTrialSettings,
|
||||||
|
StripeSubscriptionTrialSettingsEndBehavior,
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
|
||||||
|
UpdateSubscriptionParams,
|
||||||
|
};
|
||||||
|
|
||||||
pub struct StripeBilling {
|
pub struct StripeBilling {
|
||||||
state: RwLock<StripeBillingState>,
|
state: RwLock<StripeBillingState>,
|
||||||
client: Arc<stripe::Client>,
|
client: Arc<dyn StripeClient>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
struct StripeBillingState {
|
struct StripeBillingState {
|
||||||
meters_by_event_name: HashMap<String, StripeMeter>,
|
meters_by_event_name: HashMap<String, StripeMeter>,
|
||||||
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
|
price_ids_by_meter_id: HashMap<String, StripePriceId>,
|
||||||
prices_by_lookup_key: HashMap<String, stripe::Price>,
|
prices_by_lookup_key: HashMap<String, StripePrice>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StripeBilling {
|
impl StripeBilling {
|
||||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||||
|
Self {
|
||||||
|
client: Arc::new(RealStripeClient::new(client.clone())),
|
||||||
|
state: RwLock::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
state: RwLock::default(),
|
state: RwLock::default(),
|
||||||
@@ -36,24 +55,16 @@ impl StripeBilling {
|
|||||||
|
|
||||||
let mut state = self.state.write().await;
|
let mut state = self.state.write().await;
|
||||||
|
|
||||||
let (meters, prices) = futures::try_join!(
|
let (meters, prices) =
|
||||||
StripeMeter::list(&self.client),
|
futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
|
||||||
stripe::Price::list(
|
|
||||||
&self.client,
|
|
||||||
&stripe::ListPrices {
|
|
||||||
limit: Some(100),
|
|
||||||
..Default::default()
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)?;
|
|
||||||
|
|
||||||
for meter in meters.data {
|
for meter in meters {
|
||||||
state
|
state
|
||||||
.meters_by_event_name
|
.meters_by_event_name
|
||||||
.insert(meter.event_name.clone(), meter);
|
.insert(meter.event_name.clone(), meter);
|
||||||
}
|
}
|
||||||
|
|
||||||
for price in prices.data {
|
for price in prices {
|
||||||
if let Some(lookup_key) = price.lookup_key.clone() {
|
if let Some(lookup_key) = price.lookup_key.clone() {
|
||||||
state.prices_by_lookup_key.insert(lookup_key, price.clone());
|
state.prices_by_lookup_key.insert(lookup_key, price.clone());
|
||||||
}
|
}
|
||||||
@@ -70,15 +81,15 @@ impl StripeBilling {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
|
pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
|
||||||
self.find_price_id_by_lookup_key("zed-pro").await
|
self.find_price_id_by_lookup_key("zed-pro").await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn zed_free_price_id(&self) -> Result<PriceId> {
|
pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
|
||||||
self.find_price_id_by_lookup_key("zed-free").await
|
self.find_price_id_by_lookup_key("zed-free").await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
|
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
|
||||||
self.state
|
self.state
|
||||||
.read()
|
.read()
|
||||||
.await
|
.await
|
||||||
@@ -88,7 +99,7 @@ impl StripeBilling {
|
|||||||
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
|
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
|
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
|
||||||
self.state
|
self.state
|
||||||
.read()
|
.read()
|
||||||
.await
|
.await
|
||||||
@@ -100,12 +111,12 @@ impl StripeBilling {
|
|||||||
|
|
||||||
pub async fn determine_subscription_kind(
|
pub async fn determine_subscription_kind(
|
||||||
&self,
|
&self,
|
||||||
subscription: &stripe::Subscription,
|
subscription: &StripeSubscription,
|
||||||
) -> Option<SubscriptionKind> {
|
) -> Option<SubscriptionKind> {
|
||||||
let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
|
let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
|
||||||
let zed_free_price_id = self.zed_free_price_id().await.ok()?;
|
let zed_free_price_id = self.zed_free_price_id().await.ok()?;
|
||||||
|
|
||||||
subscription.items.data.iter().find_map(|item| {
|
subscription.items.iter().find_map(|item| {
|
||||||
let price = item.price.as_ref()?;
|
let price = item.price.as_ref()?;
|
||||||
|
|
||||||
if price.id == zed_pro_price_id {
|
if price.id == zed_pro_price_id {
|
||||||
@@ -129,18 +140,11 @@ impl StripeBilling {
|
|||||||
pub async fn find_or_create_customer_by_email(
|
pub async fn find_or_create_customer_by_email(
|
||||||
&self,
|
&self,
|
||||||
email_address: Option<&str>,
|
email_address: Option<&str>,
|
||||||
) -> Result<CustomerId> {
|
) -> Result<StripeCustomerId> {
|
||||||
let existing_customer = if let Some(email) = email_address {
|
let existing_customer = if let Some(email) = email_address {
|
||||||
let customers = Customer::list(
|
let customers = self.client.list_customers_by_email(email).await?;
|
||||||
&self.client,
|
|
||||||
&stripe::ListCustomers {
|
|
||||||
email: Some(email),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
customers.data.first().cloned()
|
customers.first().cloned()
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@@ -148,14 +152,12 @@ impl StripeBilling {
|
|||||||
let customer_id = if let Some(existing_customer) = existing_customer {
|
let customer_id = if let Some(existing_customer) = existing_customer {
|
||||||
existing_customer.id
|
existing_customer.id
|
||||||
} else {
|
} else {
|
||||||
let customer = Customer::create(
|
let customer = self
|
||||||
&self.client,
|
.client
|
||||||
CreateCustomer {
|
.create_customer(crate::stripe_client::CreateCustomerParams {
|
||||||
email: email_address,
|
email: email_address,
|
||||||
..Default::default()
|
})
|
||||||
},
|
.await?;
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
customer.id
|
customer.id
|
||||||
};
|
};
|
||||||
@@ -165,11 +167,10 @@ impl StripeBilling {
|
|||||||
|
|
||||||
pub async fn subscribe_to_price(
|
pub async fn subscribe_to_price(
|
||||||
&self,
|
&self,
|
||||||
subscription_id: &stripe::SubscriptionId,
|
subscription_id: &StripeSubscriptionId,
|
||||||
price: &stripe::Price,
|
price: &StripePrice,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let subscription =
|
let subscription = self.client.get_subscription(subscription_id).await?;
|
||||||
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
|
||||||
|
|
||||||
if subscription_contains_price(&subscription, &price.id) {
|
if subscription_contains_price(&subscription, &price.id) {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
@@ -180,39 +181,36 @@ impl StripeBilling {
|
|||||||
let price_per_unit = price.unit_amount.unwrap_or_default();
|
let price_per_unit = price.unit_amount.unwrap_or_default();
|
||||||
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
|
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
|
||||||
|
|
||||||
stripe::Subscription::update(
|
self.client
|
||||||
&self.client,
|
.update_subscription(
|
||||||
subscription_id,
|
subscription_id,
|
||||||
stripe::UpdateSubscription {
|
UpdateSubscriptionParams {
|
||||||
items: Some(vec![stripe::UpdateSubscriptionItems {
|
items: Some(vec![UpdateSubscriptionItems {
|
||||||
price: Some(price.id.to_string()),
|
price: Some(price.id.clone()),
|
||||||
..Default::default()
|
}]),
|
||||||
}]),
|
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||||
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
|
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||||
end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
|
missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
|
||||||
missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
},
|
||||||
},
|
}),
|
||||||
}),
|
},
|
||||||
..Default::default()
|
)
|
||||||
},
|
.await?;
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn bill_model_request_usage(
|
pub async fn bill_model_request_usage(
|
||||||
&self,
|
&self,
|
||||||
customer_id: &stripe::CustomerId,
|
customer_id: &StripeCustomerId,
|
||||||
event_name: &str,
|
event_name: &str,
|
||||||
requests: i32,
|
requests: i32,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let timestamp = Utc::now().timestamp();
|
let timestamp = Utc::now().timestamp();
|
||||||
let idempotency_key = Uuid::new_v4();
|
let idempotency_key = Uuid::new_v4();
|
||||||
|
|
||||||
StripeMeterEvent::create(
|
self.client
|
||||||
&self.client,
|
.create_meter_event(StripeCreateMeterEventParams {
|
||||||
StripeCreateMeterEventParams {
|
|
||||||
identifier: &format!("model_requests/{}", idempotency_key),
|
identifier: &format!("model_requests/{}", idempotency_key),
|
||||||
event_name,
|
event_name,
|
||||||
payload: StripeCreateMeterEventPayload {
|
payload: StripeCreateMeterEventPayload {
|
||||||
@@ -220,39 +218,37 @@ impl StripeBilling {
|
|||||||
stripe_customer_id: customer_id,
|
stripe_customer_id: customer_id,
|
||||||
},
|
},
|
||||||
timestamp: Some(timestamp),
|
timestamp: Some(timestamp),
|
||||||
},
|
})
|
||||||
)
|
.await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn checkout_with_zed_pro(
|
pub async fn checkout_with_zed_pro(
|
||||||
&self,
|
&self,
|
||||||
customer_id: stripe::CustomerId,
|
customer_id: &StripeCustomerId,
|
||||||
github_login: &str,
|
github_login: &str,
|
||||||
success_url: &str,
|
success_url: &str,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let zed_pro_price_id = self.zed_pro_price_id().await?;
|
let zed_pro_price_id = self.zed_pro_price_id().await?;
|
||||||
|
|
||||||
let mut params = stripe::CreateCheckoutSession::new();
|
let mut params = StripeCreateCheckoutSessionParams::default();
|
||||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
params.mode = Some(StripeCheckoutSessionMode::Subscription);
|
||||||
params.customer = Some(customer_id);
|
params.customer = Some(customer_id);
|
||||||
params.client_reference_id = Some(github_login);
|
params.client_reference_id = Some(github_login);
|
||||||
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
|
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||||
price: Some(zed_pro_price_id.to_string()),
|
price: Some(zed_pro_price_id.to_string()),
|
||||||
quantity: Some(1),
|
quantity: Some(1),
|
||||||
..Default::default()
|
|
||||||
}]);
|
}]);
|
||||||
params.success_url = Some(success_url);
|
params.success_url = Some(success_url);
|
||||||
|
|
||||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
let session = self.client.create_checkout_session(params).await?;
|
||||||
Ok(session.url.context("no checkout session URL")?)
|
Ok(session.url.context("no checkout session URL")?)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn checkout_with_zed_pro_trial(
|
pub async fn checkout_with_zed_pro_trial(
|
||||||
&self,
|
&self,
|
||||||
customer_id: stripe::CustomerId,
|
customer_id: &StripeCustomerId,
|
||||||
github_login: &str,
|
github_login: &str,
|
||||||
feature_flags: Vec<String>,
|
feature_flags: Vec<String>,
|
||||||
success_url: &str,
|
success_url: &str,
|
||||||
@@ -273,172 +269,75 @@ impl StripeBilling {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut params = stripe::CreateCheckoutSession::new();
|
let mut params = StripeCreateCheckoutSessionParams::default();
|
||||||
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
|
params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
|
||||||
trial_period_days: Some(trial_period_days),
|
trial_period_days: Some(trial_period_days),
|
||||||
trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
|
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||||
end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
|
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||||
missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
missing_payment_method:
|
||||||
}
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
metadata: if !subscription_metadata.is_empty() {
|
metadata: if !subscription_metadata.is_empty() {
|
||||||
Some(subscription_metadata)
|
Some(subscription_metadata)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
},
|
},
|
||||||
..Default::default()
|
|
||||||
});
|
});
|
||||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
params.mode = Some(StripeCheckoutSessionMode::Subscription);
|
||||||
params.payment_method_collection =
|
params.payment_method_collection =
|
||||||
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
|
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
|
||||||
params.customer = Some(customer_id);
|
params.customer = Some(customer_id);
|
||||||
params.client_reference_id = Some(github_login);
|
params.client_reference_id = Some(github_login);
|
||||||
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
|
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||||
price: Some(zed_pro_price_id.to_string()),
|
price: Some(zed_pro_price_id.to_string()),
|
||||||
quantity: Some(1),
|
quantity: Some(1),
|
||||||
..Default::default()
|
|
||||||
}]);
|
}]);
|
||||||
params.success_url = Some(success_url);
|
params.success_url = Some(success_url);
|
||||||
|
|
||||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
let session = self.client.create_checkout_session(params).await?;
|
||||||
Ok(session.url.context("no checkout session URL")?)
|
Ok(session.url.context("no checkout session URL")?)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn subscribe_to_zed_free(
|
pub async fn subscribe_to_zed_free(
|
||||||
&self,
|
&self,
|
||||||
customer_id: stripe::CustomerId,
|
customer_id: StripeCustomerId,
|
||||||
) -> Result<stripe::Subscription> {
|
) -> Result<StripeSubscription> {
|
||||||
let zed_free_price_id = self.zed_free_price_id().await?;
|
let zed_free_price_id = self.zed_free_price_id().await?;
|
||||||
|
|
||||||
let existing_subscriptions = stripe::Subscription::list(
|
let existing_subscriptions = self
|
||||||
&self.client,
|
.client
|
||||||
&stripe::ListSubscriptions {
|
.list_subscriptions_for_customer(&customer_id)
|
||||||
customer: Some(customer_id.clone()),
|
.await?;
|
||||||
status: None,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let existing_active_subscription =
|
let existing_active_subscription =
|
||||||
existing_subscriptions
|
existing_subscriptions.into_iter().find(|subscription| {
|
||||||
.data
|
subscription.status == SubscriptionStatus::Active
|
||||||
.into_iter()
|
|| subscription.status == SubscriptionStatus::Trialing
|
||||||
.find(|subscription| {
|
});
|
||||||
subscription.status == SubscriptionStatus::Active
|
|
||||||
|| subscription.status == SubscriptionStatus::Trialing
|
|
||||||
});
|
|
||||||
if let Some(subscription) = existing_active_subscription {
|
if let Some(subscription) = existing_active_subscription {
|
||||||
return Ok(subscription);
|
return Ok(subscription);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut params = stripe::CreateSubscription::new(customer_id);
|
let params = StripeCreateSubscriptionParams {
|
||||||
params.items = Some(vec![stripe::CreateSubscriptionItems {
|
customer: customer_id,
|
||||||
price: Some(zed_free_price_id.to_string()),
|
items: vec![StripeCreateSubscriptionItems {
|
||||||
quantity: Some(1),
|
price: Some(zed_free_price_id),
|
||||||
..Default::default()
|
quantity: Some(1),
|
||||||
}]);
|
}],
|
||||||
|
};
|
||||||
|
|
||||||
let subscription = stripe::Subscription::create(&self.client, params).await?;
|
let subscription = self.client.create_subscription(params).await?;
|
||||||
|
|
||||||
Ok(subscription)
|
Ok(subscription)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn checkout_with_zed_free(
|
|
||||||
&self,
|
|
||||||
customer_id: stripe::CustomerId,
|
|
||||||
github_login: &str,
|
|
||||||
success_url: &str,
|
|
||||||
) -> Result<String> {
|
|
||||||
let zed_free_price_id = self.zed_free_price_id().await?;
|
|
||||||
|
|
||||||
let mut params = stripe::CreateCheckoutSession::new();
|
|
||||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
|
||||||
params.payment_method_collection =
|
|
||||||
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
|
|
||||||
params.customer = Some(customer_id);
|
|
||||||
params.client_reference_id = Some(github_login);
|
|
||||||
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
|
|
||||||
price: Some(zed_free_price_id.to_string()),
|
|
||||||
quantity: Some(1),
|
|
||||||
..Default::default()
|
|
||||||
}]);
|
|
||||||
params.success_url = Some(success_url);
|
|
||||||
|
|
||||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
|
||||||
Ok(session.url.context("no checkout session URL")?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
|
||||||
struct StripeMeter {
|
|
||||||
id: String,
|
|
||||||
event_name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StripeMeter {
|
|
||||||
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct Params {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
limit: Option<u64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
client.get_query("/billing/meters", Params { limit: Some(100) })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct StripeMeterEvent {
|
|
||||||
identifier: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StripeMeterEvent {
|
|
||||||
pub async fn create(
|
|
||||||
client: &stripe::Client,
|
|
||||||
params: StripeCreateMeterEventParams<'_>,
|
|
||||||
) -> Result<Self, stripe::StripeError> {
|
|
||||||
let identifier = params.identifier;
|
|
||||||
match client.post_form("/billing/meter_events", params).await {
|
|
||||||
Ok(event) => Ok(event),
|
|
||||||
Err(stripe::StripeError::Stripe(error)) => {
|
|
||||||
if error.http_status == 400
|
|
||||||
&& error
|
|
||||||
.message
|
|
||||||
.as_ref()
|
|
||||||
.map_or(false, |message| message.contains(identifier))
|
|
||||||
{
|
|
||||||
Ok(Self {
|
|
||||||
identifier: identifier.to_string(),
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
Err(stripe::StripeError::Stripe(error))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(error) => Err(error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct StripeCreateMeterEventParams<'a> {
|
|
||||||
identifier: &'a str,
|
|
||||||
event_name: &'a str,
|
|
||||||
payload: StripeCreateMeterEventPayload<'a>,
|
|
||||||
timestamp: Option<i64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct StripeCreateMeterEventPayload<'a> {
|
|
||||||
value: u64,
|
|
||||||
stripe_customer_id: &'a stripe::CustomerId,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn subscription_contains_price(
|
fn subscription_contains_price(
|
||||||
subscription: &stripe::Subscription,
|
subscription: &StripeSubscription,
|
||||||
price_id: &stripe::PriceId,
|
price_id: &StripePriceId,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
subscription.items.data.iter().any(|item| {
|
subscription.items.iter().any(|item| {
|
||||||
item.price
|
item.price
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map_or(false, |price| price.id == *price_id)
|
.map_or(false, |price| price.id == *price_id)
|
||||||
|
|||||||
229
crates/collab/src/stripe_client.rs
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod fake_stripe_client;
|
||||||
|
mod real_stripe_client;
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub use fake_stripe_client::*;
|
||||||
|
pub use real_stripe_client::*;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)]
|
||||||
|
pub struct StripeCustomerId(pub Arc<str>);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StripeCustomer {
|
||||||
|
pub id: StripeCustomerId,
|
||||||
|
pub email: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CreateCustomerParams<'a> {
|
||||||
|
pub email: Option<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
|
||||||
|
pub struct StripeSubscriptionId(pub Arc<str>);
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct StripeSubscription {
|
||||||
|
pub id: StripeSubscriptionId,
|
||||||
|
pub customer: StripeCustomerId,
|
||||||
|
// TODO: Create our own version of this enum.
|
||||||
|
pub status: stripe::SubscriptionStatus,
|
||||||
|
pub current_period_end: i64,
|
||||||
|
pub current_period_start: i64,
|
||||||
|
pub items: Vec<StripeSubscriptionItem>,
|
||||||
|
pub cancel_at: Option<i64>,
|
||||||
|
pub cancellation_details: Option<StripeCancellationDetails>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
|
||||||
|
pub struct StripeSubscriptionItemId(pub Arc<str>);
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct StripeSubscriptionItem {
|
||||||
|
pub id: StripeSubscriptionItemId,
|
||||||
|
pub price: Option<StripePrice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct StripeCancellationDetails {
|
||||||
|
pub reason: Option<StripeCancellationDetailsReason>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub enum StripeCancellationDetailsReason {
|
||||||
|
CancellationRequested,
|
||||||
|
PaymentDisputed,
|
||||||
|
PaymentFailed,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct StripeCreateSubscriptionParams {
|
||||||
|
pub customer: StripeCustomerId,
|
||||||
|
pub items: Vec<StripeCreateSubscriptionItems>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct StripeCreateSubscriptionItems {
|
||||||
|
pub price: Option<StripePriceId>,
|
||||||
|
pub quantity: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct UpdateSubscriptionParams {
|
||||||
|
pub items: Option<Vec<UpdateSubscriptionItems>>,
|
||||||
|
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct UpdateSubscriptionItems {
|
||||||
|
pub price: Option<StripePriceId>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct StripeSubscriptionTrialSettings {
|
||||||
|
pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct StripeSubscriptionTrialSettingsEndBehavior {
|
||||||
|
pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
|
||||||
|
Cancel,
|
||||||
|
CreateInvoice,
|
||||||
|
Pause,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
|
||||||
|
pub struct StripePriceId(pub Arc<str>);
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct StripePrice {
|
||||||
|
pub id: StripePriceId,
|
||||||
|
pub unit_amount: Option<i64>,
|
||||||
|
pub lookup_key: Option<String>,
|
||||||
|
pub recurring: Option<StripePriceRecurring>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct StripePriceRecurring {
|
||||||
|
pub meter: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)]
|
||||||
|
pub struct StripeMeterId(pub Arc<str>);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct StripeMeter {
|
||||||
|
pub id: StripeMeterId,
|
||||||
|
pub event_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct StripeCreateMeterEventParams<'a> {
|
||||||
|
pub identifier: &'a str,
|
||||||
|
pub event_name: &'a str,
|
||||||
|
pub payload: StripeCreateMeterEventPayload<'a>,
|
||||||
|
pub timestamp: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct StripeCreateMeterEventPayload<'a> {
|
||||||
|
pub value: u64,
|
||||||
|
pub stripe_customer_id: &'a StripeCustomerId,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct StripeCreateCheckoutSessionParams<'a> {
|
||||||
|
pub customer: Option<&'a StripeCustomerId>,
|
||||||
|
pub client_reference_id: Option<&'a str>,
|
||||||
|
pub mode: Option<StripeCheckoutSessionMode>,
|
||||||
|
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
|
||||||
|
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
|
||||||
|
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
|
||||||
|
pub success_url: Option<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub enum StripeCheckoutSessionMode {
|
||||||
|
Payment,
|
||||||
|
Setup,
|
||||||
|
Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct StripeCreateCheckoutSessionLineItems {
|
||||||
|
pub price: Option<String>,
|
||||||
|
pub quantity: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub enum StripeCheckoutSessionPaymentMethodCollection {
|
||||||
|
Always,
|
||||||
|
IfRequired,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct StripeCreateCheckoutSessionSubscriptionData {
|
||||||
|
pub metadata: Option<HashMap<String, String>>,
|
||||||
|
pub trial_period_days: Option<u32>,
|
||||||
|
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct StripeCheckoutSession {
|
||||||
|
pub url: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait StripeClient: Send + Sync {
|
||||||
|
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
|
||||||
|
|
||||||
|
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer>;
|
||||||
|
|
||||||
|
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
|
||||||
|
|
||||||
|
async fn list_subscriptions_for_customer(
|
||||||
|
&self,
|
||||||
|
customer_id: &StripeCustomerId,
|
||||||
|
) -> Result<Vec<StripeSubscription>>;
|
||||||
|
|
||||||
|
async fn get_subscription(
|
||||||
|
&self,
|
||||||
|
subscription_id: &StripeSubscriptionId,
|
||||||
|
) -> Result<StripeSubscription>;
|
||||||
|
|
||||||
|
async fn create_subscription(
|
||||||
|
&self,
|
||||||
|
params: StripeCreateSubscriptionParams,
|
||||||
|
) -> Result<StripeSubscription>;
|
||||||
|
|
||||||
|
async fn update_subscription(
|
||||||
|
&self,
|
||||||
|
subscription_id: &StripeSubscriptionId,
|
||||||
|
params: UpdateSubscriptionParams,
|
||||||
|
) -> Result<()>;
|
||||||
|
|
||||||
|
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>;
|
||||||
|
|
||||||
|
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
|
||||||
|
|
||||||
|
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
|
||||||
|
|
||||||
|
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
|
||||||
|
|
||||||
|
async fn create_checkout_session(
|
||||||
|
&self,
|
||||||
|
params: StripeCreateCheckoutSessionParams<'_>,
|
||||||
|
) -> Result<StripeCheckoutSession>;
|
||||||
|
}
|
||||||
224
crates/collab/src/stripe_client/fake_stripe_client.rs
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{Result, anyhow};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use chrono::{Duration, Utc};
|
||||||
|
use collections::HashMap;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::stripe_client::{
|
||||||
|
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
|
||||||
|
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
|
||||||
|
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||||
|
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
|
||||||
|
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
|
||||||
|
StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
|
||||||
|
StripeSubscriptionItemId, UpdateSubscriptionParams,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StripeCreateMeterEventCall {
|
||||||
|
pub identifier: Arc<str>,
|
||||||
|
pub event_name: Arc<str>,
|
||||||
|
pub value: u64,
|
||||||
|
pub stripe_customer_id: StripeCustomerId,
|
||||||
|
pub timestamp: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StripeCreateCheckoutSessionCall {
|
||||||
|
pub customer: Option<StripeCustomerId>,
|
||||||
|
pub client_reference_id: Option<String>,
|
||||||
|
pub mode: Option<StripeCheckoutSessionMode>,
|
||||||
|
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
|
||||||
|
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
|
||||||
|
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
|
||||||
|
pub success_url: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FakeStripeClient {
|
||||||
|
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
|
||||||
|
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
|
||||||
|
pub update_subscription_calls:
|
||||||
|
Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
|
||||||
|
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
|
||||||
|
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
|
||||||
|
pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
|
||||||
|
pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeStripeClient {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
customers: Arc::new(Mutex::new(HashMap::default())),
|
||||||
|
subscriptions: Arc::new(Mutex::new(HashMap::default())),
|
||||||
|
update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
|
||||||
|
prices: Arc::new(Mutex::new(HashMap::default())),
|
||||||
|
meters: Arc::new(Mutex::new(HashMap::default())),
|
||||||
|
create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
|
||||||
|
create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StripeClient for FakeStripeClient {
|
||||||
|
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
|
||||||
|
Ok(self
|
||||||
|
.customers
|
||||||
|
.lock()
|
||||||
|
.values()
|
||||||
|
.filter(|customer| customer.email.as_deref() == Some(email))
|
||||||
|
.cloned()
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
|
||||||
|
self.customers
|
||||||
|
.lock()
|
||||||
|
.get(customer_id)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
|
||||||
|
let customer = StripeCustomer {
|
||||||
|
id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
|
||||||
|
email: params.email.map(|email| email.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.customers
|
||||||
|
.lock()
|
||||||
|
.insert(customer.id.clone(), customer.clone());
|
||||||
|
|
||||||
|
Ok(customer)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_subscriptions_for_customer(
|
||||||
|
&self,
|
||||||
|
customer_id: &StripeCustomerId,
|
||||||
|
) -> Result<Vec<StripeSubscription>> {
|
||||||
|
let subscriptions = self
|
||||||
|
.subscriptions
|
||||||
|
.lock()
|
||||||
|
.values()
|
||||||
|
.filter(|subscription| subscription.customer == *customer_id)
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(subscriptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_subscription(
|
||||||
|
&self,
|
||||||
|
subscription_id: &StripeSubscriptionId,
|
||||||
|
) -> Result<StripeSubscription> {
|
||||||
|
self.subscriptions
|
||||||
|
.lock()
|
||||||
|
.get(subscription_id)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_subscription(
|
||||||
|
&self,
|
||||||
|
params: StripeCreateSubscriptionParams,
|
||||||
|
) -> Result<StripeSubscription> {
|
||||||
|
let now = Utc::now();
|
||||||
|
|
||||||
|
let subscription = StripeSubscription {
|
||||||
|
id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
|
||||||
|
customer: params.customer,
|
||||||
|
status: stripe::SubscriptionStatus::Active,
|
||||||
|
current_period_start: now.timestamp(),
|
||||||
|
current_period_end: (now + Duration::days(30)).timestamp(),
|
||||||
|
items: params
|
||||||
|
.items
|
||||||
|
.into_iter()
|
||||||
|
.map(|item| StripeSubscriptionItem {
|
||||||
|
id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
|
||||||
|
price: item
|
||||||
|
.price
|
||||||
|
.and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
cancel_at: None,
|
||||||
|
cancellation_details: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.subscriptions
|
||||||
|
.lock()
|
||||||
|
.insert(subscription.id.clone(), subscription.clone());
|
||||||
|
|
||||||
|
Ok(subscription)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_subscription(
|
||||||
|
&self,
|
||||||
|
subscription_id: &StripeSubscriptionId,
|
||||||
|
params: UpdateSubscriptionParams,
|
||||||
|
) -> Result<()> {
|
||||||
|
let subscription = self.get_subscription(subscription_id).await?;
|
||||||
|
|
||||||
|
self.update_subscription_calls
|
||||||
|
.lock()
|
||||||
|
.push((subscription.id, params));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
|
||||||
|
// TODO: Implement fake subscription cancellation.
|
||||||
|
let _ = subscription_id;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
|
||||||
|
let prices = self.prices.lock().values().cloned().collect();
|
||||||
|
|
||||||
|
Ok(prices)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
|
||||||
|
let meters = self.meters.lock().values().cloned().collect();
|
||||||
|
|
||||||
|
Ok(meters)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
|
||||||
|
self.create_meter_event_calls
|
||||||
|
.lock()
|
||||||
|
.push(StripeCreateMeterEventCall {
|
||||||
|
identifier: params.identifier.into(),
|
||||||
|
event_name: params.event_name.into(),
|
||||||
|
value: params.payload.value,
|
||||||
|
stripe_customer_id: params.payload.stripe_customer_id.clone(),
|
||||||
|
timestamp: params.timestamp,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_checkout_session(
|
||||||
|
&self,
|
||||||
|
params: StripeCreateCheckoutSessionParams<'_>,
|
||||||
|
) -> Result<StripeCheckoutSession> {
|
||||||
|
self.create_checkout_session_calls
|
||||||
|
.lock()
|
||||||
|
.push(StripeCreateCheckoutSessionCall {
|
||||||
|
customer: params.customer.cloned(),
|
||||||
|
client_reference_id: params.client_reference_id.map(|id| id.to_string()),
|
||||||
|
mode: params.mode,
|
||||||
|
line_items: params.line_items,
|
||||||
|
payment_method_collection: params.payment_method_collection,
|
||||||
|
subscription_data: params.subscription_data,
|
||||||
|
success_url: params.success_url.map(|url| url.to_string()),
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(StripeCheckoutSession {
|
||||||
|
url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
500
crates/collab/src/stripe_client/real_stripe_client.rs
Normal file
@@ -0,0 +1,500 @@
|
|||||||
|
use std::str::FromStr as _;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::Serialize;
|
||||||
|
use stripe::{
|
||||||
|
CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode,
|
||||||
|
CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems,
|
||||||
|
CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings,
|
||||||
|
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
|
||||||
|
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||||
|
CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
|
||||||
|
SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems,
|
||||||
|
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
|
||||||
|
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::stripe_client::{
|
||||||
|
CreateCustomerParams, StripeCancellationDetails, StripeCancellationDetailsReason,
|
||||||
|
StripeCheckoutSession, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
|
||||||
|
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||||
|
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
|
||||||
|
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
|
||||||
|
StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
|
||||||
|
StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
|
||||||
|
StripeSubscriptionTrialSettingsEndBehavior,
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct RealStripeClient {
|
||||||
|
client: Arc<stripe::Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RealStripeClient {
|
||||||
|
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||||
|
Self { client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StripeClient for RealStripeClient {
|
||||||
|
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
|
||||||
|
let response = Customer::list(
|
||||||
|
&self.client,
|
||||||
|
&ListCustomers {
|
||||||
|
email: Some(email),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(response
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(StripeCustomer::from)
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
|
||||||
|
let customer_id = customer_id.try_into()?;
|
||||||
|
|
||||||
|
let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?;
|
||||||
|
|
||||||
|
Ok(StripeCustomer::from(customer))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
|
||||||
|
let customer = Customer::create(
|
||||||
|
&self.client,
|
||||||
|
CreateCustomer {
|
||||||
|
email: params.email,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(StripeCustomer::from(customer))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_subscriptions_for_customer(
|
||||||
|
&self,
|
||||||
|
customer_id: &StripeCustomerId,
|
||||||
|
) -> Result<Vec<StripeSubscription>> {
|
||||||
|
let customer_id = customer_id.try_into()?;
|
||||||
|
|
||||||
|
let subscriptions = stripe::Subscription::list(
|
||||||
|
&self.client,
|
||||||
|
&stripe::ListSubscriptions {
|
||||||
|
customer: Some(customer_id),
|
||||||
|
status: None,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(subscriptions
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(StripeSubscription::from)
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_subscription(
|
||||||
|
&self,
|
||||||
|
subscription_id: &StripeSubscriptionId,
|
||||||
|
) -> Result<StripeSubscription> {
|
||||||
|
let subscription_id = subscription_id.try_into()?;
|
||||||
|
|
||||||
|
let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||||
|
|
||||||
|
Ok(StripeSubscription::from(subscription))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_subscription(
|
||||||
|
&self,
|
||||||
|
params: StripeCreateSubscriptionParams,
|
||||||
|
) -> Result<StripeSubscription> {
|
||||||
|
let customer_id = params.customer.try_into()?;
|
||||||
|
|
||||||
|
let mut create_subscription = stripe::CreateSubscription::new(customer_id);
|
||||||
|
create_subscription.items = Some(
|
||||||
|
params
|
||||||
|
.items
|
||||||
|
.into_iter()
|
||||||
|
.map(|item| stripe::CreateSubscriptionItems {
|
||||||
|
price: item.price.map(|price| price.to_string()),
|
||||||
|
quantity: item.quantity,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let subscription = Subscription::create(&self.client, create_subscription).await?;
|
||||||
|
|
||||||
|
Ok(StripeSubscription::from(subscription))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_subscription(
|
||||||
|
&self,
|
||||||
|
subscription_id: &StripeSubscriptionId,
|
||||||
|
params: UpdateSubscriptionParams,
|
||||||
|
) -> Result<()> {
|
||||||
|
let subscription_id = subscription_id.try_into()?;
|
||||||
|
|
||||||
|
stripe::Subscription::update(
|
||||||
|
&self.client,
|
||||||
|
&subscription_id,
|
||||||
|
stripe::UpdateSubscription {
|
||||||
|
items: params.items.map(|items| {
|
||||||
|
items
|
||||||
|
.into_iter()
|
||||||
|
.map(|item| UpdateSubscriptionItems {
|
||||||
|
price: item.price.map(|price| price.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}),
|
||||||
|
trial_settings: params.trial_settings.map(Into::into),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
|
||||||
|
let subscription_id = subscription_id.try_into()?;
|
||||||
|
|
||||||
|
Subscription::cancel(
|
||||||
|
&self.client,
|
||||||
|
&subscription_id,
|
||||||
|
stripe::CancelSubscription {
|
||||||
|
invoice_now: None,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
|
||||||
|
let response = stripe::Price::list(
|
||||||
|
&self.client,
|
||||||
|
&stripe::ListPrices {
|
||||||
|
limit: Some(100),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(response.data.into_iter().map(StripePrice::from).collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct Params {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
limit: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.get_query::<stripe::List<StripeMeter>, _>(
|
||||||
|
"/billing/meters",
|
||||||
|
Params { limit: Some(100) },
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(response.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
|
||||||
|
let identifier = params.identifier;
|
||||||
|
match self.client.post_form("/billing/meter_events", params).await {
|
||||||
|
Ok(event) => Ok(event),
|
||||||
|
Err(stripe::StripeError::Stripe(error)) => {
|
||||||
|
if error.http_status == 400
|
||||||
|
&& error
|
||||||
|
.message
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |message| message.contains(identifier))
|
||||||
|
{
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!(stripe::StripeError::Stripe(error)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) => Err(anyhow!(error)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_checkout_session(
|
||||||
|
&self,
|
||||||
|
params: StripeCreateCheckoutSessionParams<'_>,
|
||||||
|
) -> Result<StripeCheckoutSession> {
|
||||||
|
let params = params.try_into()?;
|
||||||
|
let session = CheckoutSession::create(&self.client, params).await?;
|
||||||
|
|
||||||
|
Ok(session.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<CustomerId> for StripeCustomerId {
|
||||||
|
fn from(value: CustomerId) -> Self {
|
||||||
|
Self(value.as_str().into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<StripeCustomerId> for CustomerId {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
|
||||||
|
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&StripeCustomerId> for CustomerId {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
|
||||||
|
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Customer> for StripeCustomer {
|
||||||
|
fn from(value: Customer) -> Self {
|
||||||
|
StripeCustomer {
|
||||||
|
id: value.id.into(),
|
||||||
|
email: value.email,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SubscriptionId> for StripeSubscriptionId {
|
||||||
|
fn from(value: SubscriptionId) -> Self {
|
||||||
|
Self(value.as_str().into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&StripeSubscriptionId> for SubscriptionId {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
fn try_from(value: &StripeSubscriptionId) -> Result<Self, Self::Error> {
|
||||||
|
Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Subscription> for StripeSubscription {
|
||||||
|
fn from(value: Subscription) -> Self {
|
||||||
|
Self {
|
||||||
|
id: value.id.into(),
|
||||||
|
customer: value.customer.id().into(),
|
||||||
|
status: value.status,
|
||||||
|
current_period_start: value.current_period_start,
|
||||||
|
current_period_end: value.current_period_end,
|
||||||
|
items: value.items.data.into_iter().map(Into::into).collect(),
|
||||||
|
cancel_at: value.cancel_at,
|
||||||
|
cancellation_details: value.cancellation_details.map(Into::into),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<CancellationDetails> for StripeCancellationDetails {
|
||||||
|
fn from(value: CancellationDetails) -> Self {
|
||||||
|
Self {
|
||||||
|
reason: value.reason.map(Into::into),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<CancellationDetailsReason> for StripeCancellationDetailsReason {
|
||||||
|
fn from(value: CancellationDetailsReason) -> Self {
|
||||||
|
match value {
|
||||||
|
CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
|
||||||
|
CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
|
||||||
|
CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SubscriptionItemId> for StripeSubscriptionItemId {
|
||||||
|
fn from(value: SubscriptionItemId) -> Self {
|
||||||
|
Self(value.as_str().into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SubscriptionItem> for StripeSubscriptionItem {
|
||||||
|
fn from(value: SubscriptionItem) -> Self {
|
||||||
|
Self {
|
||||||
|
id: value.id.into(),
|
||||||
|
price: value.price.map(Into::into),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
|
||||||
|
fn from(value: StripeSubscriptionTrialSettings) -> Self {
|
||||||
|
Self {
|
||||||
|
end_behavior: value.end_behavior.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StripeSubscriptionTrialSettingsEndBehavior>
|
||||||
|
for UpdateSubscriptionTrialSettingsEndBehavior
|
||||||
|
{
|
||||||
|
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
|
||||||
|
Self {
|
||||||
|
missing_payment_method: value.missing_payment_method.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
|
||||||
|
for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
|
||||||
|
{
|
||||||
|
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
|
||||||
|
match value {
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
|
||||||
|
Self::CreateInvoice
|
||||||
|
}
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<PriceId> for StripePriceId {
|
||||||
|
fn from(value: PriceId) -> Self {
|
||||||
|
Self(value.as_str().into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<StripePriceId> for PriceId {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
|
||||||
|
Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Price> for StripePrice {
|
||||||
|
fn from(value: Price) -> Self {
|
||||||
|
Self {
|
||||||
|
id: value.id.into(),
|
||||||
|
unit_amount: value.unit_amount,
|
||||||
|
lookup_key: value.lookup_key,
|
||||||
|
recurring: value.recurring.map(StripePriceRecurring::from),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Recurring> for StripePriceRecurring {
|
||||||
|
fn from(value: Recurring) -> Self {
|
||||||
|
Self { meter: value.meter }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
|
||||||
|
Ok(Self {
|
||||||
|
customer: value
|
||||||
|
.customer
|
||||||
|
.map(|customer_id| customer_id.try_into())
|
||||||
|
.transpose()?,
|
||||||
|
client_reference_id: value.client_reference_id,
|
||||||
|
mode: value.mode.map(Into::into),
|
||||||
|
line_items: value
|
||||||
|
.line_items
|
||||||
|
.map(|line_items| line_items.into_iter().map(Into::into).collect()),
|
||||||
|
payment_method_collection: value.payment_method_collection.map(Into::into),
|
||||||
|
subscription_data: value.subscription_data.map(Into::into),
|
||||||
|
success_url: value.success_url,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
|
||||||
|
fn from(value: StripeCheckoutSessionMode) -> Self {
|
||||||
|
match value {
|
||||||
|
StripeCheckoutSessionMode::Payment => Self::Payment,
|
||||||
|
StripeCheckoutSessionMode::Setup => Self::Setup,
|
||||||
|
StripeCheckoutSessionMode::Subscription => Self::Subscription,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
|
||||||
|
fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
|
||||||
|
Self {
|
||||||
|
price: value.price,
|
||||||
|
quantity: value.quantity,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
|
||||||
|
fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
|
||||||
|
match value {
|
||||||
|
StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
|
||||||
|
StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
|
||||||
|
fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
|
||||||
|
Self {
|
||||||
|
trial_period_days: value.trial_period_days,
|
||||||
|
trial_settings: value.trial_settings.map(Into::into),
|
||||||
|
metadata: value.metadata,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
|
||||||
|
fn from(value: StripeSubscriptionTrialSettings) -> Self {
|
||||||
|
Self {
|
||||||
|
end_behavior: value.end_behavior.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StripeSubscriptionTrialSettingsEndBehavior>
|
||||||
|
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
|
||||||
|
{
|
||||||
|
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
|
||||||
|
Self {
|
||||||
|
missing_payment_method: value.missing_payment_method.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
|
||||||
|
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
|
||||||
|
{
|
||||||
|
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
|
||||||
|
match value {
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
|
||||||
|
Self::CreateInvoice
|
||||||
|
}
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<CheckoutSession> for StripeCheckoutSession {
|
||||||
|
fn from(value: CheckoutSession) -> Self {
|
||||||
|
Self { url: value.url }
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ mod random_channel_buffer_tests;
|
|||||||
mod random_project_collaboration_tests;
|
mod random_project_collaboration_tests;
|
||||||
mod randomized_test_helpers;
|
mod randomized_test_helpers;
|
||||||
mod remote_editing_collaboration_tests;
|
mod remote_editing_collaboration_tests;
|
||||||
|
mod stripe_billing_tests;
|
||||||
mod test_server;
|
mod test_server;
|
||||||
|
|
||||||
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||||
|
|||||||
@@ -1010,7 +1010,6 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T
|
|||||||
workspace_b.update_in(cx_b, |workspace, window, cx| {
|
workspace_b.update_in(cx_b, |workspace, window, cx| {
|
||||||
workspace.active_pane().update(cx, |pane, cx| {
|
workspace.active_pane().update(cx, |pane, cx| {
|
||||||
pane.close_inactive_items(&Default::default(), window, cx)
|
pane.close_inactive_items(&Default::default(), window, cx)
|
||||||
.unwrap()
|
|
||||||
.detach();
|
.detach();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
565
crates/collab/src/tests/stripe_billing_tests.rs
Normal file
@@ -0,0 +1,565 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use chrono::{Duration, Utc};
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
|
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||||
|
use crate::stripe_billing::StripeBilling;
|
||||||
|
use crate::stripe_client::{
|
||||||
|
FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
|
||||||
|
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
|
||||||
|
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
|
||||||
|
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
|
||||||
|
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
||||||
|
let stripe_client = Arc::new(FakeStripeClient::new());
|
||||||
|
let stripe_billing = StripeBilling::test(stripe_client.clone());
|
||||||
|
|
||||||
|
(stripe_billing, stripe_client)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_initialize() {
|
||||||
|
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||||
|
|
||||||
|
// Add test meters
|
||||||
|
let meter1 = StripeMeter {
|
||||||
|
id: StripeMeterId("meter_1".into()),
|
||||||
|
event_name: "event_1".to_string(),
|
||||||
|
};
|
||||||
|
let meter2 = StripeMeter {
|
||||||
|
id: StripeMeterId("meter_2".into()),
|
||||||
|
event_name: "event_2".to_string(),
|
||||||
|
};
|
||||||
|
stripe_client
|
||||||
|
.meters
|
||||||
|
.lock()
|
||||||
|
.insert(meter1.id.clone(), meter1);
|
||||||
|
stripe_client
|
||||||
|
.meters
|
||||||
|
.lock()
|
||||||
|
.insert(meter2.id.clone(), meter2);
|
||||||
|
|
||||||
|
// Add test prices
|
||||||
|
let price1 = StripePrice {
|
||||||
|
id: StripePriceId("price_1".into()),
|
||||||
|
unit_amount: Some(1_000),
|
||||||
|
lookup_key: Some("zed-pro".to_string()),
|
||||||
|
recurring: None,
|
||||||
|
};
|
||||||
|
let price2 = StripePrice {
|
||||||
|
id: StripePriceId("price_2".into()),
|
||||||
|
unit_amount: Some(0),
|
||||||
|
lookup_key: Some("zed-free".to_string()),
|
||||||
|
recurring: None,
|
||||||
|
};
|
||||||
|
let price3 = StripePrice {
|
||||||
|
id: StripePriceId("price_3".into()),
|
||||||
|
unit_amount: Some(500),
|
||||||
|
lookup_key: None,
|
||||||
|
recurring: Some(StripePriceRecurring {
|
||||||
|
meter: Some("meter_1".to_string()),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
stripe_client
|
||||||
|
.prices
|
||||||
|
.lock()
|
||||||
|
.insert(price1.id.clone(), price1);
|
||||||
|
stripe_client
|
||||||
|
.prices
|
||||||
|
.lock()
|
||||||
|
.insert(price2.id.clone(), price2);
|
||||||
|
stripe_client
|
||||||
|
.prices
|
||||||
|
.lock()
|
||||||
|
.insert(price3.id.clone(), price3);
|
||||||
|
|
||||||
|
// Initialize the billing system
|
||||||
|
stripe_billing.initialize().await.unwrap();
|
||||||
|
|
||||||
|
// Verify that prices can be found by lookup key
|
||||||
|
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
|
||||||
|
assert_eq!(zed_pro_price_id.to_string(), "price_1");
|
||||||
|
|
||||||
|
let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
|
||||||
|
assert_eq!(zed_free_price_id.to_string(), "price_2");
|
||||||
|
|
||||||
|
// Verify that a price can be found by lookup key
|
||||||
|
let zed_pro_price = stripe_billing
|
||||||
|
.find_price_by_lookup_key("zed-pro")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(zed_pro_price.id.to_string(), "price_1");
|
||||||
|
assert_eq!(zed_pro_price.unit_amount, Some(1_000));
|
||||||
|
|
||||||
|
// Verify that finding a non-existent lookup key returns an error
|
||||||
|
let result = stripe_billing
|
||||||
|
.find_price_by_lookup_key("non-existent")
|
||||||
|
.await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_find_or_create_customer_by_email() {
|
||||||
|
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||||
|
|
||||||
|
// Create a customer with an email that doesn't yet correspond to a customer.
|
||||||
|
{
|
||||||
|
let email = "user@example.com";
|
||||||
|
|
||||||
|
let customer_id = stripe_billing
|
||||||
|
.find_or_create_customer_by_email(Some(email))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let customer = stripe_client
|
||||||
|
.customers
|
||||||
|
.lock()
|
||||||
|
.get(&customer_id)
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
|
assert_eq!(customer.email.as_deref(), Some(email));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a customer with an email that corresponds to an existing customer.
|
||||||
|
{
|
||||||
|
let email = "user2@example.com";
|
||||||
|
|
||||||
|
let existing_customer_id = stripe_billing
|
||||||
|
.find_or_create_customer_by_email(Some(email))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let customer_id = stripe_billing
|
||||||
|
.find_or_create_customer_by_email(Some(email))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(customer_id, existing_customer_id);
|
||||||
|
|
||||||
|
let customer = stripe_client
|
||||||
|
.customers
|
||||||
|
.lock()
|
||||||
|
.get(&customer_id)
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
|
assert_eq!(customer.email.as_deref(), Some(email));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_subscribe_to_price() {
|
||||||
|
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||||
|
|
||||||
|
let price = StripePrice {
|
||||||
|
id: StripePriceId("price_test".into()),
|
||||||
|
unit_amount: Some(2000),
|
||||||
|
lookup_key: Some("test-price".to_string()),
|
||||||
|
recurring: None,
|
||||||
|
};
|
||||||
|
stripe_client
|
||||||
|
.prices
|
||||||
|
.lock()
|
||||||
|
.insert(price.id.clone(), price.clone());
|
||||||
|
|
||||||
|
let now = Utc::now();
|
||||||
|
let subscription = StripeSubscription {
|
||||||
|
id: StripeSubscriptionId("sub_test".into()),
|
||||||
|
customer: StripeCustomerId("cus_test".into()),
|
||||||
|
status: stripe::SubscriptionStatus::Active,
|
||||||
|
current_period_start: now.timestamp(),
|
||||||
|
current_period_end: (now + Duration::days(30)).timestamp(),
|
||||||
|
items: vec![],
|
||||||
|
cancel_at: None,
|
||||||
|
cancellation_details: None,
|
||||||
|
};
|
||||||
|
stripe_client
|
||||||
|
.subscriptions
|
||||||
|
.lock()
|
||||||
|
.insert(subscription.id.clone(), subscription.clone());
|
||||||
|
|
||||||
|
stripe_billing
|
||||||
|
.subscribe_to_price(&subscription.id, &price)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let update_subscription_calls = stripe_client
|
||||||
|
.update_subscription_calls
|
||||||
|
.lock()
|
||||||
|
.iter()
|
||||||
|
.map(|(id, params)| (id.clone(), params.clone()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert_eq!(update_subscription_calls.len(), 1);
|
||||||
|
assert_eq!(update_subscription_calls[0].0, subscription.id);
|
||||||
|
assert_eq!(
|
||||||
|
update_subscription_calls[0].1.items,
|
||||||
|
Some(vec![UpdateSubscriptionItems {
|
||||||
|
price: Some(price.id.clone())
|
||||||
|
}])
|
||||||
|
);
|
||||||
|
|
||||||
|
// Subscribing to a price that is already on the subscription is a no-op.
|
||||||
|
{
|
||||||
|
let now = Utc::now();
|
||||||
|
let subscription = StripeSubscription {
|
||||||
|
id: StripeSubscriptionId("sub_test".into()),
|
||||||
|
customer: StripeCustomerId("cus_test".into()),
|
||||||
|
status: stripe::SubscriptionStatus::Active,
|
||||||
|
current_period_start: now.timestamp(),
|
||||||
|
current_period_end: (now + Duration::days(30)).timestamp(),
|
||||||
|
items: vec![StripeSubscriptionItem {
|
||||||
|
id: StripeSubscriptionItemId("si_test".into()),
|
||||||
|
price: Some(price.clone()),
|
||||||
|
}],
|
||||||
|
cancel_at: None,
|
||||||
|
cancellation_details: None,
|
||||||
|
};
|
||||||
|
stripe_client
|
||||||
|
.subscriptions
|
||||||
|
.lock()
|
||||||
|
.insert(subscription.id.clone(), subscription.clone());
|
||||||
|
|
||||||
|
stripe_billing
|
||||||
|
.subscribe_to_price(&subscription.id, &price)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_subscribe_to_zed_free() {
|
||||||
|
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||||
|
|
||||||
|
let zed_pro_price = StripePrice {
|
||||||
|
id: StripePriceId("price_1".into()),
|
||||||
|
unit_amount: Some(0),
|
||||||
|
lookup_key: Some("zed-pro".to_string()),
|
||||||
|
recurring: None,
|
||||||
|
};
|
||||||
|
stripe_client
|
||||||
|
.prices
|
||||||
|
.lock()
|
||||||
|
.insert(zed_pro_price.id.clone(), zed_pro_price.clone());
|
||||||
|
let zed_free_price = StripePrice {
|
||||||
|
id: StripePriceId("price_2".into()),
|
||||||
|
unit_amount: Some(0),
|
||||||
|
lookup_key: Some("zed-free".to_string()),
|
||||||
|
recurring: None,
|
||||||
|
};
|
||||||
|
stripe_client
|
||||||
|
.prices
|
||||||
|
.lock()
|
||||||
|
.insert(zed_free_price.id.clone(), zed_free_price.clone());
|
||||||
|
|
||||||
|
stripe_billing.initialize().await.unwrap();
|
||||||
|
|
||||||
|
// Customer is subscribed to Zed Free when not already subscribed to a plan.
|
||||||
|
{
|
||||||
|
let customer_id = StripeCustomerId("cus_no_plan".into());
|
||||||
|
|
||||||
|
let subscription = stripe_billing
|
||||||
|
.subscribe_to_zed_free(customer_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Customer is not subscribed to Zed Free when they already have an active subscription.
|
||||||
|
{
|
||||||
|
let customer_id = StripeCustomerId("cus_active_subscription".into());
|
||||||
|
|
||||||
|
let now = Utc::now();
|
||||||
|
let existing_subscription = StripeSubscription {
|
||||||
|
id: StripeSubscriptionId("sub_existing_active".into()),
|
||||||
|
customer: customer_id.clone(),
|
||||||
|
status: stripe::SubscriptionStatus::Active,
|
||||||
|
current_period_start: now.timestamp(),
|
||||||
|
current_period_end: (now + Duration::days(30)).timestamp(),
|
||||||
|
items: vec![StripeSubscriptionItem {
|
||||||
|
id: StripeSubscriptionItemId("si_test".into()),
|
||||||
|
price: Some(zed_pro_price.clone()),
|
||||||
|
}],
|
||||||
|
cancel_at: None,
|
||||||
|
cancellation_details: None,
|
||||||
|
};
|
||||||
|
stripe_client.subscriptions.lock().insert(
|
||||||
|
existing_subscription.id.clone(),
|
||||||
|
existing_subscription.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let subscription = stripe_billing
|
||||||
|
.subscribe_to_zed_free(customer_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(subscription, existing_subscription);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Customer is not subscribed to Zed Free when they already have a trial subscription.
|
||||||
|
{
|
||||||
|
let customer_id = StripeCustomerId("cus_trial_subscription".into());
|
||||||
|
|
||||||
|
let now = Utc::now();
|
||||||
|
let existing_subscription = StripeSubscription {
|
||||||
|
id: StripeSubscriptionId("sub_existing_trial".into()),
|
||||||
|
customer: customer_id.clone(),
|
||||||
|
status: stripe::SubscriptionStatus::Trialing,
|
||||||
|
current_period_start: now.timestamp(),
|
||||||
|
current_period_end: (now + Duration::days(14)).timestamp(),
|
||||||
|
items: vec![StripeSubscriptionItem {
|
||||||
|
id: StripeSubscriptionItemId("si_test".into()),
|
||||||
|
price: Some(zed_pro_price.clone()),
|
||||||
|
}],
|
||||||
|
cancel_at: None,
|
||||||
|
cancellation_details: None,
|
||||||
|
};
|
||||||
|
stripe_client.subscriptions.lock().insert(
|
||||||
|
existing_subscription.id.clone(),
|
||||||
|
existing_subscription.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let subscription = stripe_billing
|
||||||
|
.subscribe_to_zed_free(customer_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(subscription, existing_subscription);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_bill_model_request_usage() {
|
||||||
|
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||||
|
|
||||||
|
let customer_id = StripeCustomerId("cus_test".into());
|
||||||
|
|
||||||
|
stripe_billing
|
||||||
|
.bill_model_request_usage(&customer_id, "some_model/requests", 73)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let create_meter_event_calls = stripe_client
|
||||||
|
.create_meter_event_calls
|
||||||
|
.lock()
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert_eq!(create_meter_event_calls.len(), 1);
|
||||||
|
assert!(
|
||||||
|
create_meter_event_calls[0]
|
||||||
|
.identifier
|
||||||
|
.starts_with("model_requests/")
|
||||||
|
);
|
||||||
|
assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
|
||||||
|
assert_eq!(
|
||||||
|
create_meter_event_calls[0].event_name.as_ref(),
|
||||||
|
"some_model/requests"
|
||||||
|
);
|
||||||
|
assert_eq!(create_meter_event_calls[0].value, 73);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_checkout_with_zed_pro() {
|
||||||
|
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||||
|
|
||||||
|
let customer_id = StripeCustomerId("cus_test".into());
|
||||||
|
let github_login = "zeduser1";
|
||||||
|
let success_url = "https://example.com/success";
|
||||||
|
|
||||||
|
// It returns an error when the Zed Pro price doesn't exist.
|
||||||
|
{
|
||||||
|
let result = stripe_billing
|
||||||
|
.checkout_with_zed_pro(&customer_id, github_login, success_url)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert_eq!(
|
||||||
|
result.err().unwrap().to_string(),
|
||||||
|
r#"no price ID found for "zed-pro""#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Successful checkout.
|
||||||
|
{
|
||||||
|
let price = StripePrice {
|
||||||
|
id: StripePriceId("price_1".into()),
|
||||||
|
unit_amount: Some(2000),
|
||||||
|
lookup_key: Some("zed-pro".to_string()),
|
||||||
|
recurring: None,
|
||||||
|
};
|
||||||
|
stripe_client
|
||||||
|
.prices
|
||||||
|
.lock()
|
||||||
|
.insert(price.id.clone(), price.clone());
|
||||||
|
|
||||||
|
stripe_billing.initialize().await.unwrap();
|
||||||
|
|
||||||
|
let checkout_url = stripe_billing
|
||||||
|
.checkout_with_zed_pro(&customer_id, github_login, success_url)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
|
||||||
|
|
||||||
|
let create_checkout_session_calls = stripe_client
|
||||||
|
.create_checkout_session_calls
|
||||||
|
.lock()
|
||||||
|
.drain(..)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert_eq!(create_checkout_session_calls.len(), 1);
|
||||||
|
let call = create_checkout_session_calls.into_iter().next().unwrap();
|
||||||
|
assert_eq!(call.customer, Some(customer_id));
|
||||||
|
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
|
||||||
|
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
|
||||||
|
assert_eq!(
|
||||||
|
call.line_items,
|
||||||
|
Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||||
|
price: Some(price.id.to_string()),
|
||||||
|
quantity: Some(1)
|
||||||
|
}])
|
||||||
|
);
|
||||||
|
assert_eq!(call.payment_method_collection, None);
|
||||||
|
assert_eq!(call.subscription_data, None);
|
||||||
|
assert_eq!(call.success_url.as_deref(), Some(success_url));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_checkout_with_zed_pro_trial() {
|
||||||
|
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||||
|
|
||||||
|
let customer_id = StripeCustomerId("cus_test".into());
|
||||||
|
let github_login = "zeduser1";
|
||||||
|
let success_url = "https://example.com/success";
|
||||||
|
|
||||||
|
// It returns an error when the Zed Pro price doesn't exist.
|
||||||
|
{
|
||||||
|
let result = stripe_billing
|
||||||
|
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert_eq!(
|
||||||
|
result.err().unwrap().to_string(),
|
||||||
|
r#"no price ID found for "zed-pro""#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let price = StripePrice {
|
||||||
|
id: StripePriceId("price_1".into()),
|
||||||
|
unit_amount: Some(2000),
|
||||||
|
lookup_key: Some("zed-pro".to_string()),
|
||||||
|
recurring: None,
|
||||||
|
};
|
||||||
|
stripe_client
|
||||||
|
.prices
|
||||||
|
.lock()
|
||||||
|
.insert(price.id.clone(), price.clone());
|
||||||
|
|
||||||
|
stripe_billing.initialize().await.unwrap();
|
||||||
|
|
||||||
|
// Successful checkout.
|
||||||
|
{
|
||||||
|
let checkout_url = stripe_billing
|
||||||
|
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
|
||||||
|
|
||||||
|
let create_checkout_session_calls = stripe_client
|
||||||
|
.create_checkout_session_calls
|
||||||
|
.lock()
|
||||||
|
.drain(..)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert_eq!(create_checkout_session_calls.len(), 1);
|
||||||
|
let call = create_checkout_session_calls.into_iter().next().unwrap();
|
||||||
|
assert_eq!(call.customer.as_ref(), Some(&customer_id));
|
||||||
|
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
|
||||||
|
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
|
||||||
|
assert_eq!(
|
||||||
|
call.line_items,
|
||||||
|
Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||||
|
price: Some(price.id.to_string()),
|
||||||
|
quantity: Some(1)
|
||||||
|
}])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
call.payment_method_collection,
|
||||||
|
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
call.subscription_data,
|
||||||
|
Some(StripeCreateCheckoutSessionSubscriptionData {
|
||||||
|
trial_period_days: Some(14),
|
||||||
|
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||||
|
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||||
|
missing_payment_method:
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
metadata: None,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
assert_eq!(call.success_url.as_deref(), Some(success_url));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Successful checkout with extended trial.
|
||||||
|
{
|
||||||
|
let checkout_url = stripe_billing
|
||||||
|
.checkout_with_zed_pro_trial(
|
||||||
|
&customer_id,
|
||||||
|
github_login,
|
||||||
|
vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
|
||||||
|
success_url,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
|
||||||
|
|
||||||
|
let create_checkout_session_calls = stripe_client
|
||||||
|
.create_checkout_session_calls
|
||||||
|
.lock()
|
||||||
|
.drain(..)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert_eq!(create_checkout_session_calls.len(), 1);
|
||||||
|
let call = create_checkout_session_calls.into_iter().next().unwrap();
|
||||||
|
assert_eq!(call.customer, Some(customer_id));
|
||||||
|
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
|
||||||
|
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
|
||||||
|
assert_eq!(
|
||||||
|
call.line_items,
|
||||||
|
Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||||
|
price: Some(price.id.to_string()),
|
||||||
|
quantity: Some(1)
|
||||||
|
}])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
call.payment_method_collection,
|
||||||
|
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
call.subscription_data,
|
||||||
|
Some(StripeCreateCheckoutSessionSubscriptionData {
|
||||||
|
trial_period_days: Some(60),
|
||||||
|
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||||
|
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||||
|
missing_payment_method:
|
||||||
|
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
metadata: Some(std::collections::HashMap::from_iter([(
|
||||||
|
"promo_feature_flag".into(),
|
||||||
|
AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
|
||||||
|
)])),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
assert_eq!(call.success_url.as_deref(), Some(success_url));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use crate::stripe_client::FakeStripeClient;
|
||||||
use crate::{
|
use crate::{
|
||||||
AppState, Config,
|
AppState, Config,
|
||||||
db::{NewUserParams, UserId, tests::TestDb},
|
db::{NewUserParams, UserId, tests::TestDb},
|
||||||
@@ -522,7 +523,8 @@ impl TestServer {
|
|||||||
llm_db: None,
|
llm_db: None,
|
||||||
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
|
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
|
||||||
blob_store_client: None,
|
blob_store_client: None,
|
||||||
stripe_client: None,
|
real_stripe_client: None,
|
||||||
|
stripe_client: Some(Arc::new(FakeStripeClient::new())),
|
||||||
stripe_billing: None,
|
stripe_billing: None,
|
||||||
executor,
|
executor,
|
||||||
kinesis_client: None,
|
kinesis_client: None,
|
||||||
|
|||||||
@@ -354,6 +354,10 @@ impl ChannelView {
|
|||||||
editor.set_read_only(true);
|
editor.set_read_only(true);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}),
|
}),
|
||||||
|
ChannelBufferEvent::Connected => self.editor.update(cx, |editor, cx| {
|
||||||
|
editor.set_read_only(false);
|
||||||
|
cx.notify();
|
||||||
|
}),
|
||||||
ChannelBufferEvent::ChannelChanged => {
|
ChannelBufferEvent::ChannelChanged => {
|
||||||
self.editor.update(cx, |_, cx| {
|
self.editor.update(cx, |_, cx| {
|
||||||
cx.emit(editor::EditorEvent::TitleChanged);
|
cx.emit(editor::EditorEvent::TitleChanged);
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ use language::{
|
|||||||
Anchor, Buffer, BufferSnapshot, CodeLabel, LanguageRegistry, ToOffset,
|
Anchor, Buffer, BufferSnapshot, CodeLabel, LanguageRegistry, ToOffset,
|
||||||
language_settings::SoftWrap,
|
language_settings::SoftWrap,
|
||||||
};
|
};
|
||||||
use project::{Completion, CompletionSource, search::SearchQuery};
|
use project::{Completion, CompletionResponse, CompletionSource, search::SearchQuery};
|
||||||
use settings::Settings;
|
use settings::Settings;
|
||||||
use std::{
|
use std::{
|
||||||
cell::RefCell,
|
cell::RefCell,
|
||||||
@@ -64,9 +64,9 @@ impl CompletionProvider for MessageEditorCompletionProvider {
|
|||||||
_: editor::CompletionContext,
|
_: editor::CompletionContext,
|
||||||
_window: &mut Window,
|
_window: &mut Window,
|
||||||
cx: &mut Context<Editor>,
|
cx: &mut Context<Editor>,
|
||||||
) -> Task<Result<Option<Vec<Completion>>>> {
|
) -> Task<Result<Vec<CompletionResponse>>> {
|
||||||
let Some(handle) = self.0.upgrade() else {
|
let Some(handle) = self.0.upgrade() else {
|
||||||
return Task::ready(Ok(None));
|
return Task::ready(Ok(Vec::new()));
|
||||||
};
|
};
|
||||||
handle.update(cx, |message_editor, cx| {
|
handle.update(cx, |message_editor, cx| {
|
||||||
message_editor.completions(buffer, buffer_position, cx)
|
message_editor.completions(buffer, buffer_position, cx)
|
||||||
@@ -248,22 +248,21 @@ impl MessageEditor {
|
|||||||
buffer: &Entity<Buffer>,
|
buffer: &Entity<Buffer>,
|
||||||
end_anchor: Anchor,
|
end_anchor: Anchor,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Task<Result<Option<Vec<Completion>>>> {
|
) -> Task<Result<Vec<CompletionResponse>>> {
|
||||||
if let Some((start_anchor, query, candidates)) =
|
if let Some((start_anchor, query, candidates)) =
|
||||||
self.collect_mention_candidates(buffer, end_anchor, cx)
|
self.collect_mention_candidates(buffer, end_anchor, cx)
|
||||||
{
|
{
|
||||||
if !candidates.is_empty() {
|
if !candidates.is_empty() {
|
||||||
return cx.spawn(async move |_, cx| {
|
return cx.spawn(async move |_, cx| {
|
||||||
Ok(Some(
|
let completion_response = Self::resolve_completions_for_candidates(
|
||||||
Self::resolve_completions_for_candidates(
|
&cx,
|
||||||
&cx,
|
query.as_str(),
|
||||||
query.as_str(),
|
&candidates,
|
||||||
&candidates,
|
start_anchor..end_anchor,
|
||||||
start_anchor..end_anchor,
|
Self::completion_for_mention,
|
||||||
Self::completion_for_mention,
|
)
|
||||||
)
|
.await;
|
||||||
.await,
|
Ok(vec![completion_response])
|
||||||
))
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -273,21 +272,23 @@ impl MessageEditor {
|
|||||||
{
|
{
|
||||||
if !candidates.is_empty() {
|
if !candidates.is_empty() {
|
||||||
return cx.spawn(async move |_, cx| {
|
return cx.spawn(async move |_, cx| {
|
||||||
Ok(Some(
|
let completion_response = Self::resolve_completions_for_candidates(
|
||||||
Self::resolve_completions_for_candidates(
|
&cx,
|
||||||
&cx,
|
query.as_str(),
|
||||||
query.as_str(),
|
candidates,
|
||||||
candidates,
|
start_anchor..end_anchor,
|
||||||
start_anchor..end_anchor,
|
Self::completion_for_emoji,
|
||||||
Self::completion_for_emoji,
|
)
|
||||||
)
|
.await;
|
||||||
.await,
|
Ok(vec![completion_response])
|
||||||
))
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Task::ready(Ok(Some(Vec::new())))
|
Task::ready(Ok(vec![CompletionResponse {
|
||||||
|
completions: Vec::new(),
|
||||||
|
is_incomplete: false,
|
||||||
|
}]))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn resolve_completions_for_candidates(
|
async fn resolve_completions_for_candidates(
|
||||||
@@ -296,18 +297,19 @@ impl MessageEditor {
|
|||||||
candidates: &[StringMatchCandidate],
|
candidates: &[StringMatchCandidate],
|
||||||
range: Range<Anchor>,
|
range: Range<Anchor>,
|
||||||
completion_fn: impl Fn(&StringMatch) -> (String, CodeLabel),
|
completion_fn: impl Fn(&StringMatch) -> (String, CodeLabel),
|
||||||
) -> Vec<Completion> {
|
) -> CompletionResponse {
|
||||||
|
const LIMIT: usize = 10;
|
||||||
let matches = fuzzy::match_strings(
|
let matches = fuzzy::match_strings(
|
||||||
candidates,
|
candidates,
|
||||||
query,
|
query,
|
||||||
true,
|
true,
|
||||||
10,
|
LIMIT,
|
||||||
&Default::default(),
|
&Default::default(),
|
||||||
cx.background_executor().clone(),
|
cx.background_executor().clone(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
matches
|
let completions = matches
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|mat| {
|
.map(|mat| {
|
||||||
let (new_text, label) = completion_fn(&mat);
|
let (new_text, label) = completion_fn(&mat);
|
||||||
@@ -322,7 +324,12 @@ impl MessageEditor {
|
|||||||
source: CompletionSource::Custom,
|
source: CompletionSource::Custom,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect()
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
CompletionResponse {
|
||||||
|
is_incomplete: completions.len() >= LIMIT,
|
||||||
|
completions,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn completion_for_mention(mat: &StringMatch) -> (String, CodeLabel) {
|
fn completion_for_mention(mat: &StringMatch) -> (String, CodeLabel) {
|
||||||
|
|||||||
@@ -298,6 +298,7 @@ pub async fn download_adapter_from_github(
|
|||||||
response.status().to_string()
|
response.status().to_string()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
delegate.output_to_console("Download complete".to_owned());
|
||||||
match file_type {
|
match file_type {
|
||||||
DownloadedFileType::GzipTar => {
|
DownloadedFileType::GzipTar => {
|
||||||
let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
|
let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
|
||||||
@@ -369,21 +370,19 @@ pub trait DebugAdapter: 'static + Send + Sync {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_config(
|
/// Extracts the kind (attach/launch) of debug configuration from the given JSON config.
|
||||||
|
/// This method should only return error when the kind cannot be determined for a given configuration;
|
||||||
|
/// in particular, it *should not* validate whether the request as a whole is valid, because that's best left to the debug adapter itself to decide.
|
||||||
|
fn request_kind(
|
||||||
&self,
|
&self,
|
||||||
config: &serde_json::Value,
|
config: &serde_json::Value,
|
||||||
) -> Result<StartDebuggingRequestArgumentsRequest> {
|
) -> Result<StartDebuggingRequestArgumentsRequest> {
|
||||||
let map = config.as_object().context("Config isn't an object")?;
|
match config.get("request") {
|
||||||
|
Some(val) if val == "launch" => Ok(StartDebuggingRequestArgumentsRequest::Launch),
|
||||||
let request_variant = map
|
Some(val) if val == "attach" => Ok(StartDebuggingRequestArgumentsRequest::Attach),
|
||||||
.get("request")
|
_ => Err(anyhow!(
|
||||||
.and_then(|val| val.as_str())
|
"missing or invalid `request` field in config. Expected 'launch' or 'attach'"
|
||||||
.context("request argument is not found or invalid")?;
|
)),
|
||||||
|
|
||||||
match request_variant {
|
|
||||||
"launch" => Ok(StartDebuggingRequestArgumentsRequest::Launch),
|
|
||||||
"attach" => Ok(StartDebuggingRequestArgumentsRequest::Attach),
|
|
||||||
_ => Err(anyhow!("request must be either 'launch' or 'attach'")),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -413,7 +412,7 @@ impl DebugAdapter for FakeAdapter {
|
|||||||
serde_json::Value::Null
|
serde_json::Value::Null
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_config(
|
fn request_kind(
|
||||||
&self,
|
&self,
|
||||||
config: &serde_json::Value,
|
config: &serde_json::Value,
|
||||||
) -> Result<StartDebuggingRequestArgumentsRequest> {
|
) -> Result<StartDebuggingRequestArgumentsRequest> {
|
||||||
@@ -458,7 +457,7 @@ impl DebugAdapter for FakeAdapter {
|
|||||||
envs: HashMap::default(),
|
envs: HashMap::default(),
|
||||||
cwd: None,
|
cwd: None,
|
||||||
request_args: StartDebuggingRequestArguments {
|
request_args: StartDebuggingRequestArguments {
|
||||||
request: self.validate_config(&task_definition.config)?,
|
request: self.request_kind(&task_definition.config)?,
|
||||||
configuration: task_definition.config.clone(),
|
configuration: task_definition.config.clone(),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ pub fn send_telemetry(scenario: &DebugScenario, location: TelemetrySpawnLocation
|
|||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
let kind = adapter
|
let kind = adapter
|
||||||
.validate_config(&scenario.config)
|
.request_kind(&scenario.config)
|
||||||
.ok()
|
.ok()
|
||||||
.map(serde_json::to_value)
|
.map(serde_json::to_value)
|
||||||
.and_then(Result::ok);
|
.and_then(Result::ok);
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use dap_types::{
|
|||||||
messages::{Message, Response},
|
messages::{Message, Response},
|
||||||
};
|
};
|
||||||
use futures::{AsyncRead, AsyncReadExt as _, AsyncWrite, FutureExt as _, channel::oneshot, select};
|
use futures::{AsyncRead, AsyncReadExt as _, AsyncWrite, FutureExt as _, channel::oneshot, select};
|
||||||
use gpui::AsyncApp;
|
use gpui::{AppContext as _, AsyncApp, Task};
|
||||||
use settings::Settings as _;
|
use settings::Settings as _;
|
||||||
use smallvec::SmallVec;
|
use smallvec::SmallVec;
|
||||||
use smol::{
|
use smol::{
|
||||||
@@ -22,7 +22,7 @@ use std::{
|
|||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
use task::TcpArgumentsTemplate;
|
use task::TcpArgumentsTemplate;
|
||||||
use util::{ResultExt as _, TryFutureExt};
|
use util::{ConnectionResult, ResultExt as _};
|
||||||
|
|
||||||
use crate::{adapters::DebugAdapterBinary, debugger_settings::DebuggerSettings};
|
use crate::{adapters::DebugAdapterBinary, debugger_settings::DebuggerSettings};
|
||||||
|
|
||||||
@@ -126,7 +126,7 @@ pub(crate) struct TransportDelegate {
|
|||||||
pending_requests: Requests,
|
pending_requests: Requests,
|
||||||
transport: Transport,
|
transport: Transport,
|
||||||
server_tx: Arc<Mutex<Option<Sender<Message>>>>,
|
server_tx: Arc<Mutex<Option<Sender<Message>>>>,
|
||||||
_tasks: Vec<gpui::Task<Option<()>>>,
|
_tasks: Vec<Task<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TransportDelegate {
|
impl TransportDelegate {
|
||||||
@@ -141,7 +141,7 @@ impl TransportDelegate {
|
|||||||
log_handlers: Default::default(),
|
log_handlers: Default::default(),
|
||||||
current_requests: Default::default(),
|
current_requests: Default::default(),
|
||||||
pending_requests: Default::default(),
|
pending_requests: Default::default(),
|
||||||
_tasks: Default::default(),
|
_tasks: Vec::new(),
|
||||||
};
|
};
|
||||||
let messages = this.start_handlers(transport_pipes, cx).await?;
|
let messages = this.start_handlers(transport_pipes, cx).await?;
|
||||||
Ok((messages, this))
|
Ok((messages, this))
|
||||||
@@ -166,45 +166,76 @@ impl TransportDelegate {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let adapter_log_handler = log_handler.clone();
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
if let Some(stdout) = params.stdout.take() {
|
if let Some(stdout) = params.stdout.take() {
|
||||||
self._tasks.push(
|
self._tasks.push(cx.background_spawn(async move {
|
||||||
cx.background_executor()
|
match Self::handle_adapter_log(stdout, adapter_log_handler).await {
|
||||||
.spawn(Self::handle_adapter_log(stdout, log_handler.clone()).log_err()),
|
ConnectionResult::Timeout => {
|
||||||
);
|
log::error!("Timed out when handling debugger log");
|
||||||
|
}
|
||||||
|
ConnectionResult::ConnectionReset => {
|
||||||
|
log::info!("Debugger logs connection closed");
|
||||||
|
}
|
||||||
|
ConnectionResult::Result(Ok(())) => {}
|
||||||
|
ConnectionResult::Result(Err(e)) => {
|
||||||
|
log::error!("Error handling debugger log: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
self._tasks.push(
|
let pending_requests = self.pending_requests.clone();
|
||||||
cx.background_executor().spawn(
|
let output_log_handler = log_handler.clone();
|
||||||
Self::handle_output(
|
self._tasks.push(cx.background_spawn(async move {
|
||||||
params.output,
|
match Self::handle_output(
|
||||||
client_tx,
|
params.output,
|
||||||
self.pending_requests.clone(),
|
client_tx,
|
||||||
log_handler.clone(),
|
pending_requests,
|
||||||
)
|
output_log_handler,
|
||||||
.log_err(),
|
)
|
||||||
),
|
.await
|
||||||
);
|
{
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(e) => log::error!("Error handling debugger output: {e}"),
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
if let Some(stderr) = params.stderr.take() {
|
if let Some(stderr) = params.stderr.take() {
|
||||||
self._tasks.push(
|
let log_handlers = self.log_handlers.clone();
|
||||||
cx.background_executor()
|
self._tasks.push(cx.background_spawn(async move {
|
||||||
.spawn(Self::handle_error(stderr, self.log_handlers.clone()).log_err()),
|
match Self::handle_error(stderr, log_handlers).await {
|
||||||
);
|
ConnectionResult::Timeout => {
|
||||||
|
log::error!("Timed out reading debugger error stream")
|
||||||
|
}
|
||||||
|
ConnectionResult::ConnectionReset => {
|
||||||
|
log::info!("Debugger closed its error stream")
|
||||||
|
}
|
||||||
|
ConnectionResult::Result(Ok(())) => {}
|
||||||
|
ConnectionResult::Result(Err(e)) => {
|
||||||
|
log::error!("Error handling debugger error: {e}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
self._tasks.push(
|
let current_requests = self.current_requests.clone();
|
||||||
cx.background_executor().spawn(
|
let pending_requests = self.pending_requests.clone();
|
||||||
Self::handle_input(
|
let log_handler = log_handler.clone();
|
||||||
params.input,
|
self._tasks.push(cx.background_spawn(async move {
|
||||||
client_rx,
|
match Self::handle_input(
|
||||||
self.current_requests.clone(),
|
params.input,
|
||||||
self.pending_requests.clone(),
|
client_rx,
|
||||||
log_handler.clone(),
|
current_requests,
|
||||||
)
|
pending_requests,
|
||||||
.log_err(),
|
log_handler,
|
||||||
),
|
)
|
||||||
);
|
.await
|
||||||
|
{
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(e) => log::error!("Error handling debugger input: {e}"),
|
||||||
|
}
|
||||||
|
}));
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -235,7 +266,7 @@ impl TransportDelegate {
|
|||||||
async fn handle_adapter_log<Stdout>(
|
async fn handle_adapter_log<Stdout>(
|
||||||
stdout: Stdout,
|
stdout: Stdout,
|
||||||
log_handlers: Option<LogHandlers>,
|
log_handlers: Option<LogHandlers>,
|
||||||
) -> Result<()>
|
) -> ConnectionResult<()>
|
||||||
where
|
where
|
||||||
Stdout: AsyncRead + Unpin + Send + 'static,
|
Stdout: AsyncRead + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
@@ -245,13 +276,14 @@ impl TransportDelegate {
|
|||||||
let result = loop {
|
let result = loop {
|
||||||
line.truncate(0);
|
line.truncate(0);
|
||||||
|
|
||||||
let bytes_read = match reader.read_line(&mut line).await {
|
match reader
|
||||||
Ok(bytes_read) => bytes_read,
|
.read_line(&mut line)
|
||||||
Err(e) => break Err(e.into()),
|
.await
|
||||||
};
|
.context("reading adapter log line")
|
||||||
|
{
|
||||||
if bytes_read == 0 {
|
Ok(0) => break ConnectionResult::ConnectionReset,
|
||||||
anyhow::bail!("Debugger log stream closed");
|
Ok(_) => {}
|
||||||
|
Err(e) => break ConnectionResult::Result(Err(e)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(log_handlers) = log_handlers.as_ref() {
|
if let Some(log_handlers) = log_handlers.as_ref() {
|
||||||
@@ -337,35 +369,35 @@ impl TransportDelegate {
|
|||||||
let mut reader = BufReader::new(server_stdout);
|
let mut reader = BufReader::new(server_stdout);
|
||||||
|
|
||||||
let result = loop {
|
let result = loop {
|
||||||
let message =
|
match Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
|
||||||
Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
|
.await
|
||||||
.await;
|
{
|
||||||
|
ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"),
|
||||||
match message {
|
ConnectionResult::ConnectionReset => {
|
||||||
Ok(Message::Response(res)) => {
|
log::info!("Debugger closed the connection");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
ConnectionResult::Result(Ok(Message::Response(res))) => {
|
||||||
if let Some(tx) = pending_requests.lock().await.remove(&res.request_seq) {
|
if let Some(tx) = pending_requests.lock().await.remove(&res.request_seq) {
|
||||||
if let Err(e) = tx.send(Self::process_response(res)) {
|
if let Err(e) = tx.send(Self::process_response(res)) {
|
||||||
log::trace!("Did not send response `{:?}` for a cancelled", e);
|
log::trace!("Did not send response `{:?}` for a cancelled", e);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
client_tx.send(Message::Response(res)).await?;
|
client_tx.send(Message::Response(res)).await?;
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
Ok(message) => {
|
ConnectionResult::Result(Ok(message)) => client_tx.send(message).await?,
|
||||||
client_tx.send(message).await?;
|
ConnectionResult::Result(Err(e)) => break Err(e),
|
||||||
}
|
|
||||||
Err(e) => break Err(e),
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
drop(client_tx);
|
drop(client_tx);
|
||||||
|
|
||||||
log::debug!("Handle adapter output dropped");
|
log::debug!("Handle adapter output dropped");
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_error<Stderr>(stderr: Stderr, log_handlers: LogHandlers) -> Result<()>
|
async fn handle_error<Stderr>(stderr: Stderr, log_handlers: LogHandlers) -> ConnectionResult<()>
|
||||||
where
|
where
|
||||||
Stderr: AsyncRead + Unpin + Send + 'static,
|
Stderr: AsyncRead + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
@@ -375,8 +407,12 @@ impl TransportDelegate {
|
|||||||
let mut reader = BufReader::new(stderr);
|
let mut reader = BufReader::new(stderr);
|
||||||
|
|
||||||
let result = loop {
|
let result = loop {
|
||||||
match reader.read_line(&mut buffer).await {
|
match reader
|
||||||
Ok(0) => anyhow::bail!("debugger error stream closed"),
|
.read_line(&mut buffer)
|
||||||
|
.await
|
||||||
|
.context("reading error log line")
|
||||||
|
{
|
||||||
|
Ok(0) => break ConnectionResult::ConnectionReset,
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
for (kind, log_handler) in log_handlers.lock().iter_mut() {
|
for (kind, log_handler) in log_handlers.lock().iter_mut() {
|
||||||
if matches!(kind, LogKind::Adapter) {
|
if matches!(kind, LogKind::Adapter) {
|
||||||
@@ -386,7 +422,7 @@ impl TransportDelegate {
|
|||||||
|
|
||||||
buffer.truncate(0);
|
buffer.truncate(0);
|
||||||
}
|
}
|
||||||
Err(error) => break Err(error.into()),
|
Err(error) => break ConnectionResult::Result(Err(error)),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -420,7 +456,7 @@ impl TransportDelegate {
|
|||||||
reader: &mut BufReader<Stdout>,
|
reader: &mut BufReader<Stdout>,
|
||||||
buffer: &mut String,
|
buffer: &mut String,
|
||||||
log_handlers: Option<&LogHandlers>,
|
log_handlers: Option<&LogHandlers>,
|
||||||
) -> Result<Message>
|
) -> ConnectionResult<Message>
|
||||||
where
|
where
|
||||||
Stdout: AsyncRead + Unpin + Send + 'static,
|
Stdout: AsyncRead + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
@@ -428,48 +464,58 @@ impl TransportDelegate {
|
|||||||
loop {
|
loop {
|
||||||
buffer.truncate(0);
|
buffer.truncate(0);
|
||||||
|
|
||||||
if reader
|
match reader
|
||||||
.read_line(buffer)
|
.read_line(buffer)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "reading a message from server")?
|
.with_context(|| "reading a message from server")
|
||||||
== 0
|
|
||||||
{
|
{
|
||||||
anyhow::bail!("debugger reader stream closed");
|
Ok(0) => return ConnectionResult::ConnectionReset,
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => return ConnectionResult::Result(Err(e)),
|
||||||
};
|
};
|
||||||
|
|
||||||
if buffer == "\r\n" {
|
if buffer == "\r\n" {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
let parts = buffer.trim().split_once(": ");
|
if let Some(("Content-Length", value)) = buffer.trim().split_once(": ") {
|
||||||
|
match value.parse().context("invalid content length") {
|
||||||
match parts {
|
Ok(length) => content_length = Some(length),
|
||||||
Some(("Content-Length", value)) => {
|
Err(e) => return ConnectionResult::Result(Err(e)),
|
||||||
content_length = Some(value.parse().context("invalid content length")?);
|
|
||||||
}
|
}
|
||||||
_ => {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let content_length = content_length.context("missing content length")?;
|
let content_length = match content_length.context("missing content length") {
|
||||||
|
Ok(length) => length,
|
||||||
|
Err(e) => return ConnectionResult::Result(Err(e)),
|
||||||
|
};
|
||||||
|
|
||||||
let mut content = vec![0; content_length];
|
let mut content = vec![0; content_length];
|
||||||
reader
|
if let Err(e) = reader
|
||||||
.read_exact(&mut content)
|
.read_exact(&mut content)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "reading after a loop")?;
|
.with_context(|| "reading after a loop")
|
||||||
|
{
|
||||||
|
return ConnectionResult::Result(Err(e));
|
||||||
|
}
|
||||||
|
|
||||||
let message = std::str::from_utf8(&content).context("invalid utf8 from server")?;
|
let message_str = match std::str::from_utf8(&content).context("invalid utf8 from server") {
|
||||||
|
Ok(str) => str,
|
||||||
|
Err(e) => return ConnectionResult::Result(Err(e)),
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(log_handlers) = log_handlers {
|
if let Some(log_handlers) = log_handlers {
|
||||||
for (kind, log_handler) in log_handlers.lock().iter_mut() {
|
for (kind, log_handler) in log_handlers.lock().iter_mut() {
|
||||||
if matches!(kind, LogKind::Rpc) {
|
if matches!(kind, LogKind::Rpc) {
|
||||||
log_handler(IoKind::StdOut, &message);
|
log_handler(IoKind::StdOut, message_str);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(serde_json::from_str::<Message>(message)?)
|
ConnectionResult::Result(
|
||||||
|
serde_json::from_str::<Message>(message_str).context("deserializing server message"),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn shutdown(&self) -> Result<()> {
|
pub async fn shutdown(&self) -> Result<()> {
|
||||||
@@ -658,9 +704,13 @@ impl StdioTransport {
|
|||||||
.stderr(Stdio::piped())
|
.stderr(Stdio::piped())
|
||||||
.kill_on_drop(true);
|
.kill_on_drop(true);
|
||||||
|
|
||||||
let mut process = command
|
let mut process = command.spawn().with_context(|| {
|
||||||
.spawn()
|
format!(
|
||||||
.with_context(|| "failed to spawn command.")?;
|
"failed to spawn command `{} {}`.",
|
||||||
|
binary.command,
|
||||||
|
binary.arguments.join(" ")
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
let stdin = process.stdin.take().context("Failed to open stdin")?;
|
let stdin = process.stdin.take().context("Failed to open stdin")?;
|
||||||
let stdout = process.stdout.take().context("Failed to open stdout")?;
|
let stdout = process.stdout.take().context("Failed to open stdout")?;
|
||||||
@@ -773,71 +823,31 @@ impl FakeTransport {
|
|||||||
let response_handlers = this.response_handlers.clone();
|
let response_handlers = this.response_handlers.clone();
|
||||||
let stdout_writer = Arc::new(Mutex::new(stdout_writer));
|
let stdout_writer = Arc::new(Mutex::new(stdout_writer));
|
||||||
|
|
||||||
cx.background_executor()
|
cx.background_spawn(async move {
|
||||||
.spawn(async move {
|
let mut reader = BufReader::new(stdin_reader);
|
||||||
let mut reader = BufReader::new(stdin_reader);
|
let mut buffer = String::new();
|
||||||
let mut buffer = String::new();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let message =
|
match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None)
|
||||||
TransportDelegate::receive_server_message(&mut reader, &mut buffer, None)
|
.await
|
||||||
.await;
|
{
|
||||||
|
ConnectionResult::Timeout => {
|
||||||
match message {
|
anyhow::bail!("Timed out when connecting to debugger");
|
||||||
Err(error) => {
|
}
|
||||||
break anyhow::anyhow!(error);
|
ConnectionResult::ConnectionReset => {
|
||||||
}
|
log::info!("Debugger closed the connection");
|
||||||
Ok(message) => {
|
break Ok(());
|
||||||
match message {
|
}
|
||||||
Message::Request(request) => {
|
ConnectionResult::Result(Err(e)) => break Err(e),
|
||||||
// redirect reverse requests to stdout writer/reader
|
ConnectionResult::Result(Ok(message)) => {
|
||||||
if request.command == RunInTerminal::COMMAND
|
match message {
|
||||||
|| request.command == StartDebugging::COMMAND
|
Message::Request(request) => {
|
||||||
{
|
// redirect reverse requests to stdout writer/reader
|
||||||
let message =
|
if request.command == RunInTerminal::COMMAND
|
||||||
serde_json::to_string(&Message::Request(request))
|
|| request.command == StartDebugging::COMMAND
|
||||||
.unwrap();
|
{
|
||||||
|
|
||||||
let mut writer = stdout_writer.lock().await;
|
|
||||||
writer
|
|
||||||
.write_all(
|
|
||||||
TransportDelegate::build_rpc_message(message)
|
|
||||||
.as_bytes(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
writer.flush().await.unwrap();
|
|
||||||
} else {
|
|
||||||
let response = if let Some(handle) = request_handlers
|
|
||||||
.lock()
|
|
||||||
.get_mut(request.command.as_str())
|
|
||||||
{
|
|
||||||
handle(
|
|
||||||
request.seq,
|
|
||||||
request.arguments.unwrap_or(json!({})),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
panic!("No request handler for {}", request.command);
|
|
||||||
};
|
|
||||||
let message =
|
|
||||||
serde_json::to_string(&Message::Response(response))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let mut writer = stdout_writer.lock().await;
|
|
||||||
|
|
||||||
writer
|
|
||||||
.write_all(
|
|
||||||
TransportDelegate::build_rpc_message(message)
|
|
||||||
.as_bytes(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
writer.flush().await.unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Message::Event(event) => {
|
|
||||||
let message =
|
let message =
|
||||||
serde_json::to_string(&Message::Event(event)).unwrap();
|
serde_json::to_string(&Message::Request(request)).unwrap();
|
||||||
|
|
||||||
let mut writer = stdout_writer.lock().await;
|
let mut writer = stdout_writer.lock().await;
|
||||||
writer
|
writer
|
||||||
@@ -848,22 +858,58 @@ impl FakeTransport {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
writer.flush().await.unwrap();
|
writer.flush().await.unwrap();
|
||||||
}
|
} else {
|
||||||
Message::Response(response) => {
|
let response = if let Some(handle) =
|
||||||
if let Some(handle) =
|
request_handlers.lock().get_mut(request.command.as_str())
|
||||||
response_handlers.lock().get(response.command.as_str())
|
|
||||||
{
|
{
|
||||||
handle(response);
|
handle(request.seq, request.arguments.unwrap_or(json!({})))
|
||||||
} else {
|
} else {
|
||||||
log::error!("No response handler for {}", response.command);
|
panic!("No request handler for {}", request.command);
|
||||||
}
|
};
|
||||||
|
let message =
|
||||||
|
serde_json::to_string(&Message::Response(response))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut writer = stdout_writer.lock().await;
|
||||||
|
|
||||||
|
writer
|
||||||
|
.write_all(
|
||||||
|
TransportDelegate::build_rpc_message(message)
|
||||||
|
.as_bytes(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
writer.flush().await.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Message::Event(event) => {
|
||||||
|
let message =
|
||||||
|
serde_json::to_string(&Message::Event(event)).unwrap();
|
||||||
|
|
||||||
|
let mut writer = stdout_writer.lock().await;
|
||||||
|
writer
|
||||||
|
.write_all(
|
||||||
|
TransportDelegate::build_rpc_message(message).as_bytes(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
writer.flush().await.unwrap();
|
||||||
|
}
|
||||||
|
Message::Response(response) => {
|
||||||
|
if let Some(handle) =
|
||||||
|
response_handlers.lock().get(response.command.as_str())
|
||||||
|
{
|
||||||
|
handle(response);
|
||||||
|
} else {
|
||||||
|
log::error!("No response handler for {}", response.command);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
.detach();
|
})
|
||||||
|
.detach();
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
TransportPipe::new(Box::new(stdin_writer), Box::new(stdout_reader), None, None),
|
TransportPipe::new(Box::new(stdin_writer), Box::new(stdout_reader), None, None),
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
use std::{collections::HashMap, path::PathBuf, sync::OnceLock};
|
use std::{collections::HashMap, path::PathBuf, sync::OnceLock};
|
||||||
|
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use dap::{
|
use dap::adapters::{DebugTaskDefinition, latest_github_release};
|
||||||
StartDebuggingRequestArgumentsRequest,
|
|
||||||
adapters::{DebugTaskDefinition, latest_github_release},
|
|
||||||
};
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use gpui::AsyncApp;
|
use gpui::AsyncApp;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -37,7 +34,7 @@ impl CodeLldbDebugAdapter {
|
|||||||
Value::String(String::from(task_definition.label.as_ref())),
|
Value::String(String::from(task_definition.label.as_ref())),
|
||||||
);
|
);
|
||||||
|
|
||||||
let request = self.validate_config(&configuration)?;
|
let request = self.request_kind(&configuration)?;
|
||||||
|
|
||||||
Ok(dap::StartDebuggingRequestArguments {
|
Ok(dap::StartDebuggingRequestArguments {
|
||||||
request,
|
request,
|
||||||
@@ -89,48 +86,6 @@ impl DebugAdapter for CodeLldbDebugAdapter {
|
|||||||
DebugAdapterName(Self::ADAPTER_NAME.into())
|
DebugAdapterName(Self::ADAPTER_NAME.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_config(
|
|
||||||
&self,
|
|
||||||
config: &serde_json::Value,
|
|
||||||
) -> Result<StartDebuggingRequestArgumentsRequest> {
|
|
||||||
let map = config
|
|
||||||
.as_object()
|
|
||||||
.ok_or_else(|| anyhow!("Config isn't an object"))?;
|
|
||||||
|
|
||||||
let request_variant = map
|
|
||||||
.get("request")
|
|
||||||
.and_then(|r| r.as_str())
|
|
||||||
.ok_or_else(|| anyhow!("request field is required and must be a string"))?;
|
|
||||||
|
|
||||||
match request_variant {
|
|
||||||
"launch" => {
|
|
||||||
// For launch, verify that one of the required configs exists
|
|
||||||
if !(map.contains_key("program")
|
|
||||||
|| map.contains_key("targetCreateCommands")
|
|
||||||
|| map.contains_key("cargo"))
|
|
||||||
{
|
|
||||||
return Err(anyhow!(
|
|
||||||
"launch request requires either 'program', 'targetCreateCommands', or 'cargo' field"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
Ok(StartDebuggingRequestArgumentsRequest::Launch)
|
|
||||||
}
|
|
||||||
"attach" => {
|
|
||||||
// For attach, verify that either pid or program exists
|
|
||||||
if !(map.contains_key("pid") || map.contains_key("program")) {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"attach request requires either 'pid' or 'program' field"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
Ok(StartDebuggingRequestArgumentsRequest::Attach)
|
|
||||||
}
|
|
||||||
_ => Err(anyhow!(
|
|
||||||
"request must be either 'launch' or 'attach', got '{}'",
|
|
||||||
request_variant
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config_from_zed_format(&self, zed_scenario: ZedDebugConfig) -> Result<DebugScenario> {
|
fn config_from_zed_format(&self, zed_scenario: ZedDebugConfig) -> Result<DebugScenario> {
|
||||||
let mut configuration = json!({
|
let mut configuration = json!({
|
||||||
"request": match zed_scenario.request {
|
"request": match zed_scenario.request {
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ pub fn init(cx: &mut App) {
|
|||||||
registry.add_adapter(Arc::from(PhpDebugAdapter::default()));
|
registry.add_adapter(Arc::from(PhpDebugAdapter::default()));
|
||||||
registry.add_adapter(Arc::from(JsDebugAdapter::default()));
|
registry.add_adapter(Arc::from(JsDebugAdapter::default()));
|
||||||
registry.add_adapter(Arc::from(RubyDebugAdapter));
|
registry.add_adapter(Arc::from(RubyDebugAdapter));
|
||||||
registry.add_adapter(Arc::from(GoDebugAdapter));
|
registry.add_adapter(Arc::from(GoDebugAdapter::default()));
|
||||||
registry.add_adapter(Arc::from(GdbDebugAdapter));
|
registry.add_adapter(Arc::from(GdbDebugAdapter));
|
||||||
|
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ impl DebugAdapter for GdbDebugAdapter {
|
|||||||
let gdb_path = user_setting_path.unwrap_or(gdb_path?);
|
let gdb_path = user_setting_path.unwrap_or(gdb_path?);
|
||||||
|
|
||||||
let request_args = StartDebuggingRequestArguments {
|
let request_args = StartDebuggingRequestArguments {
|
||||||
request: self.validate_config(&config.config)?,
|
request: self.request_kind(&config.config)?,
|
||||||
configuration: config.config.clone(),
|
configuration: config.config.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,22 +1,87 @@
|
|||||||
use anyhow::{Context as _, anyhow, bail};
|
use anyhow::{Context as _, bail};
|
||||||
use dap::{
|
use dap::{
|
||||||
StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest,
|
StartDebuggingRequestArguments,
|
||||||
adapters::DebugTaskDefinition,
|
adapters::{
|
||||||
|
DebugTaskDefinition, DownloadedFileType, download_adapter_from_github,
|
||||||
|
latest_github_release,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use gpui::{AsyncApp, SharedString};
|
use gpui::{AsyncApp, SharedString};
|
||||||
use language::LanguageName;
|
use language::LanguageName;
|
||||||
use std::{collections::HashMap, ffi::OsStr, path::PathBuf};
|
use std::{collections::HashMap, env::consts, ffi::OsStr, path::PathBuf, sync::OnceLock};
|
||||||
use util;
|
use util;
|
||||||
|
|
||||||
use crate::*;
|
use crate::*;
|
||||||
|
|
||||||
#[derive(Default, Debug)]
|
#[derive(Default, Debug)]
|
||||||
pub(crate) struct GoDebugAdapter;
|
pub(crate) struct GoDebugAdapter {
|
||||||
|
shim_path: OnceLock<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
impl GoDebugAdapter {
|
impl GoDebugAdapter {
|
||||||
const ADAPTER_NAME: &'static str = "Delve";
|
const ADAPTER_NAME: &'static str = "Delve";
|
||||||
const DEFAULT_TIMEOUT_MS: u64 = 60000;
|
async fn fetch_latest_adapter_version(
|
||||||
|
delegate: &Arc<dyn DapDelegate>,
|
||||||
|
) -> Result<AdapterVersion> {
|
||||||
|
let release = latest_github_release(
|
||||||
|
&"zed-industries/delve-shim-dap",
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
delegate.http_client(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let os = match consts::OS {
|
||||||
|
"macos" => "apple-darwin",
|
||||||
|
"linux" => "unknown-linux-gnu",
|
||||||
|
"windows" => "pc-windows-msvc",
|
||||||
|
other => bail!("Running on unsupported os: {other}"),
|
||||||
|
};
|
||||||
|
let suffix = if consts::OS == "windows" {
|
||||||
|
".zip"
|
||||||
|
} else {
|
||||||
|
".tar.gz"
|
||||||
|
};
|
||||||
|
let asset_name = format!("delve-shim-dap-{}-{os}{suffix}", consts::ARCH);
|
||||||
|
let asset = release
|
||||||
|
.assets
|
||||||
|
.iter()
|
||||||
|
.find(|asset| asset.name == asset_name)
|
||||||
|
.with_context(|| format!("no asset found matching `{asset_name:?}`"))?;
|
||||||
|
|
||||||
|
Ok(AdapterVersion {
|
||||||
|
tag_name: release.tag_name,
|
||||||
|
url: asset.browser_download_url.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
async fn install_shim(&self, delegate: &Arc<dyn DapDelegate>) -> anyhow::Result<PathBuf> {
|
||||||
|
if let Some(path) = self.shim_path.get().cloned() {
|
||||||
|
return Ok(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
let asset = Self::fetch_latest_adapter_version(delegate).await?;
|
||||||
|
let ty = if consts::OS == "windows" {
|
||||||
|
DownloadedFileType::Zip
|
||||||
|
} else {
|
||||||
|
DownloadedFileType::GzipTar
|
||||||
|
};
|
||||||
|
download_adapter_from_github(
|
||||||
|
"delve-shim-dap".into(),
|
||||||
|
asset.clone(),
|
||||||
|
ty,
|
||||||
|
delegate.as_ref(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let path = paths::debug_adapters_dir()
|
||||||
|
.join("delve-shim-dap")
|
||||||
|
.join(format!("delve-shim-dap_{}", asset.tag_name))
|
||||||
|
.join(format!("delve-shim-dap{}", std::env::consts::EXE_SUFFIX));
|
||||||
|
self.shim_path.set(path.clone()).ok();
|
||||||
|
|
||||||
|
Ok(path)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait(?Send)]
|
#[async_trait(?Send)]
|
||||||
@@ -285,24 +350,6 @@ impl DebugAdapter for GoDebugAdapter {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_config(
|
|
||||||
&self,
|
|
||||||
config: &serde_json::Value,
|
|
||||||
) -> Result<StartDebuggingRequestArgumentsRequest> {
|
|
||||||
let map = config.as_object().context("Config isn't an object")?;
|
|
||||||
|
|
||||||
let request_variant = map
|
|
||||||
.get("request")
|
|
||||||
.and_then(|val| val.as_str())
|
|
||||||
.context("request argument is not found or invalid")?;
|
|
||||||
|
|
||||||
match request_variant {
|
|
||||||
"launch" => Ok(StartDebuggingRequestArgumentsRequest::Launch),
|
|
||||||
"attach" => Ok(StartDebuggingRequestArgumentsRequest::Attach),
|
|
||||||
_ => Err(anyhow!("request must be either 'launch' or 'attach'")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config_from_zed_format(&self, zed_scenario: ZedDebugConfig) -> Result<DebugScenario> {
|
fn config_from_zed_format(&self, zed_scenario: ZedDebugConfig) -> Result<DebugScenario> {
|
||||||
let mut args = match &zed_scenario.request {
|
let mut args = match &zed_scenario.request {
|
||||||
dap::DebugRequest::Attach(attach_config) => {
|
dap::DebugRequest::Attach(attach_config) => {
|
||||||
@@ -349,13 +396,15 @@ impl DebugAdapter for GoDebugAdapter {
|
|||||||
&self,
|
&self,
|
||||||
delegate: &Arc<dyn DapDelegate>,
|
delegate: &Arc<dyn DapDelegate>,
|
||||||
task_definition: &DebugTaskDefinition,
|
task_definition: &DebugTaskDefinition,
|
||||||
_user_installed_path: Option<PathBuf>,
|
user_installed_path: Option<PathBuf>,
|
||||||
_cx: &mut AsyncApp,
|
_cx: &mut AsyncApp,
|
||||||
) -> Result<DebugAdapterBinary> {
|
) -> Result<DebugAdapterBinary> {
|
||||||
let adapter_path = paths::debug_adapters_dir().join(&Self::ADAPTER_NAME);
|
let adapter_path = paths::debug_adapters_dir().join(&Self::ADAPTER_NAME);
|
||||||
let dlv_path = adapter_path.join("dlv");
|
let dlv_path = adapter_path.join("dlv");
|
||||||
|
|
||||||
let delve_path = if let Some(path) = delegate.which(OsStr::new("dlv")).await {
|
let delve_path = if let Some(path) = user_installed_path {
|
||||||
|
path.to_string_lossy().to_string()
|
||||||
|
} else if let Some(path) = delegate.which(OsStr::new("dlv")).await {
|
||||||
path.to_string_lossy().to_string()
|
path.to_string_lossy().to_string()
|
||||||
} else if delegate.fs().is_file(&dlv_path).await {
|
} else if delegate.fs().is_file(&dlv_path).await {
|
||||||
dlv_path.to_string_lossy().to_string()
|
dlv_path.to_string_lossy().to_string()
|
||||||
@@ -384,16 +433,10 @@ impl DebugAdapter for GoDebugAdapter {
|
|||||||
|
|
||||||
adapter_path.join("dlv").to_string_lossy().to_string()
|
adapter_path.join("dlv").to_string_lossy().to_string()
|
||||||
};
|
};
|
||||||
|
let minidelve_path = self.install_shim(delegate).await?;
|
||||||
|
let tcp_connection = task_definition.tcp_connection.clone().unwrap_or_default();
|
||||||
|
|
||||||
let mut tcp_connection = task_definition.tcp_connection.clone().unwrap_or_default();
|
let (host, port, _) = crate::configure_tcp_connection(tcp_connection).await?;
|
||||||
|
|
||||||
if tcp_connection.timeout.is_none()
|
|
||||||
|| tcp_connection.timeout.unwrap_or(0) < Self::DEFAULT_TIMEOUT_MS
|
|
||||||
{
|
|
||||||
tcp_connection.timeout = Some(Self::DEFAULT_TIMEOUT_MS);
|
|
||||||
}
|
|
||||||
|
|
||||||
let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?;
|
|
||||||
|
|
||||||
let cwd = task_definition
|
let cwd = task_definition
|
||||||
.config
|
.config
|
||||||
@@ -404,6 +447,7 @@ impl DebugAdapter for GoDebugAdapter {
|
|||||||
|
|
||||||
let arguments = if cfg!(windows) {
|
let arguments = if cfg!(windows) {
|
||||||
vec![
|
vec![
|
||||||
|
delve_path,
|
||||||
"dap".into(),
|
"dap".into(),
|
||||||
"--listen".into(),
|
"--listen".into(),
|
||||||
format!("{}:{}", host, port),
|
format!("{}:{}", host, port),
|
||||||
@@ -411,6 +455,7 @@ impl DebugAdapter for GoDebugAdapter {
|
|||||||
]
|
]
|
||||||
} else {
|
} else {
|
||||||
vec![
|
vec![
|
||||||
|
delve_path,
|
||||||
"dap".into(),
|
"dap".into(),
|
||||||
"--listen".into(),
|
"--listen".into(),
|
||||||
format!("{}:{}", host, port),
|
format!("{}:{}", host, port),
|
||||||
@@ -418,18 +463,14 @@ impl DebugAdapter for GoDebugAdapter {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Ok(DebugAdapterBinary {
|
Ok(DebugAdapterBinary {
|
||||||
command: delve_path,
|
command: minidelve_path.to_string_lossy().into_owned(),
|
||||||
arguments,
|
arguments,
|
||||||
cwd: Some(cwd),
|
cwd: Some(cwd),
|
||||||
envs: HashMap::default(),
|
envs: HashMap::default(),
|
||||||
connection: Some(adapters::TcpArguments {
|
connection: None,
|
||||||
host,
|
|
||||||
port,
|
|
||||||
timeout,
|
|
||||||
}),
|
|
||||||
request_args: StartDebuggingRequestArguments {
|
request_args: StartDebuggingRequestArguments {
|
||||||
configuration: task_definition.config.clone(),
|
configuration: task_definition.config.clone(),
|
||||||
request: self.validate_config(&task_definition.config)?,
|
request: self.request_kind(&task_definition.config)?,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
use adapters::latest_github_release;
|
use adapters::latest_github_release;
|
||||||
use anyhow::{Context as _, anyhow};
|
use anyhow::Context as _;
|
||||||
use dap::{
|
use dap::{StartDebuggingRequestArguments, adapters::DebugTaskDefinition};
|
||||||
StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest,
|
|
||||||
adapters::DebugTaskDefinition,
|
|
||||||
};
|
|
||||||
use gpui::AsyncApp;
|
use gpui::AsyncApp;
|
||||||
use std::{collections::HashMap, path::PathBuf, sync::OnceLock};
|
use std::{collections::HashMap, path::PathBuf, sync::OnceLock};
|
||||||
use task::DebugRequest;
|
use task::DebugRequest;
|
||||||
@@ -26,7 +23,7 @@ impl JsDebugAdapter {
|
|||||||
delegate: &Arc<dyn DapDelegate>,
|
delegate: &Arc<dyn DapDelegate>,
|
||||||
) -> Result<AdapterVersion> {
|
) -> Result<AdapterVersion> {
|
||||||
let release = latest_github_release(
|
let release = latest_github_release(
|
||||||
&format!("{}/{}", "microsoft", Self::ADAPTER_NPM_NAME),
|
&format!("microsoft/{}", Self::ADAPTER_NPM_NAME),
|
||||||
true,
|
true,
|
||||||
false,
|
false,
|
||||||
delegate.http_client(),
|
delegate.http_client(),
|
||||||
@@ -95,7 +92,7 @@ impl JsDebugAdapter {
|
|||||||
}),
|
}),
|
||||||
request_args: StartDebuggingRequestArguments {
|
request_args: StartDebuggingRequestArguments {
|
||||||
configuration: task_definition.config.clone(),
|
configuration: task_definition.config.clone(),
|
||||||
request: self.validate_config(&task_definition.config)?,
|
request: self.request_kind(&task_definition.config)?,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -107,29 +104,6 @@ impl DebugAdapter for JsDebugAdapter {
|
|||||||
DebugAdapterName(Self::ADAPTER_NAME.into())
|
DebugAdapterName(Self::ADAPTER_NAME.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_config(
|
|
||||||
&self,
|
|
||||||
config: &serde_json::Value,
|
|
||||||
) -> Result<dap::StartDebuggingRequestArgumentsRequest> {
|
|
||||||
match config.get("request") {
|
|
||||||
Some(val) if val == "launch" => {
|
|
||||||
if config.get("program").is_none() && config.get("url").is_none() {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"either program or url is required for launch request"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
Ok(StartDebuggingRequestArgumentsRequest::Launch)
|
|
||||||
}
|
|
||||||
Some(val) if val == "attach" => {
|
|
||||||
if !config.get("processId").is_some_and(|val| val.is_u64()) {
|
|
||||||
return Err(anyhow!("processId must be a number"));
|
|
||||||
}
|
|
||||||
Ok(StartDebuggingRequestArgumentsRequest::Attach)
|
|
||||||
}
|
|
||||||
_ => Err(anyhow!("missing or invalid request field in config")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config_from_zed_format(&self, zed_scenario: ZedDebugConfig) -> Result<DebugScenario> {
|
fn config_from_zed_format(&self, zed_scenario: ZedDebugConfig) -> Result<DebugScenario> {
|
||||||
let mut args = json!({
|
let mut args = json!({
|
||||||
"type": "pwa-node",
|
"type": "pwa-node",
|
||||||
@@ -449,6 +423,8 @@ impl DebugAdapter for JsDebugAdapter {
|
|||||||
delegate.as_ref(),
|
delegate.as_ref(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
} else {
|
||||||
|
delegate.output_to_console(format!("{} debug adapter is up to date", self.name()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ impl PhpDebugAdapter {
|
|||||||
envs: HashMap::default(),
|
envs: HashMap::default(),
|
||||||
request_args: StartDebuggingRequestArguments {
|
request_args: StartDebuggingRequestArguments {
|
||||||
configuration: task_definition.config.clone(),
|
configuration: task_definition.config.clone(),
|
||||||
request: <Self as DebugAdapter>::validate_config(self, &task_definition.config)?,
|
request: <Self as DebugAdapter>::request_kind(self, &task_definition.config)?,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -149,22 +149,8 @@ impl DebugAdapter for PhpDebugAdapter {
|
|||||||
"default": false
|
"default": false
|
||||||
},
|
},
|
||||||
"pathMappings": {
|
"pathMappings": {
|
||||||
"type": "array",
|
"type": "object",
|
||||||
"description": "A list of server paths mapping to the local source paths on your machine for remote host debugging",
|
"description": "A mapping of server paths to local paths.",
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"serverPath": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Path on the server"
|
|
||||||
},
|
|
||||||
"localPath": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Corresponding path on the local machine"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["serverPath", "localPath"]
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"log": {
|
"log": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
@@ -296,10 +282,7 @@ impl DebugAdapter for PhpDebugAdapter {
|
|||||||
Some(SharedString::new_static("PHP").into())
|
Some(SharedString::new_static("PHP").into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_config(
|
fn request_kind(&self, _: &serde_json::Value) -> Result<StartDebuggingRequestArgumentsRequest> {
|
||||||
&self,
|
|
||||||
_: &serde_json::Value,
|
|
||||||
) -> Result<StartDebuggingRequestArgumentsRequest> {
|
|
||||||
Ok(StartDebuggingRequestArgumentsRequest::Launch)
|
Ok(StartDebuggingRequestArgumentsRequest::Launch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
use crate::*;
|
use crate::*;
|
||||||
use anyhow::{Context as _, anyhow};
|
use anyhow::Context as _;
|
||||||
use dap::{
|
use dap::{DebugRequest, StartDebuggingRequestArguments, adapters::DebugTaskDefinition};
|
||||||
DebugRequest, StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest,
|
|
||||||
adapters::DebugTaskDefinition,
|
|
||||||
};
|
|
||||||
use gpui::{AsyncApp, SharedString};
|
use gpui::{AsyncApp, SharedString};
|
||||||
use json_dotpath::DotPaths;
|
use json_dotpath::DotPaths;
|
||||||
use language::{LanguageName, Toolchain};
|
use language::{LanguageName, Toolchain};
|
||||||
@@ -86,7 +83,7 @@ impl PythonDebugAdapter {
|
|||||||
&self,
|
&self,
|
||||||
task_definition: &DebugTaskDefinition,
|
task_definition: &DebugTaskDefinition,
|
||||||
) -> Result<StartDebuggingRequestArguments> {
|
) -> Result<StartDebuggingRequestArguments> {
|
||||||
let request = self.validate_config(&task_definition.config)?;
|
let request = self.request_kind(&task_definition.config)?;
|
||||||
|
|
||||||
let mut configuration = task_definition.config.clone();
|
let mut configuration = task_definition.config.clone();
|
||||||
if let Ok(console) = configuration.dot_get_mut("console") {
|
if let Ok(console) = configuration.dot_get_mut("console") {
|
||||||
@@ -254,24 +251,6 @@ impl DebugAdapter for PythonDebugAdapter {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_config(
|
|
||||||
&self,
|
|
||||||
config: &serde_json::Value,
|
|
||||||
) -> Result<StartDebuggingRequestArgumentsRequest> {
|
|
||||||
let map = config.as_object().context("Config isn't an object")?;
|
|
||||||
|
|
||||||
let request_variant = map
|
|
||||||
.get("request")
|
|
||||||
.and_then(|val| val.as_str())
|
|
||||||
.context("request is not valid")?;
|
|
||||||
|
|
||||||
match request_variant {
|
|
||||||
"launch" => Ok(StartDebuggingRequestArgumentsRequest::Launch),
|
|
||||||
"attach" => Ok(StartDebuggingRequestArgumentsRequest::Attach),
|
|
||||||
_ => Err(anyhow!("request must be either 'launch' or 'attach'")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn dap_schema(&self) -> serde_json::Value {
|
async fn dap_schema(&self) -> serde_json::Value {
|
||||||
json!({
|
json!({
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -660,7 +639,7 @@ impl DebugAdapter for PythonDebugAdapter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.get_installed_binary(delegate, &config, None, None, false)
|
self.get_installed_binary(delegate, &config, None, toolchain, false)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -265,7 +265,7 @@ impl DebugAdapter for RubyDebugAdapter {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
envs: std::collections::HashMap::default(),
|
envs: std::collections::HashMap::default(),
|
||||||
request_args: StartDebuggingRequestArguments {
|
request_args: StartDebuggingRequestArguments {
|
||||||
request: self.validate_config(&definition.config)?,
|
request: self.request_kind(&definition.config)?,
|
||||||
configuration: definition.config.clone(),
|
configuration: definition.config.clone(),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||